from typing import Union

import torch

from discrete_trpl.sparse.densify import densify
from discrete_trpl.sparse.broadcasts import maybe_broadcast


class DefaultSparse:
    """
    Sparse COO tensor wrapper with a per-row "default" value for all unstored entries
    along a chosen axis.

    Concept
    -------
    A `DefaultSparse` represents a sparse tensor `x` together with an implicit
    fill value. The sparse tensor stores only a subset of entries along the
    `dim` axis, while all other entries are implicitly assigned the given
    `fill_value`.

    This lets you work with very large tensors where most rows share the same
    constant default, without having to explicitly store or materialize them.

    Parameters
    ----------
    x : torch.Tensor
        Sparse COO tensor (must satisfy `x.is_sparse`).
    fill_value : float | int | torch.Tensor
        The default value for all implicit (unstored) entries along `dim`.
        May be a scalar, or a dense tensor broadcastable to the "batch shape":
        `x.shape[:dim] + x.shape[dim+1:]`.
    dim : int, default=-1
        Axis along which `fill_value` is applied. All other axes are batch axes.

    Notes
    -----
    - The sparsity pattern is assumed fixed across arithmetic with other
      `DefaultSparse` instances (same indices, same shape, same `dim`).
    - Arithmetic with scalars is supported (addition, subtraction,
      multiplication, division, negation). Operations produce new
      `DefaultSparse` objects with updated explicit values and defaults.
    - `.densify()` returns a dense tensor where implicit entries are filled
      with the broadcasted `fill_value`.
    - `.fill_value_tensor` returns the broadcasted tensor of per-batch defaults.
    - `.fill_value_scalar` tries to extract a scalar default if it is constant.

    This class does not subclass `torch.Tensor`; it wraps a sparse tensor and
    carries the `fill_value` metadata explicitly.
    """
    def __init__(self, x: torch.Tensor, fill_value: torch.Tensor, dim: int = -1):
        assert x.is_sparse, "x must be a sparse tensor"
        self.x = x.coalesce()
        self._fill_value = fill_value
        self.dim = dim
        
    @property
    def device(self): return self.x.device

    @property
    def dtype(self):  return self.x.dtype

    @property
    def shape(self):  return self.x.size()
    
    def coalesce(self):
        self.x = self.x.coalesce()
        return self
    
    def to(self, *args, **kwargs):
        coo = self.x.to(*args, **kwargs)
        d = torch.as_tensor(self._fill_value, device=coo.device, dtype=coo.dtype)
        return self.__class__(coo, d, self.dim)

    @property
    def batch_shape(self):
        sizes = self.shape
        dim = self.dim if self.dim >= 0 else len(sizes) + self.dim
        return sizes[:dim] + sizes[dim + 1:]

    def densify(self) -> torch.Tensor:
        return densify(self.x, self._fill_value, self.dim)

    @property
    def fill_value_scalar(self):
        """
        Return a scalar if the fill value is constant; otherwise None.
        """
        d = self.fill_value_tensor
        if d.numel() == 1:
            return d.detach().clone().reshape(())
        elif torch.allclose(d, d.reshape(-1)[0].expand_as(d)):
            return d.reshape(-1)[0].detach().clone()
        return None

    @property
    def fill_value_tensor(self) -> torch.Tensor:
        # Broadcast to batch_shape on demand
        return maybe_broadcast(self._fill_value, target_shape=self.batch_shape, device=self.device,
                               dtype=self.dtype)

    def clone(self):
        return self.__class__(self.x.clone(),
                              fill_value=torch.as_tensor(self._fill_value, device=self.device,
                                                               dtype=self.dtype).clone(),
                              dim=self.dim)

    def detach(self):
        return self.__class__(self.x.detach(),
                              fill_value=torch.as_tensor(self._fill_value, device=self.device,
                                                               dtype=self.dtype).detach(),
                              dim=self.dim)

    
    
    def exp(self) -> "DefaultSparse":
        """Exponentiate the log-probabilities to get a DefaultSparse with probabilities."""
        y = torch.sparse_coo_tensor(self.x.indices(),
                                    torch.exp(self.x.values()),
                                    self.x.size(),
                                    device=self.device,
                                    dtype=self.dtype)

        new_d = self.fill_value_tensor.exp()
        return DefaultSparse(y, new_d, self.dim)

    def dot(self, other: "DefaultSparse") -> torch.Tensor:
        """
        Rowwise dot product between two DefaultSparse objects with identical sparsity patterns.
        Only supports 2D inputs and reduction along dim=-1 (rowwise dot).
        """
        from discrete_trpl.sparse.dot import sparse_dot_product_with_defaults
        if self.dim != -1 or other.dim != -1:
            raise NotImplementedError("Only dim=-1 supported for dot product")
        if self.shape != other.shape:
            raise ValueError("Shape mismatch")
        if self.x.ndim != 2 or other.x.ndim != 2:
            raise NotImplementedError("Only 2D tensors supported for dot product")
        return sparse_dot_product_with_defaults(self.x,
                                                other.x,
                                                x_default=self.fill_value_tensor,
                                                y_default=other.fill_value_tensor)

    def nan_to_num(
            self,
            nan: float = 0.0,
            posinf: float = None,
            neginf: float = None,
    ) -> "DefaultSparse":
        """
        Replace NaN, +inf, -inf in both explicit values and fill defaults,
        without densifying.

        Parameters
        ----------
        nan : float, default=0.0
            Value to replace NaNs with.
        posinf : float, optional
            Value to replace +inf with. If None, uses the largest finite value
            representable by self.dtype.
        neginf : float, optional
            Value to replace -inf with. If None, uses the smallest finite value
            representable by self.dtype.
        """
        # explicit values
        new_vals = torch.nan_to_num(self.x.values(), nan=nan, posinf=posinf, neginf=neginf)
        new_x = torch.sparse_coo_tensor(self.x.indices(), new_vals, self.x.shape)

        # fill defaults (tensor or scalar)
        fv = self._fill_value
        if isinstance(fv, torch.Tensor):
            new_fv = torch.nan_to_num(fv, nan=nan, posinf=posinf, neginf=neginf)
        else:
            # Python scalar — wrap, run nan_to_num, unwrap
            new_fv = torch.nan_to_num(torch.as_tensor(fv, dtype=self.dtype, device=self.device),
                                      nan=nan, posinf=posinf, neginf=neginf).item()

        return DefaultSparse(new_x, new_fv, self.dim)

    def where(self, cond: torch.Tensor, other: Union[float, int, torch.Tensor]) -> "DefaultSparse":
        """
        Sparse-aware analogue of torch.where(cond, self, other).

        Parameters
        ----------
        cond : torch.Tensor
            Boolean mask of shape broadcastable to self.batch_shape.
        other : scalar | tensor
            The "else" branch. Must be broadcastable to self.shape.

        Returns
        -------
        DefaultSparse
        """
        if cond.dtype != torch.bool:
            raise TypeError("cond must be boolean")

        # broadcast cond to batch_shape
        cond_b = maybe_broadcast(cond.squeeze(self.dim), self.batch_shape, self.device, torch.bool)

        # map cond to explicit indices
        dim = self.dim if self.dim >= 0 else self.x.dim() + self.dim
        batch_idx = [self.x.indices()[i] for i in range(self.x.indices().size(0)) if i != dim]
        cond_vals = cond_b[tuple(batch_idx)]  # shape (nnz,)

        # explicit entries
        other_val = torch.as_tensor(other, device=self.device, dtype=self.dtype)
        new_vals = torch.where(cond_vals, self.x.values(), other_val)
        new_x = torch.sparse_coo_tensor(self.x.indices(), new_vals, self.shape)

        # defaults
        other_b = maybe_broadcast(other, self.batch_shape, self.device, self.dtype)
        new_fill = torch.where(cond_b, self._fill_value, other_b)

        return DefaultSparse(new_x, new_fill, self.dim)


    ### Define scalar operations, i.e., +-*/
    def __add__(self, other):
        if isinstance(other, (int, float)):
            new_vals = self.x.values() + other
            new_fill = self._fill_value + other
        elif torch.is_tensor(other):
            if other.ndim == 0:  # true scalar tensor
                new_vals = self.x.values() + other
                new_fill = self._fill_value + other
            else:
                other_vals = self._broadcast_to_values(other)
                new_vals = self.x.values() + other_vals
                new_fill = self._fill_value + other.reshape(self.batch_shape)
        elif isinstance(other, DefaultSparse):
            return self._apply_ds(other, torch.add)
        else:
            return NotImplemented

        new_x = torch.sparse_coo_tensor(self.x.indices(), new_vals, self.x.shape)
        return DefaultSparse(new_x, new_fill, self.dim)

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            new_vals = self.x.values() * other
            new_fill = self._fill_value * other
        elif torch.is_tensor(other):
            if other.ndim == 0:  # true scalar tensor
                new_vals = self.x.values() * other
                new_fill = self._fill_value * other
            else:
                other_vals = self._broadcast_to_values(other)
                new_vals = self.x.values() * other_vals
                new_fill = self._fill_value * other.reshape(self.batch_shape)
        elif isinstance(other, DefaultSparse):
            return self._apply_ds(other, torch.mul)
        else:
            return NotImplemented

        new_x = torch.sparse_coo_tensor(self.x.indices(), new_vals, self.x.shape)
        return DefaultSparse(new_x, new_fill, self.dim)

    def __sub__(self, other): return self.__add__(-other)
    def __truediv__(self, other): return self.__mul__(1.0 / other)
    def __neg__(self): return self.__mul__(-1)

    def _broadcast_to_values(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Broadcast a batch-shaped tensor to align with self.x.values().
        Ensures result.shape == self.x.values().shape.
        """
        dim = self.dim if self.dim >= 0 else self.x.dim() + self.dim
        batch_idx = [self.x.indices()[i] for i in range(self.x.indices().size(0)) if i != dim]

        gathered = tensor[tuple(batch_idx)]
        target_shape = self.x.values().shape  # (nnz,) or (nnz, *dense_shape)

        if gathered.shape == target_shape:
            return gathered
        elif gathered.shape + target_shape[1:] == target_shape:
            # gathered was (nnz,) and we need (nnz,*dense)
            return gathered.view(-1, *([1] * (len(target_shape) - 1))).expand(target_shape)
        elif gathered.shape[-1] == 1 and target_shape == (gathered.shape[0],):
            # gathered is (nnz,1) but we need (nnz,)
            return gathered.squeeze(-1)
        else:
            # final fallback: rely on broadcast
            return gathered.expand(target_shape)

    # DefaultSparse–DefaultSparse ops (same sparsity pattern guaranteed)
    def _apply_ds(self, other: "DefaultSparse", op):
        """
        Apply binary operation `op` elementwise between two DefaultSparse objects.
        Assumes that the sparsity pattern (indices, shape) is identical.
        """
        if self.dim != other.dim:
            raise ValueError("dim mismatch")

        new_vals = op(self.x.values(), other.x.values())
        # <-- important: operate directly on _fill_value to let PyTorch broadcasting handle scalar/tensor
        new_fill = op(self._fill_value, other._fill_value)

        new_x = torch.sparse_coo_tensor(self.x.indices(), new_vals, self.x.shape)
        return DefaultSparse(new_x, new_fill, self.dim)


    def __repr__(self):
        cls = self.__class__.__name__
        shape = tuple(self.shape)
        nnz = self.x._nnz()
        device = self.device
        dtype = self.dtype
        # show scalar if possible, else just the shape of the tensor-valued default
        fv = self.fill_value_scalar
        if fv is not None:
            fv_repr = f"{fv.item():.4g}"
        else:
            fv_repr = f"Tensor{tuple(self.fill_value_tensor.shape)}"
        return (f"{cls}(shape={shape}, nnz={nnz}, dim={self.dim}, "
                f"fill_value={fv_repr}, dtype={dtype}, device={device})")