"""
Same as discrete_trpl_layer.py, but sparse
"""

import math
from functools import partial
from typing import Tuple

import torch
from torch import nn

from discrete_trpl.sparse.tensor_default import DefaultSparse
from discrete_trpl.dtrpl_config import DtrplOptConfig
from discrete_trpl.empty_info_dict import EMPTY_INFO_DICT
from discrete_trpl.optimizer import Optimizer1D
from discrete_trpl.sparse.helpers import subindex_required_projections, sparse_remap, sparse_where
from discrete_trpl.sparse.densify import densify
from discrete_trpl.sparse.reshape import sparse_reshape
from discrete_trpl.sparse.validate_distribution import sparse_validate_distribution
from discrete_trpl.sparse.tensor_log_prob import SparseLogProb
from discrete_trpl.sparse.align_tensors import sparse_align_tensors
from discrete_trpl.sparse.broadcasts import maybe_broadcast, broadcast_bound


class _SdtrplDualSolver(torch.autograd.Function):
    """
    Custom autograd function for discrete trust region projection (accepts sparse tensors).
    """

    @staticmethod
    def _reps_dual_sparse(
            log_eta: torch.Tensor,  # (K,P)
            bound: torch.Tensor,  # (K,1) or (K,P)
            t_sparse: torch.Tensor,  # (K,V) COO, coalesced
            r_sparse: torch.Tensor,  # (K,V) COO, coalesced
            default_log_prob: float,
    ) -> torch.Tensor:  # (K,P)
        """
        Sparse dual objective equivalent to densifying with `default_log_prob`.

        Semantics:
          - t_sparse and r_sparse share the SAME COO indices (aligned).
          - Explicit entries (including explicit zeros) belong to the UNION set I_k.
          - Implicit entries (not stored) are exactly `default_log_prob = d`.
          - Dense path would do: inner_ki = (eta r_ki + t_ki)/(eta + 1) on union entries,
            and inner_ki = d on all implicit entries. Then LogSumExp over V dims.

        We compute the same LogSumExp without forming the dense (K,V) tensors by:
          1) aggregating over explicit entries via scatter_* ops, and
          2) adding a `(V - |I_k|) * exp(d - m)` tail for the implicit dimensions,
             with `m` the numerically stable row-wise max over {union, {d}}.

        The dense version of this function would look like:
            eta = log_eta.exp()
            eta_ext = eta[..., None]
            inner = (eta_ext * log_ref_prob + log_target_prob) / (eta_ext + 1)
            log_integral = torch.logsumexp(inner, axis=-1)
            return bound * eta + (eta + 1) * log_integral

        Note that we expect there to be numerical differences between the sparse and dense versions, especially when
        not using torch.float64, due to the way we aggregate the values and handle the implicit entries.
        """

        K, V = t_sparse.shape

        idx = t_sparse.indices()  # [2, NNZ]
        rows = idx[0].long()  # [NNZ]
        t_vals = t_sparse.values()  # [NNZ]
        r_vals = r_sparse.values()  # [NNZ]
        device, dtype = t_vals.device, t_vals.dtype

        eta = log_eta.exp()  # (K,P)
        P = eta.size(1)
        eta_rows = eta.index_select(0, rows) if rows.numel() else eta.new_empty(0, P)

        inner_vals = (eta_rows * r_vals.unsqueeze(-1) + t_vals.unsqueeze(-1)) / (eta_rows + 1)  # (NNZ,P)
        d = t_vals.new_tensor(default_log_prob)

        # max over union; then include default via maximum
        max_union = torch.full((K, P), float("-inf"), device=device, dtype=dtype)
        if inner_vals.numel():
            row_idx = rows.unsqueeze(-1).expand(-1, P)  # (NNZ,P)
            max_union = max_union.scatter_reduce(0, row_idx, inner_vals, reduce="amax", include_self=True)
        m = torch.maximum(max_union, d)  # (K,P)

        # sum exp over union
        row_sum_exp = torch.zeros((K, P), device=device, dtype=dtype)
        if inner_vals.numel():
            row_sum_exp.scatter_add_(0, row_idx, torch.exp(inner_vals - m.index_select(0, rows)))

        # count explicit (including explicit zeros); default applies to *implicit* only
        nnz_per_row = torch.zeros((K,), device=device, dtype=torch.long)
        if rows.numel():
            nnz_per_row.scatter_add_(0, rows, torch.ones_like(rows, dtype=torch.long))

        default_mult = (V - nnz_per_row).to(dtype)[:, None]  # (K,1)
        total_exp = row_sum_exp + default_mult * torch.exp(d - m)  # (K,P)
        lse = m + torch.log(total_exp)  # (K,P)

        return bound * eta + (eta + 1) * lse

    @staticmethod
    def forward(ctx, log_target_prob, log_ref_prob, bound, opt_cfg: DtrplOptConfig, default_log_prob=-20.0):
        """
        Forward pass that accepts sparse tensors and handles densification internally.

        Args:
            log_target_prob: Sparse or dense tensor with target log probabilities. Shape (K, V)
            log_ref_prob: Sparse or dense tensor with reference log probabilities.
            Expects same nnz and indices as target, i.e., the sparse tensors must be aligned
            bound: Bound tensor. Shape (K, 1) or (1,)
            opt_cfg: Optimization configuration
            default_log_prob: Default value for zero entries in the sparse tensors.
        """
        assert log_target_prob.is_sparse and log_ref_prob.is_sparse, "Expect aligned sparse COO tensors"

        K = log_target_prob.shape[0]
        device = log_target_prob.device
        dtype = log_target_prob.dtype

        lower = math.log(opt_cfg.lower) * torch.ones((K, 1), device=device, dtype=dtype)
        upper = math.log(opt_cfg.upper) * torch.ones((K, 1), device=device, dtype=dtype)

        dual_sparse = partial(
            _SdtrplDualSolver._reps_dual_sparse,
            bound=bound,
            t_sparse=log_target_prob,
            r_sparse=log_ref_prob,
            default_log_prob=default_log_prob,
        )

        opt_log_eta, stats_dict = Optimizer1D.optimize(
            dual_sparse,
            lower=lower,
            upper=upper,
            num_points=opt_cfg.num_points,
            max_steps=opt_cfg.max_steps,
            x_threshold=opt_cfg.x_threshold
        )
        opt_eta = opt_log_eta.exp()

        # Set context for backward pass
        ctx.opt_cfg = opt_cfg
        ctx.default_log_prob = default_log_prob
        ctx.save_for_backward(log_target_prob, log_ref_prob, opt_eta)
        return opt_eta, stats_dict

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor, _):
        """
        Backward pass that computes gradients w.r.t. the target log probabilities.
        Args:
            grad_output: Gradient of the loss w.r.t. the output of the forward pass. Shape (K, 1)
        """
        log_target_prob, log_ref_prob, opt_eta = ctx.saved_tensors
        default_log_prob = ctx.default_log_prob

        log_primal_solution = _sparse_reps_update(log_target_prob=log_target_prob,
                                                  log_ref_prob=log_ref_prob,
                                                  opt_eta=opt_eta,
                                                  default_log_prob=default_log_prob)
        primal_solution = log_primal_solution.exp()  # (K, V)

        # Convert log and ref to DefaultSparse for safe operations
        log_target_prob = DefaultSparse(log_target_prob, default_log_prob, dim=-1)
        log_ref_prob = DefaultSparse(log_ref_prob, default_log_prob, dim=-1)


        log_prob_diff = log_ref_prob - log_target_prob  # (K, V)
        primal_diff_dot = primal_solution.dot(log_prob_diff)  # (K,)

        d_primal_d_eta = primal_solution * (log_prob_diff - primal_diff_dot[..., None]) / (opt_eta + 1) ** 2
        v_temp: DefaultSparse = (log_primal_solution - log_ref_prob) + 1
        denominator = (-v_temp).dot(d_primal_d_eta)  # (K,)



        primal_vtemp_dot = primal_solution.dot(v_temp)
        nominator = primal_solution * (v_temp - primal_vtemp_dot[..., None]) / (opt_eta + 1)


        fraction = nominator / (denominator.unsqueeze(-1))

        # Make sure the gradients are somewhat sane
        fraction = fraction.nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
        fraction = fraction.where(opt_eta > 10 * ctx.opt_cfg.lower,other=0.0)
        grad_log_target_prob = fraction * grad_output

        # Fully sparse gradients. If you want dense gradients, just call
        # grad_log_target_prob = grad_log_target_prob.densify()
        grad_log_target_prob = grad_log_target_prob.x  # densify()
        return grad_log_target_prob, None, None, None, None


def _sparse_reps_update(log_target_prob: torch.Tensor,
                        log_ref_prob: torch.Tensor,
                        opt_eta: torch.Tensor,
                        default_log_prob: float | torch.Tensor) -> SparseLogProb:
    """
    Performs a REPS-style update to compute the primal solution given the optimal dual variables.

    The dual solver returns dense a scalar eta value for each logit set
    Now we need to compute the projected solution and put it back into the original structure
    We do not need to calculate the inner values for the zeroed-out values of the tensor, since both ref and target
    have the same default value, and thus
    inner = (opt_eta * default_log_prob + default_log_prob) / (opt_eta + 1) = default_log_prob
    before the renormalization.
    However, dense tensors do not support batched division, and we need to take care of 0 values
    (which correspond to a probability of 1 for one of the logits), so this is a bit more complicated than simply
    inner = (opt_eta * ref_union + target_filtered) / (opt_eta + 1)
    dense_inner = inner.to_dense()
    dense_inner[dense_inner == 0] = default_log_prob

        Args:
            log_target_prob: Sparse or dense tensor with target log probabilities. Shape (K, V)
            log_ref_prob: Sparse or dense tensor with reference log probabilities.
            Expects same nnz and indices as target, i.e., the sparse tensors must be aligned
            opt_eta: Optimal dual variable tensor. Shape (K, 1)
            default_log_prob: Default value for zero entries in the sparse tensors.

        Returns:
            sparse_primal_solution: SparseLogProb with the projected and normalized log probabilities. Shape (K, V)
    """
    idx = log_target_prob.coalesce().indices()  # [2, NNZ']
    rows = idx[0].long()
    t_vals = log_target_prob.values()  # [NNZ']
    r_vals = log_ref_prob.values()  # [NNZ']
    eta_row = opt_eta.squeeze(-1).index_select(0, rows)  # [NNZ']
    inner_vals = (eta_row * r_vals + t_vals) / (eta_row + 1)  # [NNZ']
    unnormalized_primal = torch.sparse_coo_tensor(idx, inner_vals, log_target_prob.shape,
                                                  device=t_vals.device, dtype=t_vals.dtype).coalesce()
    unnormalized_primal = SparseLogProb(unnormalized_primal, default_log_prob, dim=-1)
    sparse_primal_solution = unnormalized_primal.log_softmax()  # (K, V)
    return sparse_primal_solution


class SdtrplLayer(nn.Module):
    def __init__(self,
                 check_valid: bool = False,
                 opt_cfg: DtrplOptConfig = None,
                 ):
        super(SdtrplLayer, self).__init__()

        self.check_valid = check_valid
        self.opt_cfg = opt_cfg if opt_cfg is not None else DtrplOptConfig()

    def forward(self, log_target_prob: torch.Tensor,
                log_ref_prob: torch.Tensor,
                bound: torch.Tensor | float,
                default_log_prob: float = -20.0,
                val_eps: float = 1e-5,
                return_dense: bool = False
                ) -> Tuple[torch.Tensor | SparseLogProb, dict]:
        """
        Forward pass of the Sparse Discrete Trust Region Projection Layer.

        Projects log_target_prob onto the trust region defined by KL divergence constraint
        relative to log_ref_prob. This implementation accepts sparse tensors and converts
        them to dense internally, filling zero positions with default_log_prob.

        Args:
            log_target_prob (torch.Tensor): Log probabilities of target distribution.
                Is a sparse tensor of arbitrary shape (..., logits_size).
                Zero positions are filled with default_log_prob.
            log_ref_prob (torch.Tensor): Log probabilities of reference distribution.
                Must have same shape as log_target_prob, but may have different non-zero positions.
            bound (torch.Tensor | float): KL divergence bound for trust region constraint.
                If float, broadcasted to match batch dimensions of input tensors.
                If tensor, should have shape (..., 1) where ... matches input shape[:-1].
                I.e., can set a different bound for each set of logits
            default_log_prob: (float, optional): Default log probability value for zero positions in sparse tensors.
                This is used to fill the zero positions in the sparse tensors before densification.
            val_eps (float, optional): Validation epsilon for distribution checks.
                Only used when check_valid=True. Defaults to 1e-5.

        Returns:
            tuple[torch.Tensor, dict]: A tuple containing:
                - primal_solution (torch.Tensor): *Dense* projected log probabilities with same shape as input tensors.
                  This is a valid logits distribution that satisfies the kl bound w.r.t. the reference distribution.
                - info_dict (dict): Dictionary containing optimization information:
                    - "opt_eta" (torch.Tensor): Optimal dual variable values
                    - "projected_elements" (int): Number of elements that required projection
                    - "avg_violation" (float): Average constraint violation in final solution
                    - "final_kl" (torch.Tensor): Final KL divergence values for each batch element
                    - "initial_kl" (torch.Tensor): Initial KL divergence values before projection

        Raises:
            AssertionError: If check_valid=True and input distributions don't sum to 1
                within val_eps tolerance. This assumes that the zero values of the sparse tensors have been filled
                with default_log_prob. I.e., the non-zero entries must consider this "additional mass" and
                should sum to slightly less than 1.0 to account for the default_log_prob positions.

        Note:
            The method works with arbitrary tensor shapes by treating the last dimension
            as the vocabulary/logits dimension and all preceding dimensions as batch
            dimensions that are processed independently.
        """
        # Validate distributions using sparse operations (if enabled)
        assert log_target_prob.is_sparse, "log_target_prob must be a sparse tensor"
        assert log_ref_prob.is_sparse, "log_ref_prob must be a sparse tensor"
        if self.check_valid:
            assert sparse_validate_distribution(log_target_prob,
                                                default_log_prob=default_log_prob,
                                                val_eps=val_eps), \
                "log_target_prob must be a valid distribution"
            assert sparse_validate_distribution(log_ref_prob,
                                                default_log_prob=default_log_prob,
                                                val_eps=val_eps), \
                "log_ref_prob must be a valid distribution"

        # Broadcast bound if it is a float, such that it has shape (..., 1)
        bound = broadcast_bound(bound, log_target_prob)

        # Build unions of the sparse tensors for efficient computation.
        # I.e., align both tensors along the logit dimensions to have the same indices/sparsity pattern.
        # From here on, we will only need these aligned tensors.
        target_aligned, ref_aligned = sparse_align_tensors(log_target_prob, log_ref_prob,
                                                           default_log_prob=default_log_prob)

        # Find which unique batch elements need projection and extract them. Filter and flatten the target and ref
        # tensors to only those elements.
        # Returns:
        #         unique_batch_indices: 1D int64 tensor with the *flattened* original batch indices
        #                               (range [0, prod(batch_dims))) that require projection.
        #         target_filtered:      COO sparse tensor (K, V) with only those rows (K=len(unique)).
        #         ref_union:         COO sparse tensor (K, V) with only those rows.
        #         kl_div:               Tensor with KL divergence for each batch element.
        #         _bound:               Tensor (K, 1) with bounds for those rows.
        unique_batch_indices, target_union, ref_union, kl_div, _bound = subindex_required_projections(
            target_aligned, ref_aligned, bound
        )

        if len(unique_batch_indices) == 0:
            # If nothing needs to be projected, we just return the original sparse tensor
            info_dict = EMPTY_INFO_DICT
            info_dict["initial_kl"] = kl_div.detach()
            info_dict["final_kl"] = kl_div.detach()
            if return_dense:
                return densify(log_target_prob, fill_value=default_log_prob), EMPTY_INFO_DICT
            else:
                return SparseLogProb(log_target_prob, fill_value=default_log_prob), EMPTY_INFO_DICT

        # Pass sparse tensors directly to the dual solver. This is the heart of the method, and solves a 1d convex
        # optimization problem for each unique batch element that requires projection.
        _opt_eta, info_dict = _SdtrplDualSolver.apply(
            target_union, ref_union, _bound, self.opt_cfg, default_log_prob
        )

        # Compute the (sparse) reps update, i.e.,
        # sparse_primal_solution = log_softmax((opt_eta * log_ref_prob + log_target_prob) / (opt_eta + 1))
        sparse_primal_solution = _sparse_reps_update(target_union, ref_union, _opt_eta, default_log_prob)

        # Given the sparse primal solution for the filtered rows, we now need to put it back into the original
        # tensor structure. We do this by creating a mask for the rows that were projected, and then using sparse_where
        # to combine the original log_target_prob with the new primal solution.
        # This ensures that we only replace the rows that were projected, and keep the original values
        # for the rows that were not projected.

        # Create shape aliases
        # log_target_prob: (*batch, V)  |  sparse_primal_solution: (K, V)  |  unique_batch_indices: (K,) in [0, ∏batch)
        batch_shape = tuple(log_target_prob.shape[:-1])  # (*batch,)
        logit_dim = log_target_prob.shape[-1]  # (V,)
        batch_entries = math.prod(batch_shape)  # (N,)

        # Flatten to [N, V] (sparse_reshape must not densify)
        log_bs = sparse_reshape(log_target_prob, (batch_entries, logit_dim))  # [N, V]
        prim_bs = sparse_remap(sparse_primal_solution.x, unique_batch_indices, batch_entries)  # [N, V]

        mask = torch.zeros(batch_entries, 1, dtype=torch.bool, device=log_bs.device)  # [N, 1] -> broadcasts to [N, V]
        mask[unique_batch_indices] = True  # True on rows to replace
        result_bs = sparse_where(mask, prim_bs,
                                 log_bs).coalesce()  # [N, V], primary on True/projected rows, target else

        # Restore original shape, re-align the default_log_prob_tensor() with new primal solution shape
        primal_solution = sparse_reshape(result_bs, log_target_prob.shape)  # [*batch, V]

        new_default_log_probs = maybe_broadcast(default_log_prob,
                                                target_shape=(batch_entries,),
                                                device=primal_solution.device,
                                                dtype=primal_solution.dtype)  # [N]
        new_default_log_probs[unique_batch_indices] = sparse_primal_solution.fill_value_tensor  # [K]

        # Final SparseLogProb of shape [*batch, V] with correct default_log_prob_tensor()
        joint_primal = SparseLogProb(primal_solution,
                                     new_default_log_probs.reshape(batch_shape),  # [*batch]
                                     dim=-1)

        # tmp = joint_primal.data.sum()
        # grad = torch.autograd.grad(tmp, log_target_prob, retain_graph=True, allow_unused=True)[0]
        # print("stepX grad w.r.t log_target_prob:", grad)
        # print("@@")

        # For violation computation, use sparse operations
        with torch.no_grad():
            # compare to the reference
            log_ref_prob = SparseLogProb(log_ref_prob, default_log_prob, dim=-1)
            final_kl_divergences = joint_primal.kl(log_ref_prob)

            kl_violation = torch.clamp(final_kl_divergences - bound.squeeze(-1), min=0.0)
            info_dict["violation"] = kl_violation.detach()
            info_dict["final_kl"] = final_kl_divergences.detach()
            info_dict["initial_kl"] = kl_div.detach()

            opt_eta = torch.zeros((batch_entries, 1),
                                  device=joint_primal.device,
                                  dtype=joint_primal.dtype)
            opt_eta[unique_batch_indices] = _opt_eta
            info_dict["opt_eta"] = opt_eta.detach()
            info_dict["projected_elements"] = len(unique_batch_indices)
            info_dict["projected_frac"] = len(unique_batch_indices) / len(opt_eta)

        if return_dense:
            return joint_primal.densify(), info_dict
        else:
            return joint_primal, info_dict
