Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpy/_core/src/multiarray/abstractdtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ npy_mark_tmp_array_if_pyscalar(
* a custom DType registered, and then we should use that.
* Further, `np.float64` is a double subclass, so must reject it.
*/
// TODO,NOTE: This function should be changed to do exact long checks
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue for this? Otherwise, we'll probably find this for numpy 2.13 or so...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a specific one that is milestoned.

// For NumPy 2.1!
if (PyLong_Check(obj)
&& (PyArray_ISINTEGER(arr) || PyArray_ISOBJECT(arr))) {
((PyArrayObject_fields *)arr)->flags |= NPY_ARRAY_WAS_PYTHON_INT;
Expand Down
24 changes: 17 additions & 7 deletions numpy/_core/src/umath/dispatching.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "common.h"
#include "npy_pycompat.h"

#include "arrayobject.h"
#include "dispatching.h"
#include "dtypemeta.h"
#include "npy_hashtable.h"
Expand All @@ -64,7 +65,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
PyArrayObject *const ops[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool allow_legacy_promotion);
npy_bool legacy_promotion_is_possible);


/**
Expand Down Expand Up @@ -759,7 +760,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
PyArrayObject *const ops[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool allow_legacy_promotion)
npy_bool legacy_promotion_is_possible)
{
/*
* Fetch the dispatching info which consists of the implementation and
Expand Down Expand Up @@ -828,7 +829,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
* However, we need to give the legacy implementation a chance here.
* (it will modify `op_dtypes`).
*/
if (!allow_legacy_promotion || ufunc->type_resolver == NULL ||
if (!legacy_promotion_is_possible || ufunc->type_resolver == NULL ||
(ufunc->ntypes == 0 && ufunc->userloops == NULL)) {
/* Already tried or not a "legacy" ufunc (no loop found, return) */
return NULL;
Expand Down Expand Up @@ -935,11 +936,11 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool force_legacy_promotion,
npy_bool allow_legacy_promotion,
npy_bool promoting_pyscalars,
npy_bool ensure_reduce_compatible)
{
int nin = ufunc->nin, nargs = ufunc->nargs;
npy_bool legacy_promotion_is_possible = NPY_TRUE;

/*
* Get the actual DTypes we operate with by setting op_dtypes[i] from
Expand All @@ -964,11 +965,20 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
*/
Py_CLEAR(op_dtypes[i]);
}
/*
* If the op_dtype ends up being a non-legacy one, then we cannot use
* legacy promotion (unless this is a python scalar).
*/
if (op_dtypes[i] != NULL && !NPY_DT_is_legacy(op_dtypes[i]) && (
signature[i] != NULL || // signature cannot be a pyscalar
!(PyArray_FLAGS(ops[i]) & NPY_ARRAY_WAS_PYTHON_LITERAL))) {
legacy_promotion_is_possible = NPY_FALSE;
}
}

int current_promotion_state = get_npy_promotion_state();

if (force_legacy_promotion
if (force_legacy_promotion && legacy_promotion_is_possible
&& current_promotion_state == NPY_USE_LEGACY_PROMOTION
&& (ufunc->ntypes != 0 || ufunc->userloops != NULL)) {
/*
Expand All @@ -986,7 +996,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
/* Pause warnings and always use "new" path */
set_npy_promotion_state(NPY_USE_WEAK_PROMOTION);
PyObject *info = promote_and_get_info_and_ufuncimpl(ufunc,
ops, signature, op_dtypes, allow_legacy_promotion);
ops, signature, op_dtypes, legacy_promotion_is_possible);
set_npy_promotion_state(current_promotion_state);

if (info == NULL) {
Expand Down Expand Up @@ -1032,7 +1042,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
Py_INCREF(signature[0]);
return promote_and_get_ufuncimpl(ufunc,
ops, signature, op_dtypes,
force_legacy_promotion, allow_legacy_promotion,
force_legacy_promotion,
promoting_pyscalars, NPY_FALSE);
}

Expand Down
1 change: 0 additions & 1 deletion numpy/_core/src/umath/dispatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool force_legacy_promotion,
npy_bool allow_legacy_promotion,
npy_bool promote_pyscalars,
npy_bool ensure_reduce_compatible);

Expand Down
30 changes: 30 additions & 0 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,25 @@ all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
PyArray_DTypeMeta *const signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
if ((op_dtypes[0] != &PyArray_StringDType &&
op_dtypes[1] != &PyArray_StringDType &&
op_dtypes[2] != &PyArray_StringDType)) {
/*
* This promoter was triggered with only unicode arguments, so use
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems confusing - should StringDType really decide the result for something that does not include itself? Shouldn't that be up to UnicodeDType? What happens if we return -1 here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things would just fail the operation, I am very sure I added it for a reason. The problem is that you would have to decide that UU->U is clearly better than UU->T. But there is nothing to decide that by, so the machinery prefers the promoter, since the promoter has the ability to resolve which one to actual use (also go to UU->U).

The other thing you might not like is that it matches at all, but that would need one of two new features (which is fine):

  • Always call promoters right away even if there may be much better matches and give them the ability to say "don't know" (i.e. a -1 return might be that without an error set).
  • Allow some heuristic for it like, "must contain this DType" as an additional dtype.
  • Allow "matching" via a second function.

The other solution for the particular case is that if there wasn't legacy promotion involved, I would like a default promoter that ensures that ufunc(..., dtype=X) will search for ufunc(..., signature=(X, X, X)) as well if there is otherwise no match. That would do the right thing here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was probably good to get this in, but it still feels weird to have a promotor for a given dtype return a result that does not involve that type at all - how can it decide for another type what is acceptable? Two of your solutions sound reasonable: returning the equivalent of NotImplemented (your -1 with no error set), or a default promotor. Maybe worth a new issue?

* unicode. This can happen due to `dtype=` support which sets the
* output DType/signature.
*/
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_UnicodeDType);
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_UnicodeDType);
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_UnicodeDType);
return 0;
}
if ((signature[0] == &PyArray_UnicodeDType &&
signature[1] == &PyArray_UnicodeDType &&
signature[2] == &PyArray_UnicodeDType)) {
/* Unicode forced, but didn't override a string input: invalid */
return -1;
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part makes me wonder if I should just check it after the promoter is done and invalidate the result if this is violated. But it is OK here also.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be better to enforce that there, if only because IMO DType authors shouldn't have to worry about that case or add code to account for it to write a correct DType.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it seems strange one would even get here if the signature is already clear that StringDType should not be involved.

}
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_StringDType);
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_StringDType);
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_StringDType);
Expand Down Expand Up @@ -2532,6 +2551,17 @@ init_stringdtype_ufuncs(PyObject *umath)
return -1;
}

PyArray_DTypeMeta *out_strings_promoter_dtypes[] = {
&PyArray_UnicodeDType,
&PyArray_UnicodeDType,
&PyArray_StringDType,
};

if (add_promoter(umath, "add", out_strings_promoter_dtypes, 3,
all_strings_promoter) < 0) {
return -1;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!


INIT_MULTIPLY(Int64, int64);
INIT_MULTIPLY(UInt64, uint64);

Expand Down
38 changes: 12 additions & 26 deletions numpy/_core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ static int
convert_ufunc_arguments(PyUFuncObject *ufunc,
ufunc_full_args full_args, PyArrayObject *out_op[],
PyArray_DTypeMeta *out_op_DTypes[],
npy_bool *force_legacy_promotion, npy_bool *allow_legacy_promotion,
npy_bool *force_legacy_promotion,
npy_bool *promoting_pyscalars,
PyObject *order_obj, NPY_ORDER *out_order,
PyObject *casting_obj, NPY_CASTING *out_casting,
Expand All @@ -622,7 +622,6 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
/* Convert and fill in input arguments */
npy_bool all_scalar = NPY_TRUE;
npy_bool any_scalar = NPY_FALSE;
*allow_legacy_promotion = NPY_TRUE;
*force_legacy_promotion = NPY_FALSE;
*promoting_pyscalars = NPY_FALSE;
for (int i = 0; i < nin; i++) {
Expand Down Expand Up @@ -657,11 +656,6 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
break;
}

if (!NPY_DT_is_legacy(out_op_DTypes[i])) {
*allow_legacy_promotion = NPY_FALSE;
// TODO: A subclass of int, float, complex could reach here and
// it should not be flagged as "weak" if it does.
}
if (PyArray_NDIM(out_op[i]) == 0) {
any_scalar = NPY_TRUE;
}
Expand Down Expand Up @@ -707,7 +701,7 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
*promoting_pyscalars = NPY_TRUE;
}
}
if (*allow_legacy_promotion && (!all_scalar && any_scalar)) {
if ((!all_scalar && any_scalar)) {
*force_legacy_promotion = should_use_min_scalar(nin, out_op, 0, NULL);
}

Expand Down Expand Up @@ -2351,8 +2345,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
}

PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
ops, signature, operation_DTypes, NPY_FALSE, NPY_TRUE,
NPY_FALSE, NPY_TRUE);
ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE, NPY_TRUE);
if (evil_ndim_mutating_hack) {
((PyArrayObject_fields *)out)->nd = 0;
}
Expand Down Expand Up @@ -4433,13 +4426,12 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
npy_bool subok = NPY_TRUE;
int keepdims = -1; /* We need to know if it was passed */
npy_bool force_legacy_promotion;
npy_bool allow_legacy_promotion;
npy_bool promoting_pyscalars;
if (convert_ufunc_arguments(ufunc,
/* extract operand related information: */
full_args, operands,
operand_DTypes,
&force_legacy_promotion, &allow_legacy_promotion,
&force_legacy_promotion,
&promoting_pyscalars,
/* extract general information: */
order_obj, &order,
Expand All @@ -4460,7 +4452,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
*/
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
operands, signature,
operand_DTypes, force_legacy_promotion, allow_legacy_promotion,
operand_DTypes, force_legacy_promotion,
promoting_pyscalars, NPY_FALSE);
if (ufuncimpl == NULL) {
goto fail;
Expand Down Expand Up @@ -5790,22 +5782,20 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
operand_DTypes[0] = NPY_DTYPE(PyArray_DESCR(op1_array));
Py_INCREF(operand_DTypes[0]);
int force_legacy_promotion = 0;
int allow_legacy_promotion = NPY_DT_is_legacy(operand_DTypes[0]);

if (op2_array != NULL) {
tmp_operands[1] = op2_array;
operand_DTypes[1] = NPY_DTYPE(PyArray_DESCR(op2_array));
Py_INCREF(operand_DTypes[1]);
allow_legacy_promotion &= NPY_DT_is_legacy(operand_DTypes[1]);
tmp_operands[2] = tmp_operands[0];
operand_DTypes[2] = operand_DTypes[0];
Py_INCREF(operand_DTypes[2]);

if (allow_legacy_promotion && ((PyArray_NDIM(op1_array) == 0)
!= (PyArray_NDIM(op2_array) == 0))) {
/* both are legacy and only one is 0-D: force legacy */
force_legacy_promotion = should_use_min_scalar(2, tmp_operands, 0, NULL);
}
if ((PyArray_NDIM(op1_array) == 0)
!= (PyArray_NDIM(op2_array) == 0)) {
/* both are legacy and only one is 0-D: force legacy */
force_legacy_promotion = should_use_min_scalar(2, tmp_operands, 0, NULL);
}
}
else {
tmp_operands[1] = tmp_operands[0];
Expand All @@ -5816,7 +5806,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)

ufuncimpl = promote_and_get_ufuncimpl(ufunc, tmp_operands, signature,
operand_DTypes, force_legacy_promotion,
allow_legacy_promotion, NPY_FALSE, NPY_FALSE);
NPY_FALSE, NPY_FALSE);
if (ufuncimpl == NULL) {
for (int i = 0; i < 3; i++) {
Py_XDECREF(signature[i]);
Expand Down Expand Up @@ -6058,7 +6048,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
set_npy_promotion_state(NPY_USE_WEAK_PROMOTION);

npy_bool promoting_pyscalars = NPY_FALSE;
npy_bool allow_legacy_promotion = NPY_TRUE;

if (_get_fixed_signature(ufunc, NULL, signature_obj, signature) < 0) {
goto finish;
Expand Down Expand Up @@ -6091,9 +6080,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
}
DTypes[i] = NPY_DTYPE(descr);
Py_INCREF(DTypes[i]);
if (!NPY_DT_is_legacy(DTypes[i])) {
allow_legacy_promotion = NPY_FALSE;
}
}
/* Explicitly allow int, float, and complex for the "weak" types. */
else if (descr_obj == (PyObject *)&PyLong_Type) {
Expand Down Expand Up @@ -6149,7 +6135,7 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
if (!reduction) {
ufuncimpl = promote_and_get_ufuncimpl(ufunc,
dummy_arrays, signature, DTypes, NPY_FALSE,
allow_legacy_promotion, promoting_pyscalars, NPY_FALSE);
promoting_pyscalars, NPY_FALSE);
if (ufuncimpl == NULL) {
goto finish;
}
Expand Down
25 changes: 25 additions & 0 deletions numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,31 @@ def test_add_promoter(string_list):
assert_array_equal(op + arr, lresult)
assert_array_equal(arr + op, rresult)

# The promoter should be able to handle things if users pass `dtype=`
res = np.add("hello", string_list, dtype=StringDType)
Copy link
Copy Markdown
Member

@ngoldbaum ngoldbaum Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth using the dtype fixture for this since na_object and coerce doesn't matter, but maybe worth making dtype a parameter of the test that can either by StringDType or StringDType(). I could also see perhaps defining a dtype_lass_or_instance fixture and using that in a few other places in this file where we just test with StringDType() or "T".

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work. Signatures are DType classes. It should work at least for "T", but otherwise need to look into logic to say "this is OK".
I am not sure it should be, but it is a different issue in either case.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks for explaining. I saw that error before but thought this change made dtype instances OK.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, not hat happens much earlier, I explicitly allowed the singleton instancs of legacy dtypes (or maybe all singleton instances, not sure), because otherwise things would be tricky.

But, we have the "give me the DType" now also, which maybe (not sure!) makes T work. It should make T work in either case, though.

assert res.dtype == StringDType()

# The promoter should not kick in if users override the input,
# which means arr is cast, this fails because of the unknown length.
with pytest.raises(TypeError, match="cannot cast dtype"):
np.add(arr, "add", signature=("U", "U", None), casting="unsafe")

# But it must simply reject the following:
with pytest.raises(TypeError, match=".*did not contain a loop"):
np.add(arr, "add", signature=(None, "U", None))

with pytest.raises(TypeError, match=".*did not contain a loop"):
np.add("a", "b", signature=("U", "U", StringDType))


def test_add_no_legacy_promote_with_signature():
# Possibly misplaced, but useful to test with string DType. We check that
# if there is clearly no loop found, a stray `dtype=` doesn't break things
# Regression test for the bad error in gh-26735
# (If legacy promotion is gone, this can be deleted...)
with pytest.raises(TypeError, match=".*did not contain a loop"):
np.add("3", 6, dtype=StringDType)


def test_add_promoter_reduce():
# Exact TypeError could change, but ensure StringDtype doesn't match
Expand Down