import logging
import sys
import warnings
from collections.abc import Callable, Sequence
from itertools import chain, groupby, zip_longest
from typing import TypeVar, cast, overload
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
import pytensor
from pytensor import scalar as ps
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
TensorLike,
_get_vector_length,
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_scalar_constant_value,
nonzero,
)
from pytensor.tensor.basic import (
constant as tensor_constant,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import (
Reshape,
Shape_i,
specify_broadcastable,
)
from pytensor.tensor.type import (
TensorType,
complex_dtypes,
discrete_dtypes,
integer_dtypes,
tensor,
)
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import unzip
_logger = logging.getLogger("pytensor.tensor.subtensor")
T = TypeVar("T")
def flatten_index_variables(
idx_vars: Sequence[T | None | slice],
) -> tuple[list[int | slice], list[T]]:
counter = 0
idx_list: list[int | slice] = []
flat_vars = []
for idx_var in idx_vars:
if isinstance(idx_var, slice):
slice_idx_list: list[None | int] = []
for arg_entry in (idx_var.start, idx_var.stop, idx_var.step):
if arg_entry is None or (
isinstance(arg_entry, Variable)
and isinstance(arg_entry.type, NoneTypeT)
):
slice_idx_list.append(None)
else:
flat_vars.append(arg_entry)
slice_idx_list.append(counter)
counter += 1
idx_list.append(slice(*slice_idx_list))
else:
flat_vars.append(idx_var)
idx_list.append(counter)
counter += 1
return idx_list, flat_vars
def unflatten_index_variables(
flat_indices: Sequence[T],
idx_list: Sequence[slice | int],
) -> tuple[slice | T, ...]:
indices: list[T | slice] = []
for idx_entry in idx_list:
if isinstance(idx_entry, int):
indices.append(flat_indices[idx_entry])
elif isinstance(idx_entry, slice):
start, stop, step = idx_entry.start, idx_entry.stop, idx_entry.step
indices.append(
slice(
None if idx_entry.start is None else flat_indices[start],
None if idx_entry.stop is None else flat_indices[stop],
None if idx_entry.step is None else flat_indices[step],
)
)
else:
raise ValueError(f"idx_entry must be int or slice, got {type(idx_entry)}")
return tuple(indices)
def indices_from_subtensor(
op_indices: Sequence[Variable],
idx_list: tuple[slice | int, ...],
) -> tuple[slice | Variable, ...]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
----------
op_indices
The flattened indices obtained from ``x.inputs``, when ``x`` is a ``*Subtensor*`` node.
idx_list
The values describing each dimension's index. This is obtained from
``op.idx_list``. Entries can be:
- Integer positions (indices into op_indices)
- slice objects with int/None components
Returns
-------
tuple[slice | Variable, ...]
A tuple containing a mix of ``slice`` objects and ``Variable`` objects.
Each element corresponds to one indexing dimension:
- ``slice`` objects for slice-based indexing (e.g., ``x[1:3]``)
- ``Variable`` objects for scalar or array-based indexing
Callers should handle both types when iterating over the result.
Example
-------
array, *op_indices = subtensor_node.inputs
indices = indices_from_subtensor(op_indices, subtensor_node.op.idx_list)
"""
return unflatten_index_variables(op_indices, idx_list)
def as_index_constant(
a: slice | int | np.integer | Variable | None | TensorLike,
) -> Variable | slice | None:
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
This will leave `Variable`\s untouched.
"""
if a is None:
return a
elif isinstance(a, slice):
return slice(
as_index_constant(a.start),
as_index_constant(a.stop),
as_index_constant(a.step),
)
elif isinstance(a, int | np.integer):
return ps.ScalarConstant(ps.int64, a)
elif isinstance(a, Variable):
return a
return as_tensor_variable(a)
@overload
def as_index_literal(idx: int | np.integer) -> int | np.integer: ...
@overload
def as_index_literal(idx: None) -> None: ...
@overload
def as_index_literal(idx: slice) -> slice: ...
@overload
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...
@overload
def as_index_literal(idx: Variable): ...
def as_index_literal(
idx: None | int | np.integer | slice | ScalarConstant | TensorConstant | Variable,
) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant`
Raises
------
NotScalarConstantError
"""
if idx is None or isinstance(idx, int | np.integer):
return idx
if isinstance(idx, slice):
return slice(
as_index_literal(idx.start),
as_index_literal(idx.stop),
as_index_literal(idx.step),
)
if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}")
if isinstance(idx, ScalarConstant):
return cast(int, idx.data)
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return cast(int | np.integer, as_index_literal(idx.owner.inputs[0]))
if isinstance(idx, TensorConstant):
return cast(int, idx.data.item())
# Other kinds of variables are not supported
raise NotScalarConstantError()
def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)
@overload
def get_canonical_form_slice(
theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | TensorVariable]: ...
@overload
def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[TensorVariable, int]: ...
def get_canonical_form_slice(
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice | TensorVariable, int | TensorVariable]:
"""Convert indices or slices to canonical form.
Handles slice objects with ScalarVariable (including ScalarConstant) or None components.
Vector indices and advanced indexing operations are handled separately by AdvancedSubtensor.
Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
In a canonical form a slice is represented by a canonical form slice,
in which 0 <= start <= stop <= length and step > 0, and a flag which says
if the resulting set of numbers needs to be reversed or not.
Given a scalar index `idx` that may or not be negative, convert it to
a certainly positive form `idx if idx >= 0 else length + idx`.
Returns
-------
slc
Canonical form slice or scalar variable.
direction
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
"""
from pytensor.tensor import ge, lt, sign, switch
def undo_scalarization(x):
"""Undo scalarization of a variable.
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
But reasoning symbolically about the result of multiple indexing operations, we usually
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
"""
if isinstance(x, ScalarVariable):
if isinstance(x, ScalarConstant):
return tensor_constant(x.data, dtype=x.dtype)
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
return x.owner.inputs[0]
else:
return as_tensor_variable(x)
return x
def analyze(x):
try:
x_constant = as_index_literal(x)
is_constant = True
except NotScalarConstantError:
x_constant = undo_scalarization(x)
is_constant = False
return x_constant, is_constant
length, is_length_constant = analyze(length)
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
if not (
isinstance(theslice, int | np.integer | ScalarVariable)
or (isinstance(theslice, TensorVariable) and theslice.ndim == 0)
):
raise ValueError(f"Slice {theslice} is not a supported slice type.")
idx, is_index_constant = analyze(theslice)
if is_index_constant:
if idx >= 0:
return idx, 1
else:
return idx + length, 1
else:
return switch(lt(idx, 0), idx + length, idx), 1
# At this point we have a slice object. Possibly with symbolic inputs.
start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step)
if (
is_start_constant
and is_stop_constant
and is_step_constant
and is_length_constant
):
assert isinstance(length, int | np.integer)
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
if step is None:
step = 1
is_step_constant = True
# First handle the easier and common case where `step` is 1 and
# either `start` or `stop` is a range boundary. More specializations
# could be added later. This makes the resulting graph smaller than
# in the generic case below.
if step == 1:
is_start_0 = (
start is None
or start == 0
or (
is_start_constant
and is_length_constant
and start < 0
and start + length <= 0
)
)
is_stop_length = (
stop is None
or stop in [length, sys.maxsize]
or (is_stop_constant and is_length_constant and stop >= length)
)
if is_start_0:
# 0:stop:1
if is_stop_length:
# Full slice.
return slice(0, length, 1), 1
if is_stop_constant and stop >= 0:
return (slice(0, switch(lt(stop, length), stop, length), 1), 1)
stop_plus_len = stop + length
stop = switch(
lt(stop, 0),
# stop < 0
switch(
lt(stop_plus_len, 0),
# stop + len < 0
0,
# stop + len >= 0
stop_plus_len,
),
# stop >= 0: use min(stop, length)
switch(lt(stop, length), stop, length),
)
return slice(0, stop, 1), 1
elif is_stop_length:
# start:length:1
if is_start_constant and start >= 0:
return slice(switch(lt(start, length), start, length), length, 1), 1
start_plus_len = start + length
start = switch(
lt(start, 0),
# start < 0
switch(
lt(start_plus_len, 0),
# start + len < 0
0,
# start + len >= 0
start_plus_len,
),
# start >= 0: use min(start, length)
switch(lt(start, length), start, length),
)
return slice(start, length, 1), 1
# This is the generic case.
if is_step_constant:
# When we know the sign of `step`, the graph can be made simpler.
assert step != 0
if step > 0:
def switch_neg_step(a, b):
return b
abs_step = step
sgn_step = 1
else:
def switch_neg_step(a, b):
return a
abs_step = -step
sgn_step = -1
else:
is_step_neg = lt(step, 0)
def switch_neg_step(a, b):
return switch(is_step_neg, a, b)
abs_step = abs(step)
sgn_step = sign(step)
defstart = switch_neg_step(length - 1, 0)
defstop = switch_neg_step(-1, length)
if start is None:
start = defstart
else:
start = switch(lt(start, 0), start + length, start)
start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
start = switch(ge(start, length), switch_neg_step(length - 1, length), start)
if stop is None or stop == sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
stop = defstop
else:
stop = switch(lt(stop, 0), stop + length, stop)
stop = switch(lt(stop, 0), -1, stop)
stop = switch(ge(stop, length), length, stop)
nw_stop = switch_neg_step(start + 1, stop)
slice_len = (start - stop - 1) // abs_step + 1
slice_len = switch(lt(slice_len, 0), 0, slice_len)
neg_start = nw_stop - (slice_len - 1) * abs_step - 1
neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
nw_start = switch_neg_step(neg_start, start)
nw_start = switch(lt(nw_start, 0), 0, nw_start)
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
# Ensure start <= stop.
nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop)
nw_step = abs_step
if step != 1:
reverse = sgn_step
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), 1
def slice_len(slc, n):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
Adapted from CPython.
"""
from pytensor.tensor import and_, gt, lt, switch
# TODO: Do we need to do this or should we expect `slc` to already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
start, stop, step = tuple(
as_index_constant(a) for a in [canon_slc.start, canon_slc.stop, canon_slc.step]
)
return switch(
and_(gt(step, 0), lt(start, stop)),
1 + (stop - 1 - start) // step,
switch(
and_(lt(step, 0), gt(start, stop)),
1 + (start - 1 - stop) // (-step),
ps.ScalarConstant(ps.int64, 0),
),
)
def basic_shape(shape, indices):
r"""Computes the shape resulting from basic NumPy indexing.
Basic indices are either ``slice``\s or ``None``\s. ``Ellipsis`` are not
supported here; convert them to ``slice``\s first.
Parameters
----------
shape: Tuple[int, ...]
The shape of the array being indexed
indices: Sequence[Or[slice, NoneType]]
A sequence of basic indices used to index an array.
"""
res_shape = ()
for n, idx in zip(shape[: len(indices)], indices, strict=True):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),)
else:
raise ValueError(f"Invalid index type: {idx}")
return res_shape
def group_indices(indices):
"""Group indices sequentially by whether or not they're basic or advanced.
Returns
-------
Tuple[Boolean, List[Tuple[Integer, Any]]]
The boolean indicates whether or not the group is a set of basic
indices. The list contains the contiguous set of indices paired with their
corresponding dimension number in the array being indexed.
"""
idx_groups = []
dim_num = -1
for basic, grp_indices in groupby(indices, key=lambda x: isinstance(x, slice)):
enum_grp_indices = []
for idx in grp_indices:
# We "zip" the dimension number to each index, which means we can't
# count indices that add new axes
if idx is not None:
dim_num += 1
enum_grp_indices.append((dim_num, idx))
idx_groups.append((basic, enum_grp_indices))
return idx_groups
def _non_consecutive_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
This function uses NumPy's basic and advanced indexing logic. It can also
handle combinations of advanced and basic indices.
Parameters
----------
array_shape: Tuple[Variable, ...]
Shape of the array being indexed.
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable], ...]]]
Either the indices themselves or the shapes of each index--depending
on the value of `indices_are_shapes`.
indices_are_shapes: bool (Optional)
Indicates whether or not the `indices` contains shape tuples instead of
the actual index arrays. If you use this approach, make sure that the
broadcastable dimensions are (scalar) constants with the value `1`, or `1`
exactly.
"""
res_shape = ()
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices)
if _non_consecutive_adv_indexing(indices):
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0])
idx_groups = groupby(
chain.from_iterable(d_idx for _, d_idx in idx_groups),
key=lambda x: isinstance(x[1], slice),
)
for basic, grp_dim_indices in idx_groups:
dim_nums, grp_indices = unzip(grp_dim_indices, n=2, strict=True)
remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums)
if basic:
grp_shapes = tuple(array_shape[dim] for dim in dim_nums)
res_shape += basic_shape(grp_shapes, grp_indices)
else:
from pytensor.tensor.extra_ops import broadcast_shape
res_shape += broadcast_shape(
*grp_indices,
arrays_are_shapes=indices_are_shapes,
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
# As long as that is true, the shape inference has to respect that this is not an error.
allow_runtime_broadcast=True,
)
res_shape += tuple(array_shape[dim] for dim in remaining_dims)
return res_shape
def get_slice_elements(
idxs: Sequence,
cond: Callable = lambda x: isinstance(x, Variable),
) -> list:
"""Extract slice elements conditional on a given predicate function.
Parameters
----------
idxs : a list of indices or slices.
cond : a callable that returns a bool
Returns
-------
list
idxs, with the slices flattened out into a list.
If cond is true for an entry, does not flatten it.
"""
ret = []
def helper(entry):
if cond(entry):
ret.append(entry)
elif isinstance(entry, slice):
helper(entry.start)
helper(entry.stop)
helper(entry.step)
for idx in idxs:
helper(idx)
return ret
def get_constant_idx(
idx_list, inputs, allow_partial=False, only_process_constants=False, elemwise=True
):
r"""Return an `idx_list` with its constant inputs replaced by their Python scalar equivalents.
May raise `NotScalarConstantError` if the indices contain non-constant entries.
If `allow_partial` is ``True``, then entries that are not constant will
stay as their input variable rather than raising an exception.
``None`` entries are always left as-is.
Parameters
----------
only_process_constants
If ``True``, we only attempt to obtain the value of an index/slice if
it's directly constant and don't try to dig through `DimShuffle`\s,
fills, `Alloc`\s, and other to figure out its value.
Examples
--------
Example usage where `v` and `a` are appropriately typed PyTensor variables :
>>> from pytensor.scalar import int64
>>> from pytensor.tensor import matrix
>>> import numpy as np
>>>
>>> v = int64("v")
>>> a = matrix("a")
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(0, slice(1, 2, None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError
"""
real_idx = get_idx_list(inputs, idx_list)
# TODO: Combine this with `as_index_literal`
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start), conv(val.stop), conv(val.step))
else:
try:
return get_scalar_constant_value(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
).item()
except NotScalarConstantError:
if allow_partial:
return val
else:
raise
return list(map(conv, real_idx))
def as_scalar_index_variable(idx) -> ps.ScalarVariable:
idx = ps.as_scalar(idx)
if idx.type.dtype not in integer_dtypes:
raise TypeError("basic indices must be integers")
return idx # type: ignore[no-any-return]
def slice_static_length(slc, dim_length):
if dim_length is None:
# TODO: Some cases must be zero by definition, we could handle those
return None
entries = [None, None, None]
for i, entry in enumerate((slc.start, slc.stop, slc.step)):
if entry is None:
continue
try:
entries[i] = get_scalar_constant_value(entry)
except NotScalarConstantError:
return None
return len(range(*slice(*entries).indices(dim_length)))
class BaseSubtensor:
"""Base class for Subtensor operations that handles idx_list and hash/equality."""
def __init__(self, idx_list: Sequence[int | slice]):
index_counter = -1
for idx_entry in idx_list:
if isinstance(idx_entry, int):
if idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = idx_entry
elif isinstance(idx_entry, slice):
for slice_idx_entry in (
idx_entry.start,
idx_entry.stop,
idx_entry.step,
):
if slice_idx_entry is not None:
if not isinstance(slice_idx_entry, int):
raise ValueError(
f"idx_list slice entries must be None or integer, got {slice_idx_entry} in {idx_entry}"
)
if slice_idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = slice_idx_entry
else:
raise ValueError(
f"idx_list entries must be int or slice, got {idx_entry}"
)
self.n_index_vars = index_counter + 1
self.idx_list = tuple(idx_list)
def _hashable_idx_list(self):
"""Return a hashable version of idx_list (slices converted to tuples).
Slices are not hashable in Python < 3.12, so we convert them to tuples.
"""
return tuple(
(slice, entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list
)
def __hash__(self):
# Temporary workaround: slices are hashable in Python 3.12+
props_values = tuple(
self._hashable_idx_list() if prop == "idx_list" else getattr(self, prop)
for prop in self.__props__
)
return hash((type(self), props_values))
class Subtensor(BaseSubtensor, COp):
"""Basic NumPy indexing operator."""
check_input = False
view_map = {0: [0]}
_f16_ok = True
__props__ = ("idx_list",)
__hash__ = BaseSubtensor.__hash__
def make_node(self, x, *inputs):
"""
Parameters
----------
x
The tensor to take a subtensor of.
inputs
A list of pytensor Scalars.
"""
x = as_tensor_variable(x)
inputs = tuple(as_scalar_index_variable(a) for a in inputs)
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_positions = get_slice_elements(
idx_list, lambda entry: isinstance(entry, int)
)
assert len(inputs) == len(input_positions)
padded = [
*indices_from_subtensor(inputs, self.idx_list),
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
]
out_shape = [
slice_static_length(slc, length)
for slc, length in zip(padded, x.type.shape, strict=True)
if isinstance(slc, slice)
]
return Apply(
self,
(x, *inputs),
[tensor(dtype=x.type.dtype, shape=out_shape)],
)
def perform(self, node, inputs, out_):
(out,) = out_
x, *index_variables = inputs
cdata = unflatten_index_variables(index_variables, self.idx_list)
out[0] = np.asarray(x.__getitem__(tuple(cdata)))
def infer_shape(self, fgraph, node, shapes):
def _is_constant(const, x):
return isinstance(const, Constant) and const.data.item() == x
xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim
outshp = []
actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
padded = actual_idx_list + [slice(None, None, None)] * (
len(xshp) - len(self.idx_list)
)
i = 0
for idx, xl in zip(padded, xshp, strict=True):
if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if (
(idx.start is None or _is_constant(idx.start, 0))
and (idx.stop is None or _is_constant(idx.stop, sys.maxsize))
and (idx.step is None or _is_constant(idx.step, 1))
):
outshp.append(xl)
elif (
(idx.start is None)
and (idx.stop is None)
and _is_constant(idx.step, -1)
):
# Reverse slice
outshp.append(xl)
else:
cnf = get_canonical_form_slice(idx, xl)[0]
if cnf.step == 1:
length = cnf.stop - cnf.start
else:
length = (cnf.stop - cnf.start - 1) // cnf.step + 1
outshp.append(length)
i += 1
else:
# That dimension is dropped
pass
assert i == node.outputs[0].ndim
assert len(outshp) == node.outputs[0].ndim
return [outshp]
def pullback(self, inputs, outputs, output_grads):
(gz,) = output_grads
x, *index_variables = inputs
if x.dtype in discrete_dtypes:
first = x.zeros_like(dtype=config.floatX)
else:
# For best optimization, we let this as an inc.
# This allow the opt local_IncSubtensor_serialize to apply first.
# We have an optimization that will convert this to a
# set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *index_variables)
return [first, *(disconnected_type() for _ in range(len(index_variables)))]
def connection_pattern(self, node):
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
@staticmethod
def str_from_slice(entry):
if entry.step is not None:
return ":".join(
(
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"
@staticmethod
def str_from_indices(idx_list):
indices = []
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(Subtensor.str_from_slice(entry))
else:
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)
def __str__(self):
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"
@staticmethod
def default_helper_c_code_args():
"""
Returns a dictionary of default arguments to helper_c_code.
"""
return {"c_prefix": "PyArray", "strides_mul": 1}
@staticmethod
def helper_c_code(
node,
name,
inputs,
outputs,
sub,
idx_list,
view_ndim,
c_prefix=None,
strides_mul=None,
):
"""
The parameters c_prefix are there to allow reusing this
function on PyArray object.
This fct take as input the x.
"""
default_args = Subtensor.default_helper_c_code_args()
if strides_mul is None:
strides_mul = default_args["strides_mul"]
if c_prefix is None:
c_prefix = default_args["c_prefix"]
#
# two arrays are created in C code:
# is_slice: len == ndim, 0 means int, 1 means slice
# subtensor_spec: len = n_ints + 3 * n_slices
#
fail = sub["fail"]
init_cmds = [] # initialization for subtensor_spec
is_slice = []
# TODO: change that, it might lead to unexpected results,
# see assembla-#767
NONE_CODE = sys.maxsize - 1
pos = [0, 1] # annoying version of global variable for init_entry
def inc_spec_pos(amt):
pos[0] += amt
def inc_input_pos(amt):
pos[1] += amt
def spec_pos():
return pos[0]
def input_pos():
return pos[1]
def init_entry(entry, depth=0):
if isinstance(entry, int):
init_cmds.append(
f"subtensor_spec[{spec_pos()}] = {inputs[input_pos()]};"
)
inc_spec_pos(1)
inc_input_pos(1)
if depth == 0:
is_slice.append(0)
elif entry is None:
init_cmds.append(f"subtensor_spec[{spec_pos()}] = {NONE_CODE};")
inc_spec_pos(1)
if depth == 0:
is_slice.append(0)
elif depth == 0 and isinstance(entry, slice):
init_entry(entry.start, depth + 1)
init_entry(entry.stop, depth + 1)
init_entry(entry.step, depth + 1)
is_slice.append(1)
else:
assert 0, entry
for entry in idx_list:
init_entry(entry)
# make sure we used all inputs
assert input_pos() == len(inputs), input_pos()
assert len(is_slice) <= node.inputs[0].ndim, node.inputs[0].ndim
len_is_slice = len(is_slice)
len_subtensor_spec = spec_pos()
subensor_spec = f"npy_intp subtensor_spec[{len_subtensor_spec}];"
if len_subtensor_spec == 0:
subensor_spec = "npy_intp * subtensor_spec = NULL;"
if is_slice:
is_slice_init = (
"int is_slice[] = {" + ",".join(str(s) for s in is_slice) + "};"
)
else:
is_slice_init = "int* is_slice = NULL;"
subtensor_init = "\n".join(init_cmds)
(x,) = inputs[:1]
(_z,) = outputs
if view_ndim:
rval = f"""
// Argument of the view
npy_intp xview_dims[{view_ndim}];
npy_intp xview_strides[{view_ndim}];
"""
else:
rval = """
// Argument of the view
npy_intp* xview_dims = NULL;
npy_intp* xview_strides = NULL;
"""
rval += f"""
// One more argument of the view
npy_intp xview_offset = 0;
// The subtensor is created by iterating over the dimensions
// and updating stride, shape, and data pointers
{is_slice_init}
{subensor_spec}
{subtensor_init};
int spec_pos = 0; //position in subtensor_spec
int inner_ii = 0; // the current dimension of zview
int outer_ii = 0; // current dimension of z
for (; outer_ii < {len_is_slice}; ++outer_ii)
{{
if (is_slice[outer_ii])
{{
npy_intp length = {c_prefix}_DIMS({x})[outer_ii];
npy_intp slicelength;
npy_intp start = subtensor_spec[spec_pos+0];
npy_intp stop = subtensor_spec[spec_pos+1];
npy_intp step = subtensor_spec[spec_pos+2];
if (step == {NONE_CODE}) step = 1;
npy_intp defstart = step < 0 ? length-1 : 0;
npy_intp defstop = step < 0 ? -1 : length;
// logic adapted from
// PySlice_GetIndicesEx in python source
if (!step)
{{
PyErr_Format(PyExc_ValueError,
"slice step cannot be zero");
{fail};
}}
if (start == {NONE_CODE})
{{
start = defstart;
}}
else
{{
if (start < 0) start += length;
if (start < 0) start = (step < 0) ? -1 : 0;
if (start >= length)
start = (step < 0) ? length - 1 : length;
}}
if (stop == {NONE_CODE})
{{
stop = defstop;
}}
else
{{
if (stop < 0) stop += length;
if (stop < 0) stop = (step < 0) ? -1 : 0;
if (stop >= length)
stop = (step < 0) ? length - 1 : length;
}}
if ((step < 0 && stop >= start)
|| (step > 0 && start >= stop)) {{
slicelength = 0;
}}
else if (step < 0) {{
slicelength = (stop-start+1)/step+1;
}}
else {{
slicelength = (stop-start-1)/step+1;
}}
if (0){{
fprintf(stdout, "start %zi\\n", start);
fprintf(stdout, "stop %zi\\n", stop);
fprintf(stdout, "step %zi\\n", step);
fprintf(stdout, "length %zi\\n", length);
fprintf(stdout, "slicelength %zi\\n", slicelength);
}}
assert (slicelength <= length);
xview_offset += (npy_intp){c_prefix}_STRIDES({x})[outer_ii]
* start * {strides_mul};
xview_dims[inner_ii] = slicelength;
xview_strides[inner_ii] = (npy_intp){c_prefix}_STRIDES({x})[outer_ii] * step;
inner_ii += 1;
spec_pos += 3;
}}
else // tuple coord `outer_ii` is an int
{{
int idx = subtensor_spec[spec_pos];
if (idx < 0) idx += {c_prefix}_DIMS({x})[outer_ii];
if (idx >= 0)
{{
if (idx < {c_prefix}_DIMS({x})[outer_ii])
{{
xview_offset += (npy_intp){c_prefix}_STRIDES({x})[outer_ii] * idx *
{strides_mul};
}}
else
{{
PyErr_Format(PyExc_IndexError,"index out of bounds");
{fail};
}}
}}
else
{{
PyErr_Format(PyExc_IndexError,"index out of bounds");
{fail};
}}
spec_pos += 1;
}}
}}
assert (inner_ii <= {view_ndim});
while (inner_ii < {view_ndim})
{{
assert (outer_ii < {c_prefix}_NDIM({x}));
xview_dims[inner_ii] = {c_prefix}_DIMS({x})[outer_ii];
xview_strides[inner_ii] = {c_prefix}_STRIDES({x})[outer_ii];
inner_ii += 1;
outer_ii += 1;
}}
"""
# print rval
return rval
@staticmethod
def helper_c_code_cache_version():
return (9,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
x = inputs[0]
(z,) = outputs
ndim = node.inputs[0].ndim
view_ndim = node.outputs[0].ndim
fail = sub["fail"]
decl = "PyArrayObject * xview = NULL;"
checkNDim = f"""
if (PyArray_NDIM({x}) != {ndim}){{
PyErr_SetString(PyExc_ValueError,
"Expected {ndim} dimensions input"
);
{fail}
}}
"""
get_xview = self.helper_c_code(
node, name, inputs, outputs, sub, self.idx_list, view_ndim
)
build_view = f"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR({x}));
xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR({x}),
{view_ndim},
xview_dims,
xview_strides,
PyArray_BYTES({x}) + xview_offset,
PyArray_FLAGS({x}),
NULL);
assert (PyArray_NDIM(xview) == {view_ndim});
if (!xview)
{{
{fail};
}}
"""
finish_view = f"""
Py_XDECREF({z});
Py_INCREF(py_{x});
PyArray_SetBaseObject(xview, py_{x});
assert(py_{x} == (PyObject*){x});
{z} = xview;
"""
return decl + checkNDim + "{" + get_xview + build_view + finish_view + "}"
def c_code_cache_version(self):
hv = self.helper_c_code_cache_version()
# If `helper_c_code_cache_version` is not versioned we do not want to
# have a versioned version of this op's C code.
if len(hv) == 0:
return ()
return (4, hv)
def pushforward(self, inputs, outputs, eval_points):
# Subtensor is not differentiable wrt to its indices, therefore we
# do not even need to consider the eval_points provided for those
# (they should be defaulted to zeros_like by the global R_op)
if isinstance(eval_points[0].type, DisconnectedType):
return [disconnected_type()]
_x, *index_variables = inputs
return self(eval_points[0], *index_variables, return_list=True)
def basic_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return Subtensor(idx_list)(x, *flat_index_vars)
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = get_idx_list(var.owner.inputs, var.owner.op.idx_list)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
class SubtensorPrinter(Printer):
def process(self, r, pstate):
return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)
def _process(self, idxs, op_inputs, pstate):
inputs = list(op_inputs)
input = inputs.pop(0)
sidxs = []
getattr(pstate, "precedence", None)
def process_slice_component(comp):
"""Process a slice component, returning string representation."""
if comp is None:
return ""
elif isinstance(comp, int):
with set_precedence(pstate):
return pstate.pprinter.process(inputs.pop(0))
else:
return str(comp)
for entry in idxs:
if isinstance(entry, int):
with set_precedence(pstate):
sidxs.append(pstate.pprinter.process(inputs.pop(0)))
elif isinstance(entry, slice):
msg1 = process_slice_component(entry.start)
msg2 = process_slice_component(entry.stop)
if entry.step is None:
msg3 = ""
else:
msg3 = f":{process_slice_component(entry.step)}"
sidxs.append(f"{msg1}:{msg2}{msg3}")
with set_precedence(pstate, 1000):
sub = pstate.pprinter.process(input, pstate)
return f"{sub}[{', '.join(sidxs)}]"
pprint.assign(Subtensor, SubtensorPrinter())
class IncSubtensor(BaseSubtensor, COp):
"""
Increment a subtensor.
This is like numpy's
x[i,j,k] += y
It is used internally to implement the gradient on SubTensor.
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
"""
check_input = False
__props__ = (
"idx_list",
"inplace",
"set_instead_of_inc",
"destroyhandler_tolerate_aliased",
)
__hash__ = BaseSubtensor.__hash__
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None,
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = ()
super().__init__(idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = tuple(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc
def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def make_node(self, x, y, *inputs):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs
The indeces/slices list to increment in combination with idx_list.
E.g. self._idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))
tell to use inputs[0] as the first dim.
"""
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(as_scalar_index_variable, inputs))
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
if len(inputs) != self.n_index_vars:
raise ValueError(
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
return Apply(self, (x, y, *inputs), [x.type()])
def decl_view(self):
return "PyArrayObject * zview = NULL;"
def perform(self, node, inputs, output_storage):
x, y, *flat_indices = inputs
flat_indices_iterator = iter(flat_indices)
indices = tuple(
(
next(flat_indices_iterator)
if isinstance(entry, int)
else slice(
None if entry.start is None else next(flat_indices_iterator),
None if entry.stop is None else next(flat_indices_iterator),
None if entry.step is None else next(flat_indices_iterator),
)
)
for entry in self.idx_list
)
if not self.inplace:
x = x.copy()
if self.set_instead_of_inc:
x[indices] = y
else:
x[indices] += y
output_storage[0][0] = x
def c_code(self, node, name, inputs, outputs, sub):
# This method delegates much of the work to helper
# methods. This method implements the main logic
# but subclasses may override the helper methods
# to change the particulars.
self.do_type_checking(node)
if self.inplace: # convert bool to int
inplace = 1
else:
inplace = 0
x = inputs[0]
y = inputs[1]
(z,) = outputs
if self.set_instead_of_inc: # convert bool to int
op_is_set = 1
else:
op_is_set = 0
fail = sub["fail"]
view_ndim = node.inputs[0].ndim - sum(
not isinstance(idx, slice) for idx in self.idx_list
)
copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = f"""
if ({inplace})
{{
if ({x} != {z})
{{
Py_XDECREF({z});
Py_INCREF({x});
{z} = {x};
}}
}}
else
{{
Py_XDECREF({z});
{z} = {copy_of_x};
if (!{z}) {{
// Exception already set
{fail}
}}
}}
"""
# get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args()
get_zview = Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
view_ndim=view_ndim,
**helper_args,
)
# Make a view on the output, as we will write into it.
alloc_zview = self.make_view_array(z, view_ndim)
build_view = f"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
{alloc_zview};
if (!zview)
{{
{fail};
}}
"""
copy_into = self.copy_into("zview", y)
add_to_zview = self.add_to_zview(name, y, fail)
make_modification = f"""
if ({op_is_set})
{{
if ({copy_into}) // does broadcasting
{{
Py_DECREF(zview);
{fail};
}}
}}
else
{{
{add_to_zview}
}}
"""
return (
self.decl_view()
+ copy_input_if_necessary
+ "{"
+ get_zview
+ build_view
+ make_modification
+ "Py_DECREF(zview);"
+ "}"
)
def do_type_checking(self, node):
"""
Should raise NotImplementedError if c_code does not support
the types involved in this node.
"""
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version()
if hv:
return (3, hv)
else:
return ()
def copy_of_x(self, x):
"""
Parameters
----------
x
A string giving the name of a C variable pointing to an array.
Returns
-------
object
C code expression to make a copy of x.
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
# Parameters of PyArray_FromAny are:
# array
# dtype: we pass NULL to say any dtype is acceptable, so the existing
# dtype will be copied
# min_depth: we pass 0 to have this parameter ignored
# max_depth: we pass 0 to have this parameter ignored
# requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
# context: this is almost always NULL, I'm not sure what it's used for
return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)"""
def make_view_array(self, x, view_ndim):
"""
Parameters
----------
x
A string identifying an array to be viewed.
view_ndim
A string specifying the number of dimensions to have in the view.
This doesn't need to actually set up the view with the right indexing;
we'll do that manually later.
"""
return f"""Py_INCREF(PyArray_DESCR({x}));
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR({x}),
{view_ndim},
xview_dims, //PyArray_DIMS({x}),
xview_strides, //PyArray_STRIDES({x}),
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_FLAGS({x}),
NULL);
"""
def get_helper_c_code_args(self):
"""
Return a dictionary of arguments to pass to helper_c_code.
"""
return Subtensor.default_helper_c_code_args()
def copy_into(self, view, source):
"""
Parameters
----------
view : string
C code expression for an array.
source : string
C code expression for an array.
Returns
-------
object
C code expression to copy source into view, and 0 on success.
"""
return f"""PyArray_CopyInto({view}, {source})"""
def add_to_zview(self, name, x, fail):
"""
Return C code to add x to zview. Should DECREF zview if the
add fails.
"""
return f"""
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_{x});
if (add_rval)
{{
assert (PyArray_Check((PyObject*)add_rval));
assert (PyArray_DATA(add_rval) == PyArray_DATA(zview));
Py_DECREF(add_rval);
}}
else
{{
Py_DECREF(zview);
{fail};
}}"""
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
def pushforward(self, inputs, outputs, eval_points):
if isinstance(eval_points[0].type, DisconnectedType) or isinstance(
eval_points[1].type, DisconnectedType
):
return [disconnected_type()]
# Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those
_x, _y, *index_variables = inputs
return self(eval_points[0], eval_points[1], *index_variables, return_list=True)
def connection_pattern(self, node):
_x, _y, *index_variables = node.inputs
rval = [[True], [True], *([False] for _ in index_variables)]
return rval
def pullback(self, inputs, outputs, output_grads):
(g_output,) = output_grads
x, y, *index_variables = inputs
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
if y.dtype in discrete_dtypes:
gy = y.zeros_like(dtype=config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *index_variables),
pytensor.tensor.zeros_like(y),
)
else:
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *index_variables)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))]
class IncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
x, y, *index_variables = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate)
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(y, pstate)
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
else:
res = f"inc_subtensor({res}, {y_str})"
return res
pprint.assign(IncSubtensor, IncSubtensorPrinter())
def _sum_grad_over_bcasted_dims(x, gx):
"""
Sum of gx over dimensions to reproduce x.broadcastable.
This is useful to sum gradients over certain dimensions when
x has been broadcasted, and we need to sum the gradient contributions
over all duplications.
"""
if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable
axis_to_sum = []
for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True:
axis_to_sum.append(i)
elif gx.broadcastable[i] is True and x_broad[i] is False:
# This means that PyTensor was able to infer that
# gx.shape[i] is 1, so x.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gx.broadcastable[i] == x_broad[i]
gx = gx.sum(axis=axis_to_sum, keepdims=True)
if gx.ndim != x.ndim:
assert gx.ndim > x.ndim
for i in range(x_dim_added):
assert gx.broadcastable[i]
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
# Broadcastable flags of gx can be the same or more specific than x.
# Only unallowed case is x_dim_b == True and gx_dim_b == False.
assert not any(
x_dim_b and not gx_dim_b
for x_dim_b, gx_dim_b in zip(
x.type.broadcastable, gx.type.broadcastable, strict=True
)
), (x.type, gx.type)
return gx
class AdvancedSubtensor1(COp):
"""
Implement x[ilist] where ilist is a vector of integers.
"""
# sparse_grad doesn't go in here since it only affects the output
# of the grad() method.
__props__ = ()
idx_list = (0,)
_f16_ok = True
check_input = False
def __hash__(self):
return hash(type(self))
def __init__(self, sparse_grad=False):
self.sparse_grad = sparse_grad
def make_node(self, x, ilist):
x_ = as_tensor_variable(x)
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
if ilist_.type.ndim != 1:
raise TypeError("index must be vector")
if x_.type.ndim == 0:
raise TypeError("cannot index into a scalar")
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
def perform(self, node, inp, output_storage):
x, i = inp
# Numpy take is always slower when out is provided
# https://github.com/numpy/numpy/issues/28636
output_storage[0][0] = x.take(i, axis=0, out=None)
def connection_pattern(self, node):
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
def pullback(self, inputs, outputs, output_grads):
x, ilist = inputs
(gz,) = output_grads
assert len(inputs) == 2
if self.sparse_grad:
if x.type.ndim != 2:
raise TypeError(
"AdvancedSubtensor1: you can't take the sparse grad"
" from a tensor with ndim != 2. ndim is " + str(x.type.ndim)
)
rval1 = pytensor.sparse.construct_sparse_from_list(x, gz, ilist)
else:
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
elif x.dtype in complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rval1 = advanced_inc_subtensor1(gx, gz, ilist)
return [rval1, *(disconnected_type() for _ in range(len(inputs) - 1))]
def pushforward(self, inputs, outputs, eval_points):
if isinstance(eval_points[0].type, DisconnectedType):
return [disconnected_type()]
_x, *index_variables = inputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes):
x, ilist = ishapes
return [ilist + x[1:]]
def c_code(self, node, name, input_names, output_names, sub):
if self.__class__ is not AdvancedSubtensor1:
raise MethodNotDefined(
"c_code defined for AdvancedSubtensor1, not for child class",
type(self),
)
x, idxs = node.inputs
if self._idx_may_be_invalid(x, idxs):
mode = "NPY_RAISE"
else:
# We can know ahead of time that all indices are valid, so we can use a faster mode
mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0]
fail = sub["fail"]
if mode == "NPY_RAISE":
# numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
# We can remove this special case after https://github.com/numpy/numpy/issues/28636
manage_pre_allocated_out = f"""
if ({output_name} != NULL) {{
// Numpy TakeFrom is always slower when copying
// https://github.com/numpy/numpy/issues/28636
Py_CLEAR({output_name});
}}
"""
else:
manage_pre_allocated_out = f"""
if ({output_name} != NULL) {{
npy_intp nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
if (PyArray_NDIM({output_name}) != nd) {{
Py_CLEAR({output_name});
}}
else {{
int i;
npy_intp* shape = PyArray_DIMS({output_name});
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
Py_CLEAR({output_name});
break;
}}
}}
if ({output_name} != NULL) {{
for (; i < nd; i++) {{
if (shape[i] != PyArray_DIMS({a_name})[i-PyArray_NDIM({i_name})+1]) {{
Py_CLEAR({output_name});
break;
}}
}}
}}
}}
}}
"""
return f"""
{manage_pre_allocated_out}
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
{a_name}, (PyObject*){i_name}, 0, {output_name}, {mode});
if ({output_name} == NULL) {fail};
"""
def c_code_cache_version(self):
return (5,)
@staticmethod
def _idx_may_be_invalid(x, idx) -> bool:
if idx.type.shape[0] == 0:
# Empty index is always valid
return False
if x.type.shape[0] is None:
# We can't know if in index is valid if we don't know the length of x
return True
if not isinstance(idx, Constant):
# This is conservative, but we don't try to infer lower/upper bound symbolically
return True
shape0 = x.type.shape[0]
min_idx, max_idx = idx.data.min(), idx.data.max()
return not (min_idx >= 0 or min_idx >= -shape0) and (
max_idx < 0 or max_idx < shape0
)
advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(BaseSubtensor, COp):
"""
Increments a subtensor using advanced slicing (list of index).
"""
__props__ = (
"inplace",
"set_instead_of_inc",
)
idx_list = (0,)
check_input = False
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
)
def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = bool(inplace)
self.set_instead_of_inc = bool(set_instead_of_inc)
if inplace:
self.destroy_map = {0: [0]}
def __hash__(self):
return hash(
(
type(self),
self.inplace,
self.set_instead_of_inc,
)
)
def clone_inplace(self):
return self.__class__(
inplace=True,
set_instead_of_inc=self.set_instead_of_inc,
)
def __str__(self):
if self.inplace:
msg = "inplace"
else:
msg = "no_inplace"
if self.set_instead_of_inc:
msg += ",set"
else:
msg += ",inc"
return self.__class__.__name__ + f"{{{msg}}}"
def make_node(self, x, y, ilist):
x_ = as_tensor_variable(x)
y_ = as_tensor_variable(y)
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
if ilist_.type.ndim != 1:
raise TypeError("index must be vector")
if x_.type.ndim == 0:
raise TypeError("cannot index into a scalar")
if y_.type.ndim > x_.type.ndim:
if self.set_instead_of_inc:
opname = "set"
else:
opname = "increment"
raise TypeError(
f"cannot {opname} x subtensor with ndim={x_.type.ndim} by y with ndim={y_.type.ndim}."
)
return Apply(self, [x_, y_, ilist_], [x_.type()])
def copy_of_x(self, x):
"""
Parameters
----------
x : string
Gives the name of a C variable pointing to an array.
Returns
-------
object
C code expression to make a copy of x.
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
# Parameters of PyArray_FromAny are:
# array
# dtype: we pass NULL to say any dtype is acceptable, so the existing
# dtype will be copied
# min_depth: we pass 0 to have this parameter ignored
# max_depth: we pass 0 to have this parameter ignored
# requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
# context: this is almost always NULL, I'm not sure what it's used for
return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)"""
def c_code(self, node, name, input_names, output_names, sub):
x, y, idx = input_names
[out] = output_names
copy_of_x = self.copy_of_x(x)
params = sub["params"]
fail = sub["fail"]
x_, y_, idx_ = node.inputs
y_cdtype = y_.type.dtype_specs()[1]
idx_cdtype = idx_.type.dtype_specs()[1]
out_cdtype = node.outputs[0].type.dtype_specs()[1]
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
if (
x_.type.ndim == 1
and y_.type.ndim == 1
and not y_bcast
and x_.type.dtype not in complex_dtypes
and y_.type.dtype not in complex_dtypes
):
# Simple implementation for vector x, y cases
idx_may_be_neg = not (
# Empty idx needs no negative checks
idx_.type.shape[0] == 0
or (isinstance(idx_, Constant) and idx_.data.min() >= 0)
)
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
shape0 = x_.type.shape[0]
# This is used to make sure that when we trust the indices to be valid
# we are not fooled by a wrong static shape
# We mention x to the user in error messages but we work (and make checks) on out,
# which should be x or a copy of it
unexpected_shape0 = (
f"PyArray_SHAPE({out})[0] != {shape0}" if shape0 is not None else "0"
)
op = "=" if self.set_instead_of_inc else "+="
code = f"""
if ({params}->inplace)
{{
if ({x} != {out})
{{
Py_XDECREF({out});
Py_INCREF({x});
{out} = {x};
}}
}}
else
{{
Py_XDECREF({out});
{out} = {copy_of_x};
if (!{out}) {{
// Exception already set
{fail}
}}
}}
if (PyArray_NDIM({out}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got %d", PyArray_NDIM({out}));
{fail}
}}
if ({unexpected_shape0}) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be {shape0}, got %d", PyArray_SHAPE({out})[0]);
{fail}
}}
if (PyArray_NDIM({idx}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got %d", PyArray_NDIM({idx}));
{fail}
}}
if (PyArray_NDIM({y}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got %d", PyArray_NDIM({y}));
{fail}
}}
if (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0]) {{
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
}} else {{
PyErr_Format(PyExc_ValueError,
"AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match: %d, %d",
PyArray_SHAPE({y})[0], PyArray_SHAPE({idx})[0]);
}}
{fail}
}}
{{
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
{out_cdtype}* out_data = ({out_cdtype}*)PyArray_DATA({out});
{y_cdtype}* y_data = ({y_cdtype}*)PyArray_DATA({y});
{idx_cdtype}* idx_data = ({idx_cdtype}*)PyArray_DATA({idx});
npy_intp n = PyArray_SHAPE({idx})[0];
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
for(int i = 0; i < n; i++){{
{idx_cdtype} idx = idx_data[i * idx_jump];
if ({int(idx_may_be_neg)}){{
if (idx < 0) {{
idx += out_shape0;
}}
}}
if ({int(idx_may_be_invalid)}){{
if ((idx < 0) || (idx >= out_shape0)) {{
PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx_data[i * idx_jump], out_shape0);
{fail}
}}
}}
out_data[idx * out_jump] {op} y_data[i * y_jump];
}}
}}
"""
return code
raise NotImplementedError
def c_code_cache_version(self):
return (10,)
def _check_runtime_broadcasting(
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray
) -> None:
if y.ndim > 0:
y_pt_bcast = node.inputs[1].broadcastable # type: ignore
if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]:
# Attempting to broadcast with index
raise ValueError(self._runtime_broadcast_error_msg)
if any(
not y_bcast and y_dim == 1 and y_dim != x_dim
for y_bcast, y_dim, x_dim in zip(
reversed(y_pt_bcast),
reversed(y.shape),
reversed(x.shape),
strict=False,
)
):
# Attempting to broadcast with buffer
raise ValueError(self._runtime_broadcast_error_msg)
def perform(self, node, inputs, output_storage):
x, y, idx = inputs
if not self.inplace:
x = x.copy()
self._check_runtime_broadcasting(node, x, y, idx)
if self.set_instead_of_inc:
x[idx] = y
else:
# In Numpy, `x[idx] += y` doesn't work if the same index is present
# many times: it does it only once.
np.add.at(x, idx, y)
output_storage[0][0] = x
def infer_shape(self, fgraph, node, ishapes):
x, _y, _ilist = ishapes
return [x]
def pushforward(self, inputs, outputs, eval_points):
if any(isinstance(t.type, DisconnectedType) for t in eval_points[:2]):
return [disconnected_type()]
_x, _y, *index_variables = inputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def connection_pattern(self, node):
rval = [[True], [True], [False]]
return rval
def pullback(self, inputs, outputs, output_grads):
(g_output,) = output_grads
x, y, idx_list = inputs
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
if y.dtype in discrete_dtypes:
gy = y.zeros_like(dtype=config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = advanced_set_subtensor1(g_output, y.zeros_like(), idx_list)
else:
gx = g_output
gy = advanced_subtensor1(g_output, idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, disconnected_type()]
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True)
def as_tensor_index_variable(idx):
"""Convert index to Variable form for advanced indexing."""
idx = as_tensor_variable(idx)
if idx.type.dtype not in discrete_dtypes:
raise TypeError("index must be integers or a boolean mask")
if idx.type.dtype == "bool" and idx.type.ndim == 0:
raise NotImplementedError(
"Boolean scalar indexing not implemented. "
"Open an issue in https://github.com/pymc-devs/pytensor/issues if you need this behavior."
)
return idx
class AdvancedSubtensor(BaseSubtensor, COp):
"""Implements NumPy's advanced indexing."""
__props__ = ("idx_list",)
__hash__ = BaseSubtensor.__hash__
def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version()
if hv:
return (3, hv)
else:
return ()
def make_node(self, x, *index_variables):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} inputs, got {len(index_variables)}"
)
x = as_tensor_variable(x)
index_variables = tuple(as_tensor_index_variable(a) for a in index_variables)
idx_list = self.idx_list
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
reconstructed_indices = unflatten_index_variables(index_variables, idx_list)
explicit_indices = []
for idx in reconstructed_indices:
if isinstance(idx, slice):
explicit_indices.append(idx)
elif hasattr(idx, "dtype") and idx.dtype == "bool":
if idx.type.ndim == 0:
raise NotImplementedError(
"Indexing with scalar booleans not supported"
)
axis = len(explicit_indices)
indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
for j, (indexed_length, indexer_length) in enumerate(
zip(indexed_shape, idx.type.shape)
):
if (
indexed_length is not None
and indexer_length is not None
and indexed_length != indexer_length
):
raise IndexError(
f"boolean index did not match indexed tensor along axis {axis + j};"
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
)
# Convert boolean indices to integer with nonzero, to reason about static shape next
if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else:
nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices)
else:
explicit_indices.append(idx)
if len(explicit_indices) > x.type.ndim:
raise IndexError(
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed"
)
# Perform basic and advanced indexing shape inference separately (no newaxis)
basic_group_shape = []
advanced_indices = []
adv_group_axis = None
last_adv_group_axis = None
for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None))
):
if isinstance(idx, slice):
basic_group_shape.append(slice_static_length(idx, dim_length))
else: # TensorType (advanced index)
# Keep track of advanced group axis
if adv_group_axis is None:
# First time we see an advanced index
adv_group_axis, last_adv_group_axis = i, i
elif last_adv_group_axis == (i - 1):
# Another advanced indexing aligned with the first group
last_adv_group_axis = i
else:
# Non-consecutive advanced index, all advanced index views get moved to the front
adv_group_axis = 0
advanced_indices.append(idx)
if advanced_indices:
try:
# Use variadic add to infer static shape of advanced integer indices
advanced_group_static_shape = add(*advanced_indices).type.shape
except ValueError:
# It fails when static shapes are inconsistent
static_shapes = [idx.type.shape for idx in advanced_indices]
raise IndexError(
f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
)
# Combine advanced and basic views
indexed_shape = [
*basic_group_shape[:adv_group_axis],
*advanced_group_static_shape,
*basic_group_shape[adv_group_axis:],
]
else:
# This could have been a basic subtensor!
indexed_shape = basic_group_shape
return Apply(
self,
[x, *index_variables],
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
)
def pushforward(self, inputs, outputs, eval_points):
if isinstance(eval_points[0].type, DisconnectedType):
return [disconnected_type()]
_x, *index_variables = inputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes):
def is_bool_index(idx):
return (
isinstance(idx, np.bool_ | bool)
or getattr(idx, "dtype", None) == "bool"
)
_x, *index_variables = node.inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list)
index_shapes = []
for idx in full_indices:
if isinstance(idx, slice):
index_shapes.append(idx)
else:
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
else:
input_shape_idx = (
index_variables.index(idx) + 1
) # +1 because ishapes[0] is x
index_shapes.append(ishapes[input_shape_idx])
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)
adv_indices = [idx for idx in full_indices if not isinstance(idx, slice)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if len(bool_indices) == 1 and len(adv_indices) == 1:
[bool_index] = bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = full_indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()
assert node.outputs[0].ndim == len(res_shape)
return [res_shape]
def perform(self, node, inputs, out_):
(out,) = out_
x, *index_variables = inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list)
rval = x.__getitem__(tuple(full_indices))
# When there are no arrays, we are not actually doing advanced
# indexing, so __getitem__ will not return a copy.
# Since no view_map is set, we need to copy the returned value
if not any(
isinstance(idx, np.ndarray) and idx.ndim > 0 for idx in full_indices
):
rval = rval.copy()
out[0] = rval
def connection_pattern(self, node):
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
def pullback(self, inputs, outputs, output_grads):
(gz,) = output_grads
x, *index_variables = inputs
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
elif x.dtype in complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
return [
AdvancedIncSubtensor(self.idx_list)(gx, gz, *index_variables),
*(disconnected_type() for _ in range(len(index_variables))),
]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
indices = indices_from_subtensor(node.inputs[1:], node.op.idx_list)
return _non_consecutive_adv_indexing(indices)
class AdvancedSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)
pprint.assign(AdvancedSubtensor, AdvancedSubtensorPrinter())
def advanced_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return AdvancedSubtensor(idx_list)(x, *flat_index_vars)
@_vectorize_node.register(AdvancedSubtensor)
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
x, *idxs = node.inputs
batch_x, *batch_idxs = batch_inputs
x_is_batched = x.type.ndim < batch_x.type.ndim
idxs_are_batched = any(
batch_idx.type.ndim > idx.type.ndim
for batch_idx, idx in zip(batch_idxs, idxs, strict=True)
if isinstance(batch_idx, TensorVariable)
)
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim = batch_x.type.ndim - x.type.ndim
new_idx_list = (slice(None),) * x_batch_ndim + op.idx_list
return type(op)(new_idx_list).make_node(batch_x, *batch_idxs)
class AdvancedIncSubtensor(BaseSubtensor, Op):
"""Increments a subtensor using advanced indexing."""
__props__ = (
"idx_list",
"inplace",
"set_instead_of_inc",
"ignore_duplicates",
)
__hash__ = BaseSubtensor.__hash__
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
ignore_duplicates=False,
):
super().__init__(idx_list)
self.set_instead_of_inc = set_instead_of_inc
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.ignore_duplicates = ignore_duplicates
def __str__(self):
return (
"AdvancedSetSubtensor"
if self.set_instead_of_inc
else "AdvancedIncSubtensor"
)
def make_node(self, x, y, *index_variables):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} tensor inputs but got {len(index_variables)}"
)
index_variables = tuple(
as_tensor_index_variable(idx) for idx in index_variables
)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
return Apply(
self,
[x, y, *index_variables],
[x.type()],
)
def perform(self, node, inputs, out_):
x, y, *index_variables = inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list)
(out,) = out_
if not self.inplace:
out[0] = x.copy()
else:
out[0] = x
if self.set_instead_of_inc:
out[0][tuple(full_indices)] = y
elif self.ignore_duplicates:
out[0][tuple(full_indices)] += y
else:
np.add.at(out[0], tuple(full_indices), y)
def infer_shape(self, fgraph, node, ishapes):
return [ishapes[0]]
def connection_pattern(self, node):
_x, _y, *index_variables = node.inputs
rval = [[True], [True], *([False] for _ in index_variables)]
return rval
def pushforward(self, inputs, outputs, eval_points):
if any(isinstance(t.type, DisconnectedType) for t in eval_points[:2]):
return [disconnected_type()]
_x, _y, *index_variables = inputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def pullback(self, inputs, outputs, output_grads):
x, y, *index_variables = inputs
(outgrad,) = output_grads
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
if y.dtype in discrete_dtypes:
gy = y.zeros_like(dtype=config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = (
type(self)(self.idx_list, set_instead_of_inc=True)
.make_node(outgrad, y.zeros_like(), *index_variables)
.outputs[0]
)
else:
gx = outgrad
gy = (
AdvancedSubtensor(self.idx_list)
.make_node(outgrad, *index_variables)
.outputs[0]
)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))]
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
indices = indices_from_subtensor(node.inputs[2:], node.op.idx_list)
return _non_consecutive_adv_indexing(indices)
def advanced_inc_subtensor(x, y, *args, **kwargs):
idx_list, flat_index_vars = flatten_index_variables(args)
return AdvancedIncSubtensor(idx_list, **kwargs)(x, y, *flat_index_vars)
def advanced_set_subtensor(x, y, *args, **kwargs):
return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs)
class AdvancedIncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
x, y, *index_variables = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate)
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(y, pstate)
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
else:
res = f"inc_subtensor({res}, {y_str})"
return res
pprint.assign(AdvancedIncSubtensor, AdvancedIncSubtensorPrinter())
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
"""
Return x with the given subtensor overwritten by y.
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
.. code-block:: python
from pytensor.tensor import set_subtensor, vector
r = vector("r")
new_r = set_subtensor(r[10:], 5)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
"""
return inc_subtensor(
x,
y,
inplace,
set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
)
def inc_subtensor(
x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset = x.ndim - y.ndim
for dim in range(y.ndim):
if x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if x.owner is None:
raise TypeError("x must be the result of a subtensor operation")
# retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0, 1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased,
)
real_x, *index_variables = x.owner.inputs
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if ignore_duplicates:
the_op = AdvancedIncSubtensor(
(0,),
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=True,
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x, *index_variables = x.owner.inputs
the_op = AdvancedIncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order = x.owner.op.new_order
y_order = ["x"] * x.ndim
for i, v in enumerate(x_order):
if v != "x" and (v - dim_offset) >= 0:
y_order[v - dim_offset] = i
inner_incsubtensor = inc_subtensor(
inner_x,
y.dimshuffle(y_order),
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, Reshape):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
inner_incsubtensor = inc_subtensor(
inner_x,
flattened_y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
return inner_incsubtensor
else:
raise TypeError("x must be the result of a subtensor operation")
[docs]
def take(a, indices, axis=None, mode="raise"):
"""Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
See `np.take`
Parameters
----------
a : TensorVariable
The source array.
indices : TensorVariable, ndarray, list, tuple
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
"""
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
if not isinstance(axis, int | type(None)):
raise TypeError("`axis` must be an integer or None")
if axis is None:
return advanced_subtensor(a.flatten(), indices)
elif axis < 0:
axis += a.ndim
if mode == "clip":
indices = clip(indices, 0, a.shape[axis] - 1)
elif mode == "wrap":
indices = indices % a.shape[axis]
full_indices = (slice(None),) * axis + (indices,)
return a[full_indices]
def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.
Copied from numpy.lib.arraypad._slice_at_axis
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".
Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.
Examples
--------
.. testcode::
import pytensor.tensor as pt
s = pt.slice_at_axis(slice(None, 1), 1)
print(s)
.. testoutput::
(slice(None, None, None), slice(None, 1, None), Ellipsis)
.. testcode::
x = pt.tensor('x', shape=(None, None, None))
x_sliced = x[s]
f = pytensor.function([x], x_sliced)
x = np.arange(27).reshape(3, 3, 3)
print(f(x))
.. testoutput::
[[[ 0. 1. 2.]]
[[ 9. 10. 11.]]
[[18. 19. 20.]]]
"""
if axis >= 0:
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
else:
# If axis = -1 we want zero right padding (and so on), so subtract one
axis = abs(axis) - 1
return (...,) + (sl,) + (slice(None),) * axis # type: ignore
def flip(
arr: TensorVariable, axis: int | tuple[int] | TensorVariable | None = None
) -> TensorVariable:
"""
Reverse the order of elements in an tensor along the given axis.
Parameters
----------
arr: TensorVariable
Input tensor.
axis: int | tuple[int] | TensorVariable, optional
Axis or axes along which to flip over. The default is to flip over all of the axes of the input tensor.
Returns
-------
arr: TensorVariable
A view of `arr` with the entries of axis reversed.
Examples
--------
.. testcode::
import pytensor
import pytensor.tensor as pt
x = pt.tensor('x', shape=(None, None))
x_flipped = pt.flip(x, axis=0)
f = pytensor.function([x], x_flipped)
x = [[1, 2], [3, 4]]
print(f(x))
.. testoutput::
[[3. 4.]
[1. 2.]]
"""
if axis is None:
index = ((slice(None, None, -1)),) * arr.ndim
else:
normalized_axis = normalize_axis_tuple(axis, arr.ndim)
index = tuple(
[
slice(None, None, -1)
if i in normalized_axis
else slice(None, None, None)
for i in range(arr.ndim)
]
)
return cast(TensorVariable, arr[index])
__all__ = [
"flip",
"inc_subtensor",
"set_subtensor",
"slice_at_axis",
"take",
]