Source code for pytensor.tensor.linalg.summary

import warnings
from collections.abc import Callable
from typing import Literal

import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.linalg.decomposition.svd import svd
from pytensor.tensor.linalg.dtype_utils import (
    linalg_output_dtype,
    linalg_real_output_dtype,
)
from pytensor.tensor.type import scalar


def trace(X):
    """
    Returns the sum of diagonal elements of matrix X.
    """
    warnings.warn(
        "pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.",
        FutureWarning,
    )
    return diagonal(X).sum()


class Det(Op):
    """
    Matrix determinant. Input should be a square matrix.

    """

    __props__ = ()
    gufunc_signature = "(m,m)->()"
    gufunc_spec = ("numpy.linalg.det", 1, 1)

    def make_node(self, x):
        x = as_tensor_variable(x)
        if x.ndim != 2:
            raise ValueError(
                f"Input passed is not a valid 2D matrix. Current ndim {x.ndim} != 2"
            )
        # Check for known shapes and square matrix
        if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]):
            raise ValueError(
                f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
            )
        out_dtype = linalg_output_dtype(x.type.dtype)
        o = scalar(dtype=out_dtype)
        return Apply(self, [x], [o])

    def perform(self, node, inputs, outputs):
        (x,) = inputs
        (z,) = outputs
        try:
            z[0] = np.asarray(np.linalg.det(x))
        except Exception as e:
            raise ValueError("Failed to compute determinant", x) from e

    def pullback(self, inputs, outputs, g_outputs):
        from pytensor.tensor.linalg.inverse import matrix_inverse

        (gz,) = g_outputs
        (x,) = inputs
        return [gz * self(x) * matrix_inverse(x).T]

    def infer_shape(self, fgraph, node, shapes):
        return [()]

    def __str__(self):
        return "Det"


det = Blockwise(Det())


class SLogDet(Op):
    """
    Compute the log determinant and its sign of the matrix. Input should be a square matrix.
    """

    __props__ = ()
    gufunc_signature = "(m,m)->(),()"
    gufunc_spec = ("numpy.linalg.slogdet", 1, 2)

    def make_node(self, x):
        x = as_tensor_variable(x)
        assert x.ndim == 2
        out_dtype = linalg_output_dtype(x.type.dtype)
        sign = scalar(dtype=out_dtype)
        det = scalar(dtype=linalg_real_output_dtype(x.type.dtype))
        return Apply(self, [x], [sign, det])

    def perform(self, node, inputs, outputs):
        (x,) = inputs
        (sign, det) = outputs
        try:
            sign[0], det[0] = (np.array(z) for z in np.linalg.slogdet(x))
        except Exception as e:
            raise ValueError("Failed to compute determinant", x) from e

    def infer_shape(self, fgraph, node, shapes):
        return [(), ()]

    def __str__(self):
        return "SLogDet"


[docs] def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: """ Compute the sign and (natural) logarithm of the determinant of an array. Returns a naive graph which is optimized later using rewrites with the det operation. Parameters ---------- x : (..., M, M) tensor or tensor_like Input tensor, has to be square. Returns ------- A tuple with the following attributes: sign : (...) tensor_like A number representing the sign of the determinant. For a real matrix, this is 1, 0, or -1. logabsdet : (...) tensor_like The natural log of the absolute value of the determinant. If the determinant is zero, then `sign` will be 0 and `logabsdet` will be -inf. In all cases, the determinant is equal to ``sign * exp(logabsdet)``. """ det_val = det(x) return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
def _multi_svd_norm( x: ptb.TensorVariable, row_axis: int, col_axis: int, reduce_op: Callable ): """Compute a function of the singular values of the 2-D matrices in `x`. This is a private utility function used by `pytensor.tensor.linalg.norm()`. Copied from `np.linalg._multi_svd_norm`. Parameters ---------- x : TensorVariable Input tensor. row_axis, col_axis : int The axes of `x` that hold the 2-D matrices. reduce_op : callable Reduction op. Should be one of `pt.min`, `pt.max`, or `pt.sum` Returns ------- result : float or ndarray If `x` is 2-D, the return values is a float. Otherwise, it is an array with ``x.ndim - 2`` dimensions. The return values are either the minimum or maximum or sum of the singular values of the matrices, depending on whether `op` is `pt.amin` or `pt.amax` or `pt.sum`. """ y = ptb.moveaxis(x, (row_axis, col_axis), (-2, -1)) result = reduce_op(svd(y, compute_uv=False), axis=-1) return result VALID_ORD = Literal["fro", "f", "nuc", "inf", "-inf", 0, 1, -1, 2, -2]
[docs] def norm( x: ptb.TensorVariable, ord: float | VALID_ORD | None = None, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ): """ Matrix or vector norm. Parameters ---------- x: TensorVariable Tensor to take norm of. ord: float, str or int, optional Order of norm. If `ord` is a str, it must be one of the following: - 'fro' or 'f' : Frobenius norm - 'nuc' : nuclear norm - 'inf' : Infinity norm - '-inf' : Negative infinity norm If an integer, order can be one of -2, -1, 0, 1, or 2. Otherwise `ord` must be a float. Default is the Frobenius (L2) norm. axis: tuple of int, optional Axes over which to compute the norm. If None, norm of entire matrix (or vector) is computed. Row or column norms can be computed by passing a single integer; this will treat a matrix like a batch of vectors. keepdims: bool If True, dummy axes will be inserted into the output so that norm.dnim == x.dnim. Default is False. Returns ------- TensorVariable Norm of `x` along axes specified by `axis`. Notes ----- Batched dimensions are supported to the left of the core dimensions. For example, if `x` is a 3D tensor with shape (2, 3, 4), then `norm(x)` will compute the norm of each 3x4 matrix in the batch. If the input is a 2D tensor and should be treated as a batch of vectors, the `axis` argument must be specified. """ x = ptb.as_tensor_variable(x) ndim = x.ndim core_ndim = min(2, ndim) batch_ndim = ndim - core_ndim if axis is None: # Handle some common cases first. These can be computed more quickly than the default SVD way, so we always # want to check for them. if ( (ord is None) or (ord in ("f", "fro") and core_ndim == 2) or (ord == 2 and core_ndim == 1) ): x = x.reshape(tuple(x.shape[:-2]) + (-1,) + (1,) * (core_ndim - 1)) batch_T_dim_order = tuple(range(batch_ndim)) + tuple( range(batch_ndim + core_ndim - 1, batch_ndim - 1, -1) ) if x.dtype.startswith("complex"): x_real = x.real # type: ignore x_imag = x.imag # type: ignore sqnorm = ( ptb.transpose(x_real, batch_T_dim_order) @ x_real + ptb.transpose(x_imag, batch_T_dim_order) @ x_imag ) else: sqnorm = ptb.transpose(x, batch_T_dim_order) @ x ret = ptm.sqrt(sqnorm).squeeze() if keepdims: ret = ptb.shape_padright(ret, core_ndim) return ret # No special computation to exploit -- set default axis before continuing axis = tuple(range(core_ndim)) elif not isinstance(axis, tuple): try: axis = int(axis) except Exception as e: raise TypeError( "'axis' must be None, an integer, or a tuple of integers" ) from e axis = (axis,) if len(axis) == 1: # Vector norms if ord in [None, "fro", "f"] and (core_ndim == 2): # This is here to catch the case where X is a 2D tensor but the user wants to treat it as a batch of # vectors. Other vector norms will work fine in this case. ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis, keepdims=keepdims)) elif (ord == "inf") or (ord == np.inf): ret = ptm.max(ptm.abs(x), axis=axis, keepdims=keepdims) elif (ord == "-inf") or (ord == -np.inf): ret = ptm.min(ptm.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: ret = ptm.neq(x, 0).sum(axis=axis, keepdims=keepdims) elif ord == 1: ret = ptm.sum(ptm.abs(x), axis=axis, keepdims=keepdims) elif isinstance(ord, str): raise ValueError(f"Invalid norm order '{ord}' for vectors") else: ret = ptm.sum(ptm.abs(x) ** ord, axis=axis, keepdims=keepdims) ret **= ptm.reciprocal(ord) return ret elif len(axis) == 2: # Matrix norms row_axis, col_axis = ( batch_ndim + x for x in normalize_axis_tuple(axis, core_ndim) ) axis = (row_axis, col_axis) if ord in [None, "fro", "f"]: ret = ptm.sqrt(ptm.sum((x.conj() * x).real, axis=axis)) elif (ord == "inf") or (ord == np.inf): if row_axis > col_axis: row_axis -= 1 ret = ptm.max(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis) elif (ord == "-inf") or (ord == -np.inf): if row_axis > col_axis: row_axis -= 1 ret = ptm.min(ptm.sum(ptm.abs(x), axis=col_axis), axis=row_axis) elif ord == 1: if col_axis > row_axis: col_axis -= 1 ret = ptm.max(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis) elif ord == -1: if col_axis > row_axis: col_axis -= 1 ret = ptm.min(ptm.sum(ptm.abs(x), axis=row_axis), axis=col_axis) elif ord == 2: ret = _multi_svd_norm(x, row_axis, col_axis, ptm.max) elif ord == -2: ret = _multi_svd_norm(x, row_axis, col_axis, ptm.min) elif ord == "nuc": ret = _multi_svd_norm(x, row_axis, col_axis, ptm.sum) else: raise ValueError(f"Invalid norm order for matrices: {ord}") if keepdims: ret = ptb.expand_dims(ret, axis) return ret else: raise ValueError( f"Cannot compute norm when core_dims < 1 or core_dims > 3, found: core_dims = {core_ndim}" )