Source code for pytensor.tensor.sort

import typing

import numpy as np

from pytensor.gradient import grad_undefined
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.basic import arange, as_tensor_variable, switch
from pytensor.tensor.math import eq, ge
from pytensor.tensor.type import TensorType


KIND = typing.Literal["quicksort", "mergesort", "heapsort", "stable"]
KIND_VALUES = typing.get_args(KIND)


def _parse_sort_args(kind: KIND | None, order, stable: bool | None) -> KIND:
    if order is not None:
        raise ValueError("The order argument is not applicable to PyTensor graphs")
    if stable is not None and kind is not None:
        raise ValueError("kind and stable cannot be set at the same time")
    if stable:
        kind = "stable"
    elif kind is None:
        kind = "quicksort"
    if kind not in KIND_VALUES:
        raise ValueError(f"kind must be one of {KIND_VALUES}, got {kind}")
    return kind


class SortOp(Op):
    """
    This class is a wrapper for numpy sort function.

    """

    __props__ = ("kind",)

    def __init__(self, kind: KIND):
        self.kind = kind

    def make_node(self, input, axis=-1):
        input = as_tensor_variable(input)
        axis = as_tensor_variable(axis, ndim=0, dtype=int)
        if axis.type.numpy_dtype.kind != "i":
            raise ValueError(
                f"Sort axis must have an integer dtype, got {axis.type.dtype}"
            )
        out_type = input.type()
        return Apply(self, [input, axis], [out_type])

    def perform(self, node, inputs, output_storage):
        a, axis = inputs
        z = output_storage[0]
        z[0] = np.sort(a, axis, self.kind)

    def infer_shape(self, fgraph, node, inputs_shapes):
        assert node.inputs[0].ndim == node.outputs[0].ndim
        assert inputs_shapes[1] == ()
        return [inputs_shapes[0]]

    def pullback(self, inputs, outputs, output_grads):
        a, axis = inputs
        indices = self.__get_argsort_indices(a, axis)
        inp_grad = output_grads[0][tuple(indices)]
        axis_grad = grad_undefined(
            self,
            1,
            axis,
            "The gradient of sort is not defined "
            "with respect to the integer axes itself",
        )
        return [inp_grad, axis_grad]

    def __get_expanded_dim(self, a, axis, i):
        index_shape = [1] * a.ndim
        index_shape[i] = a.shape[i]
        # it's a way to emulate
        # numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
        index_val = arange(a.shape[i]).reshape(index_shape)
        return index_val

    def __get_argsort_indices(self, a, axis):
        """
        Calculates indices which can be used to reverse sorting operation of
        "a" tensor along "axis".

        Returns
        -------
        1d array if axis is None
        list of length len(a.shape) otherwise

        """

        # The goal is to get gradient wrt input from gradient
        # wrt sort(input, axis)
        idx = argsort(a, axis, kind=self.kind)
        # rev_idx is the reverse of previous argsort operation
        rev_idx = argsort(idx, axis, kind=self.kind)
        indices = []
        axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data)
        for i in range(a.ndim):
            index_val = switch(
                eq(i, axis_data),
                rev_idx,
                self.__get_expanded_dim(a, axis, i),
            )
            indices.append(index_val)
        return indices

    """
    def pushforward(self, inputs, outputs, eval_points):
        # pushforward can receive DisconnectedType as eval_points.
        # That mean there is no diferientiable path through that input
        # If this imply that you cannot compute some outputs,
        # return disconnected_type() for those.
        if isinstance(eval_points[0].type, DisconnectedType):
            return list(eval_points)
        return self.pullback(inputs, outputs, eval_points)
    """


[docs] def sort( a, axis=-1, kind: KIND | None = None, order=None, *, stable: bool | None = None ): """ Parameters ---------- a: TensorVariable Tensor to be sorted axis: TensorVariable Axis along which to sort. If None, the array is flattened before sorting. kind: {'quicksort', 'mergesort', 'heapsort' 'stable'}, optional Sorting algorithm. Default is 'quicksort' unless stable is defined. order: list, optional For compatibility with numpy sort signature. Cannot be specified. stable: bool, optional Same as specifying kind = 'stable'. Cannot be specified at the same time as kind Returns ------- array A sorted copy of an array. """ kind = _parse_sort_args(kind, order, stable) if axis is None: a = a.flatten() axis = 0 return SortOp(kind)(a, axis)
class ArgSortOp(Op): """ This class is a wrapper for numpy argsort function. """ __props__ = ("kind",) def __init__(self, kind: KIND): self.kind = kind def make_node(self, input, axis=-1): input = as_tensor_variable(input) axis = as_tensor_variable(axis, ndim=0, dtype=int) if axis.type.numpy_dtype.kind != "i": raise ValueError( f"ArgSort axis must have an integer dtype, got {axis.type.dtype}" ) return Apply( self, [input, axis], [TensorType(dtype="int64", shape=input.type.shape)()], ) def perform(self, node, inputs, output_storage): a, axis = inputs z = output_storage[0] z[0] = np.asarray( np.argsort(a, axis, self.kind), dtype=node.outputs[0].dtype, ) def infer_shape(self, fgraph, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] def pullback(self, inputs, outputs, output_grads): # No grad defined for integers. inp, axis = inputs inp_grad = inp.zeros_like() axis_grad = grad_undefined( self, 1, axis, "argsort is not defined for non-integer axes so" " argsort(x, axis+eps) is undefined", ) return [inp_grad, axis_grad] """ def pushforward(self, inputs, outputs, eval_points): # pushforward can receive DisconnectedType as eval_points. # That mean there is no diferientiable path through that input # If this imply that you cannot compute some outputs, # return disconnected_type() for those. if isinstance(eval_points[0].type, DisconnectedType): return list(eval_points) return self.pullback(inputs, outputs, eval_points) """
[docs] def argsort( a, axis=-1, kind: KIND | None = None, order=None, stable: bool | None = None ): """ Returns the indices that would sort an array. Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order. """ kind = _parse_sort_args(kind, order, stable) if axis is None: a = a.flatten() axis = 0 return ArgSortOp(kind)(a, axis)