import torch
import torch.nn as nn
import numpy as np
from scipy.optimize import linear_sum_assignment

class PermutedLinear(nn.Module):
    """
    A drop-in replacement for nn.Linear that learns a permutation of its input features.

    This module wraps a standard nn.Linear layer and prepends a learnable permutation
    operation. The permutation is learned as a "soft" doubly-stochastic matrix during
    training and is converted to a "hard" permutation for efficient re-indexing
    during inference.

    The permutation learning is guided by a penalty loss that must be added to the
    main task loss during training.

    Args:
        in_features (int): Size of each input sample.
        out_features (int): Size of each output sample.
        bias (bool): If set to False, the layer will not learn an additive bias. Default: True.

    Attributes:
        linear (nn.Linear): The underlying linear layer.
        soft_permutation (nn.Parameter): The learnable (N x N) soft permutation matrix,
            where N is in_features.
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # The standard linear layer that holds the weights W
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        # Initialize the soft permutation matrix M.
        # We start with a matrix close to identity to ensure stable training at the beginning.
        # A small amount of uniform noise is added to break symmetry.
        identity = torch.eye(in_features)
        noise = torch.rand(in_features, in_features) / 1000.0
        initial_matrix = identity + noise
        self.soft_permutation = nn.Parameter(initial_matrix)

        # Buffer to store the hard permutation indices for efficient inference
        self.register_buffer('hard_permutation_indices', torch.arange(in_features))
        self.is_hard_perm_cached = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies the permuted linear transformation.

        In training mode, it uses the soft permutation matrix via matrix multiplication.
        In evaluation mode, it uses a cached hard permutation via efficient re-indexing.
        """
        if self.training:
            # During training, apply the learnable soft permutation matrix
            permuted_x = x @ self.soft_permutation
            return self.linear(permuted_x)
        else:
            # During inference, use the pre-computed hard permutation for efficiency
            if not self.is_hard_perm_cached:
                self._cache_hard_permutation()
            
            permuted_x = x[:, self.hard_permutation_indices]
            return self.linear(permuted_x)

    def get_permutation_loss(self) -> torch.Tensor:
        """
        Calculates the penalty loss to enforce the soft_permutation matrix
        to become a permutation matrix.

        This loss is based on Equation 14 from the paper, which penalizes
        doubly-stochastic matrices that are not vertices of the Birkhoff polytope.
        The penalty is zero if and only if the matrix is a permutation matrix.
        """
        # We work with the absolute value to handle potential negative values during optimization
        M = torch.abs(self.soft_permutation)
        
        # L1-L2 penalty for rows
        row_penalty = torch.sum(M.sum(dim=1) - torch.sqrt((M**2).sum(dim=1)))
        
        # L1-L2 penalty for columns
        col_penalty = torch.sum(M.sum(dim=0) - torch.sqrt((M**2).sum(dim=0)))
        
        return row_penalty + col_penalty

    def _cache_hard_permutation(self):
        """

        Computes and caches the hard permutation indices from the soft permutation matrix.
        This is done by solving the linear sum assignment problem (Hungarian algorithm),
        which finds the optimal permutation.
        """
        # Detach the matrix from the computation graph and move to CPU for scipy
        soft_perm_np = self.soft_permutation.detach().cpu().numpy()
        
        # The Hungarian algorithm finds the minimum weight matching.
        # To find the permutation that maximizes the sum of entries, we solve for the minimum
        # of the negated matrix.
        row_ind, col_ind = linear_sum_assignment(-soft_perm_np)
        
        # Create an index tensor and store it.
        # `col_ind` now represents the optimal permutation of columns.
        self.hard_permutation_indices = torch.tensor(col_ind, device=self.soft_permutation.device).long()
        self.is_hard_perm_cached = True
