Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions numpy/_core/src/umath/string_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1149,49 +1149,54 @@ string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE striptype)
return 0;
}

size_t i = 0;
size_t new_start = 0;

size_t num_bytes = (buf.after - buf.buf);
Buffer traverse_buf = Buffer<enc>(buf.buf, num_bytes);

if (striptype != STRIPTYPE::RIGHTSTRIP) {
while (i < len) {
while (new_start < len) {
if (!traverse_buf.first_character_isspace()) {
break;
}
num_bytes -= traverse_buf.num_bytes_next_character();
traverse_buf++;
i++;
new_start++;
traverse_buf++; // may go one beyond buffer
}
}

npy_intp j = len - 1; // Could also turn negative if we're stripping the whole string
size_t new_stop = len; // New stop is a range (beyond last char)
if (enc == ENCODING::UTF8) {
traverse_buf = Buffer<enc>(buf.after, 0) - 1;
}
else {
traverse_buf = buf + j;
traverse_buf = buf + (new_stop - 1);
}

if (striptype != STRIPTYPE::LEFTSTRIP) {
while (j >= static_cast<npy_intp>(i)) {
while (new_stop > new_start) {
if (*traverse_buf != 0 && !traverse_buf.first_character_isspace()) {
break;
}

num_bytes -= traverse_buf.num_bytes_next_character();
traverse_buf--;
j--;
new_stop--;

// Do not step to character -1: can't find it's start for utf-8.
if (new_stop > 0) {
traverse_buf--;
}
}
}

Buffer offset_buf = buf + i;
Buffer offset_buf = buf + new_start;
if (enc == ENCODING::UTF8) {
offset_buf.buffer_memcpy(out, num_bytes);
return num_bytes;
}
offset_buf.buffer_memcpy(out, j - i + 1);
out.buffer_fill_with_zeros_after_index(j - i + 1);
return j - i + 1;
offset_buf.buffer_memcpy(out, new_stop - new_start);
out.buffer_fill_with_zeros_after_index(new_stop - new_start);
return new_stop - new_start;
}


Expand All @@ -1218,13 +1223,13 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
return len1;
}

size_t i = 0;
size_t new_start = 0;

size_t num_bytes = (buf1.after - buf1.buf);
Buffer traverse_buf = Buffer<enc>(buf1.buf, num_bytes);

if (striptype != STRIPTYPE::RIGHTSTRIP) {
while (i < len1) {
for (; new_start < len1; traverse_buf++) {
Py_ssize_t res;
switch (enc) {
case ENCODING::ASCII:
Expand All @@ -1245,21 +1250,20 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
break;
}
num_bytes -= traverse_buf.num_bytes_next_character();
traverse_buf++;
i++;
new_start++;
}
}

npy_intp j = len1 - 1;
size_t new_stop = len1; // New stop is a range (beyond last char)
if (enc == ENCODING::UTF8) {
traverse_buf = Buffer<enc>(buf1.after, 0) - 1;
}
else {
traverse_buf = buf1 + j;
traverse_buf = buf1 + (new_stop - 1);
}

if (striptype != STRIPTYPE::LEFTSTRIP) {
while (j >= static_cast<npy_intp>(i)) {
while (new_stop > new_start) {
Py_ssize_t res;
switch (enc) {
case ENCODING::ASCII:
Expand All @@ -1280,19 +1284,22 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
break;
}
num_bytes -= traverse_buf.num_bytes_next_character();
j--;
traverse_buf--;
new_stop--;
// Do not step to character -1: can't find it's start for utf-8.
if (new_stop > 0) {
traverse_buf--;
}
}
}

Buffer offset_buf = buf1 + i;
Buffer offset_buf = buf1 + new_start;
if (enc == ENCODING::UTF8) {
offset_buf.buffer_memcpy(out, num_bytes);
return num_bytes;
}
offset_buf.buffer_memcpy(out, j - i + 1);
out.buffer_fill_with_zeros_after_index(j - i + 1);
return j - i + 1;
offset_buf.buffer_memcpy(out, new_stop - new_start);
out.buffer_fill_with_zeros_after_index(new_stop - new_start);
return new_stop - new_start;
}

template <typename char_type>
Expand Down
13 changes: 9 additions & 4 deletions numpy/_core/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,12 @@ def test_endswith(self, a, suffix, start, end, out, dt):
])
def test_lstrip(self, a, chars, out, dt):
a = np.array(a, dtype=dt)
out = np.array(out, dtype=dt)
if chars is not None:
chars = np.array(chars, dtype=dt)
out = np.array(out, dtype=dt)
assert_array_equal(np.strings.lstrip(a, chars), out)
assert_array_equal(np.strings.lstrip(a, chars), out)
else:
assert_array_equal(np.strings.lstrip(a), out)

@pytest.mark.parametrize("a,chars,out", [
("", None, ""),
Expand All @@ -486,17 +488,20 @@ def test_lstrip(self, a, chars, out, dt):
("xyzzyhelloxyzzy", "xyz", "xyzzyhello"),
("hello", "xyz", "hello"),
("xyxz", "xyxz", ""),
(" ", None, ""),
("xyxzx", "x", "xyxz"),
(["xyzzyhelloxyzzy", "hello"], ["xyz", "xyz"],
["xyzzyhello", "hello"]),
(["ab", "ac", "aab", "abb"], "b", ["a", "ac", "aa", "a"]),
])
def test_rstrip(self, a, chars, out, dt):
a = np.array(a, dtype=dt)
out = np.array(out, dtype=dt)
if chars is not None:
chars = np.array(chars, dtype=dt)
out = np.array(out, dtype=dt)
assert_array_equal(np.strings.rstrip(a, chars), out)
assert_array_equal(np.strings.rstrip(a, chars), out)
else:
assert_array_equal(np.strings.rstrip(a), out)

@pytest.mark.parametrize("a,chars,out", [
("", None, ""),
Expand Down