import torch
import math
import numpy as np
from abc import abstractclassmethod, abstractmethod
from symo.utils import to_dtype

NDArray = torch.Tensor | np.ndarray


def to_partition(eq):
    res = {}
    for i, c in enumerate(eq):
        if c in res:
            res[c].append(i)
        else:
            res[c] = [
                i,
            ]
    return res


def merge_partition(par1, par2):
    def find1(i):
        for k in par1:
            if i in par1[k]:
                return k

    for k, v in par2.items():
        if len(v) > 1:
            i = v[0]
            findi = find1(i)
            for j in v[1:]:
                findj = find1(j)
                if findi != findj:  # need to merge
                    par1[findi].extend(par1[findj])
                    del par1[findj]
    return par1


def to_einsum(par, n):
    res = [0] * n
    for k, v in par.items():
        for c in v:
            res[c] = k
    return "".join(res)


def remove(eq, to_remove):
    """Removes all characters in `to_remove` from the string `eq`."""
    for char in to_remove:
        eq = eq.replace(char, "")
    return eq


def dim_prod(char_lst, left, right, dims):
    """Computes the product of dimensions associated with a list of characters.

    For each character in `char_lst`, the function determines whether it belongs
    to the left-hand or right-hand index string, then multiplies the
    corresponding dimension value from `dims`.

    Args:
        char_lst (Iterable[str]): Characters whose associated dimensions
            should be multiplied.
        left (str): Left-hand index string (e.g., indices for the first tensor).
        right (str): Right-hand index string (e.g., indices for the second tensor).
        dims (list[list[int]]):
            Dimension sizes corresponding to each index occurrence in the left
            and right strings.

    Returns:
        int: The product of all dimension values associated with the characters
            in `char_lst`.

    Raises:
        ValueError: If a character in `char_lst` is not found in either
            `left` or `right`.

    Example:
        >>> dim_prod(['i', 'k'], "ij", "kl", [[2,3],[5,7]])
        2 * 7 = 14
    """
    eq = left + right
    dims = [*dims[0], *dims[1]]
    return np.prod([dims[eq.index(char)] for char in char_lst])


def term_norm(term, dims):
    """Computes the norm of a tensor term after removing factor indices.

    The term is specified as (left_indices, right_indices, factor_indices).
    Factor indices are removed, and only indices that appear exactly once
    in the remaining expression contribute to the norm. The norm is defined as
    the square root of the product of the associated dimension sizes.

    Args:
        term (tuple[str, str, str]):
            A tuple (left, right, factor) representing symbolic index strings.
        dims (list[list[int]]):
            Dimension sizes corresponding to each index occurrence in the left
            and right strings.

    Returns:
        float: The computed norm of the term.

    Example:
        >>> term_norm(("ij", "jk", "j"), [[2,3],[3,4]])
        sqrt(2 * 4)
    """
    left = term[0]
    right = term[1]
    eq = left + right
    factor_eq = term[2]

    eq_factor_removed = remove(eq, factor_eq)

    # characters that appear exactly once
    seen = {char for char in eq_factor_removed if eq_factor_removed.count(char) == 1}

    norm_sqr = dim_prod(seen, left, right, dims)
    return math.sqrt(norm_sqr)


def dot_prod(term1, term2, dims):
    """Computes the normalized dot product between two tensor contraction terms.

    Both terms are normalized using `term_norm`. The dot product is computed by:
      1. Removing factor indices from both terms.
      2. Merging remaining indices where symbols match.
      3. Computing the dimension product associated with the merged characters.
      4. Dividing by the product of the two term norms.

    Args:
        term1 (tuple[str, str, str]): First symbolic tensor term.
        term2 (tuple[str, str, str]): Second symbolic tensor term.
        dims (list[list[int]]): Dimension sizes for index occurrences in each term.

    Returns:
        float: The normalized dot product between the two terms.

    Raises:
        ValueError: If the index matching rules fail, i.e. the merge operation
            encounters incompatible index patterns.

    Example:
        >>> dot_prod(("ij", "jk", "j"), ("ij", "jk", "j"), dims)
        1.0
    """
    eq1 = term1[0] + term1[1]
    eq2 = term2[0] + term2[1]
    factor_eq = term1[2]

    # compute norm of term1 and term2
    term1_norm = term_norm(term1, dims)
    term2_norm = term_norm(term2, dims)

    eq1 = remove(eq1, factor_eq)
    eq2 = remove(eq2, factor_eq)

    merged = to_einsum(merge_partition(to_partition(eq1), to_partition(eq2)), len(eq1))

    # compute product over dims
    prod = dim_prod(set(merged), term1[0], term1[1], dims)
    return prod / (term1_norm * term2_norm)


def generalized_einsum_compile(factor_eq, v_eq, out_eq, out_size, device):
    """
    Pre-computes reduced output indices, axis sizes, and identity tensors
    for a generalized einsum operation for matrix (factor) vector product.

    Args:
        factor_eq (str): Indices of the factor tensor in einsum notation.
        v_eq (str): Einsum equation for `v`.
        out_eq (str): Einsum equation for output.
        out_size (list[int]): Size of each axis in the output tensor.
        device (torch.device or str)

    Returns:
        reduced_out_eq (str): Output equation with repeated indices removed.
        axes (list[int]): Axis sizes aligned with `out_eq` (1 for collapsed axes).
        identities (list[torch.Tensor]): List of broadcastable identity tensors
            enforcing repeated index equality for each repeated index in `out_eq`.
    """
    # reduce those indicies not in out_eq
    reduced_out_eq = ""
    axes, count = [], {}
    merge_eq = set(v_eq + factor_eq)
    for i, c in enumerate(out_eq):
        if c not in count:
            count[c] = [
                i,
            ]
        else:
            count[c].append(i)
        if c not in reduced_out_eq and c in merge_eq:
            reduced_out_eq += c
            axes.append(out_size[i])
        else:
            axes.append(1)

    # broadcast to full out_eq
    identities = []
    for k, v in count.items():
        if len(v) > 1:
            iden = torch.zeros(
                [
                    out_size[v[0]],
                ]
                * len(v)
            ).fill_diagonal_(1)
            shape = [out_size[v[0]] if i in v else 1 for i in range(len(out_eq))]
            identities.append(iden.view(shape).to(device))

    return reduced_out_eq, axes, identities


def generalized_einsum(
    factor, v, factor_eq, v_eq, out_size, reduced_out_eq, axes, identities
):
    """Performs a generalized einsum contraction specifically used for matrix (factor)
    vector product.

    Args:
        factor (torch.Tensor): The factor scalar/ tensor.
        v (torch.Tensor): The primary tensor for the einsum contraction.
        factor_eq (str): Einsum equation for `factor`. Empty string indicates a
            scalar factor.
        v_eq (str): Einsum equation for `v`.
        out_size (list[int]): Final desired output tensor shape.
        reduced_out_eq (str): Einsum equation for the reduced output after contraction.
        axes (list[int]): Shape used to reshape the reduced output before
            applying identity multipliers.
        identities (list[torch.Tensor]): Identity tensors broadcast-multiplied
            into the reshaped output.

    Returns:
        torch.Tensor: The broadcasted output tensor with shape `out_size`.
    """
    # handle batch
    check_batch = len(v.size()) - len(v_eq)
    if check_batch > 0:
        if check_batch > 1:
            raise ValueError("only considering one batch axis")
        v_eq = "z" + v_eq
        reduced_out_eq = "z" + reduced_out_eq
        axes = [v.size(0), *axes]
        out_size = [v.size(0), *out_size]

    if factor_eq == "":  # scalar factor
        red = factor * torch.einsum(
            v_eq + "->" + reduced_out_eq,
            v,
        )
    else:  # non-scalar factor
        red = torch.einsum(
            factor_eq + "," + v_eq + "->" + reduced_out_eq,
            factor,
            v,
        )
    # broadcast to full out_eq
    out = red.view(axes)
    for iden in identities:
        out = out * iden
    return out.expand(out_size)


def make_block(dims, device=None):
    """Create a reshaped identity tensor corresponding to right-hand dimensions.

    Args:
        dims (tuple[list[int], list[int]]): Left and right tensor dimensions of weights.
        device (torch.device or str, optional)


    Returns:
        torch.Tensor:
            A tensor of shape `(prod(right_dims), *right_dims)` created by reshaping
            an identity matrix of size `prod(right_dims) x prod(right_dims)`.
    """
    right_mul = np.prod(dims[1])

    # Identity matrix, reshaped into (right_mul, *right_dims)
    eye = torch.eye(right_mul, device=device).reshape(right_mul, *dims[1])
    return eye


def svd_inv_symm(x):
    """Compute the inverse of a Hermitian matrix using SVD.
    Args:
        x (torch.Tensor): A square Hermitian matrix of shape (..., N, N).

    Returns:
        torch.Tensor: The inverse of `x` with the same shape as the input.
    """
    u, s, _ = torch.linalg.svd(x.to(torch.float64))
    u, s = to_dtype([u, s], torch.get_default_dtype())
    return (u * (1 / s)) @ u.T


class Factor(torch.nn.Module):
    weights: torch.Tensor

    def __init__(self, weights):
        super().__init__()
        self.register_buffer("weights", weights)

    @classmethod
    def from_param(cls, params: tuple[torch.Tensor, torch.Tensor]):
        weights = cls.outer_estimate(params)
        return cls(weights)

    @classmethod
    def from_cov(cls, cov: torch.Tensor):
        weights = cls.cov_estimate(cov)
        return cls(weights)

    @classmethod
    @abstractmethod
    def outer_estimate(
        cls, vectors: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor: ...

    @classmethod
    @abstractmethod
    def cov_estimate(
        cls, cov: torch.Tensor, surrogate: bool = False
    ) -> torch.Tensor: ...

    @abstractmethod
    def cov(self, surrogate: bool = False) -> torch.Tensor: ...

    @abstractmethod
    def matvec(
        self, vec: NDArray, surrogate: bool = False, transpose: bool = False
    ) -> torch.Tensor: ...

    def outer_estimate_(self, lhs, rhs):
        pair = (lhs, rhs)
        weight = self.outer_estimate(pair)
        self.weights.copy_(weight)

    def cov_estimate_(self, cov, surrogate: bool = False):
        weight = self.cov_estimate(cov, surrogate=surrogate)
        self.weights.copy_(weight)


class Compiler(Factor):
    """
    Compiles an invariance specification for a specific set of tensor parameter dimensions.

    This class precomputes matrices, inverses, norms, auxiliary tensors and all quantities
    needed to compute the symo preconditioners.

    Attributes:
        inv (list): List of invariance terms. Each term is a three-element tuple containing
            (out_eq, v_eq, factor_eq).
        dims (tuple[list[int], list[int]]): Left and right tensor dimensions of weights.
        surr_dims (tuple[list[int], list[int]]): Left and right tensor dimensions of surrogate weights.
        device (torch.device or str)
        coeffs (torch.Tensor): Pseudoinverse of the full design matrix.
        surr_coeffs (torch.Tensor): Pseudoinverse of the surrogate design matrix.
        norm (list[float]): Norms of full covariance matrices.
        surr_norm (list[float]): Norms of surrogate matrices.
        einsum_info (list): Precomputed einsum info for full terms.
        einsum_info_trans (list): Precomputed einsum info for transposed terms.
        surr_einsum_info (list): Precomputed einsum info for surrogate terms.
        surr_eye_block (torch.Tensor): Identity block tensor for surrogate dimensions.
    """

    def __init__(self, inv, dims, surr_dims, device):
        """
        Args:
            inv (list): List of invariance terms. Each term is a three-element tuple
                containing (out_eq, v_eq, factor_eq).
            dims (tuple[list[int], list[int]]): Left and right tensor dimensions of weights.
            surr_dims (tuple[list[int], list[int]]): Left and right tensor dimensions of surrogate weights.
            device (torch.device or str)
        """
        weights = self.from_init_fn(inv, dims)

        super().__init__(weights)

        self.inv = inv
        self.surr_dims = surr_dims
        self.dims = dims
        self.device = device
        self.norm = [term_norm(term, dims) for term in self.inv]
        self.surr_norm = [term_norm(term, surr_dims) for term in self.inv]

        self.einsum_info = [
            generalized_einsum_compile(term[2], term[1], term[0], dims[0], device)
            for term in self.inv
        ]
        self.einsum_info_trans = [
            generalized_einsum_compile(term[2], term[0], term[1], dims[1], device)
            for term in self.inv
        ]
        self.surr_einsum_info = [
            generalized_einsum_compile(term[2], term[1], term[0], surr_dims[0], device)
            for term in self.inv
        ]

        coeffs = torch.asarray(
            [[dot_prod(t1, t2, dims) for t2 in self.inv] for t1 in self.inv],
        )

        surr_coeffs = torch.asarray(
            [[dot_prod(t1, t2, surr_dims) for t2 in self.inv] for t1 in self.inv],
        )

        coeffs = svd_inv_symm(coeffs)
        surr_coeffs = svd_inv_symm(surr_coeffs)
        surr_eye_block = make_block(self.surr_dims, self.device)

        self.register_buffer("surr_eye_block", surr_eye_block, persistent=False)
        self.register_buffer("coeffs", coeffs, persistent=False)
        self.register_buffer("surr_coeffs", surr_coeffs, persistent=False)

    def from_init_fn(self, inv, dims):
        """
        Initialize a factor tensor for the invariance terms.

        This method computes the shape of the factor tensor based on the first
        invariance term in `self.inv` and the surrogate dimensions. If `factor_eq`
        is empty, the factor tensor is a scalar. Otherwise, the shape is determined
        by the positions of indices in `factor_eq` within the concatenated left and
        right surrogate dimensions. Finally, it stacks a zero tensor for each term
        in `self.inv`.

        Returns:
            torch.Tensor: A tensor of shape `(len(self.inv), *factor_shape)` filled with zeros,
                where `factor_shape` is determined by `factor_eq` and `surr_dims`.
        """
        left_eq, right_eq, factor_eq = inv[0]
        if factor_eq == "":
            factor_shape = ()
        else:
            eq = left_eq + right_eq
            shape = [*dims[0], *dims[1]]
            factor_shape = tuple(shape[eq.index(c)] for c in factor_eq)

        return torch.stack([torch.zeros(factor_shape) for _ in inv])

    def cov_estimate(self, cov, surrogate=False):
        """
        Estimate factor tensors from the covariance matrix.

        This method computes the right-hand side `b` of the linear system using
        the invariance terms and norms, then multiplies by the pseudoinverse
        (`coeffs` or `surr_coeffs`) to estimate the factor tensors.

        Args:
            cov (torch.Tensor): The covariance matrix from which factors
                are estimated.
            surrogate (bool, optional): If True, use surrogate matrices and norms
                instead of full ones. Defaults to False.

        Returns:
            torch.Tensor: Estimated factor tensor(s) of shape `(len(self.inv), *factor_eq)`,
                where `factor_eq` corresponds to the shape determined by each term.
        """
        coeffs: torch.Tensor = self.surr_coeffs if surrogate else self.coeffs
        norm = self.surr_norm if surrogate else self.norm
        dims = self.surr_dims if surrogate else self.dims

        cov_reshaped = cov.view(*dims[0], *dims[1])
        rhs = [
            torch.einsum(einsum_expr(*term), cov_reshaped) / scale
            for term, scale in zip(self.inv, norm)
        ]
        b = torch.stack(rhs)

        # x = C^-1 b
        k = b.size(0)
        factor_eq = b.size()[1:]

        factors = coeffs @ b.view(k, -1)  # k x k, k x b -> k x b
        factors = factors.view(k, *factor_eq)
        return factors

    def outer_estimate(self, params):
        """
        Estimate factor tensors from the gradients.

        Same procedure as cov_estimate

        Args:
            data (list of torch.Tensor): Input data from which factors
                are estimated. If a list, each element corresponds to the gradient.

        Returns:
            torch.Tensor: Estimated factor tensor(s) of shape `(len(self.inv), *factor_eq)`,
                where `factor_eq` corresponds to the shape determined by each term.
        """
        # form right-hand side, b
        b = torch.stack(
            [
                torch.einsum(contract_expr(*term), *params) / scale
                for term, scale in zip(self.inv, self.norm)
            ],
        )

        k = b.size(0)
        factor_eq = b.size()[1:]
        factors = self.coeffs @ b.view(k, -1)  # k x k, k x b -> k x b
        factors = factors.view(k, *factor_eq)
        return factors

    def matvec(self, v, surrogate=False, transpose=False):
        """
        Perform a matrix-vector multiplication between the factors and the vector.

        This method applies the generalized einsum contractions to compute either
        the forward or transposed multiplication with the factor tensors. Supports
        both full and surrogate dimensions.

        Args:
            v (torch.Tensor): Vector or tensor to be multiplied.
            factors (torch.Tensor): Factor tensors.
            surrogate (bool, optional): If True, use surrogate matrices and precomputed
                einsum info. Defaults to False.
            transpose (bool, optional): If True, perform the transposed operation.
                Defaults to False.

        Returns:
            torch.Tensor: Result of the matrix-vector multiplication.
        """
        einsum_info = self.surr_einsum_info if surrogate else self.einsum_info
        norm = self.surr_norm if surrogate else self.norm
        dims = self.surr_dims if surrogate else self.dims

        if transpose:
            return sum(
                [
                    generalized_einsum(
                        factor,
                        v,
                        term[2],
                        term[0],
                        dims[1],
                        *info,
                    )
                    / scale
                    for term, factor, info, scale in zip(
                        self.inv, self.weights, self.einsum_info_trans, norm
                    )
                ]
            )
        else:
            return sum(
                [
                    generalized_einsum(
                        factor,
                        v,
                        term[2],
                        term[1],
                        dims[0],
                        *info,
                    )
                    / scale
                    for term, factor, info, scale in zip(
                        self.inv, self.weights, einsum_info, norm
                    )
                ]
            )

    def cov(self, surrogate=True):
        """Compute the surrogate covariance matrix using structured tensor operations.

        Args:
            factors (torch.Tensor, optional): Estimated factor tensors. Defaults to None.

        Returns:
            torch.Tensor:
                A 2-D tensor of shape `(prod(left_dims), prod(right_dims))`
                containing the transformed and reshaped result.
        """
        device = self.coeffs.device
        dims = self.surr_dims if surrogate else self.dims
        eye_block = self.surr_eye_block if surrogate else make_block(self.dims, device)
        # surrogate = False will only be called during testing
        # hence not saving eye_block once and for all for full dimension

        left_mul = np.prod(dims[0])
        right_mul = np.prod(dims[1])

        res = self.matvec(eye_block, surrogate=surrogate)  # right_mul, left_eq

        # reshape to (left_mul, right_mul) then transpose
        res = res.reshape(right_mul, left_mul).T
        return res


def einsum_expr(a: str, b: str, c: str) -> str:
    return f"{a}{b}->{c}"


def contract_expr(a: str, b: str, c: str) -> str:
    return f"{a},{b}->{c}"
