import torch
from tqdm import tqdm
import pdb
from utils import train_utils


class SVDLinear(torch.nn.Module):
    def __init__(self, current_layer, rank, svd_vector=None, alpha=0.5):
        super(SVDLinear, self).__init__()
        """
        Converts the linear layer into a layer where the weight matrix is decomposed using SVD into singular vectors and singular values
        """
        self.in_features = current_layer.weight.shape[1]
        self.out_features = current_layer.weight.shape[0]
        self.rank = rank

        weight = current_layer.weight.float()
        
        # push to gpu only for svd 
        layer_device = weight.device
        if torch.cuda.is_available():
            weight = weight.cuda()

        if svd_vector is not None: 
            svd_vector += 1e-6 # nan
            svd_vector = svd_vector.to(weight.device)**alpha
            weight = weight * svd_vector.unsqueeze(0)
        
        with torch.no_grad():
            #self.rank = int(self.in_features*self.out_features/(self.in_features + self.out_features)) # start from no compression   
            U, E, V = torch.svd_lowrank(weight,
                                    q=self.rank,
                                    niter=2)
        
        if svd_vector is not None: 
            V = V / svd_vector.unsqueeze(1)
        
        U, E, V = U.to(layer_device), E.to(layer_device), V.to(layer_device)

        self.V_t = torch.nn.Linear(self.in_features, V.shape[1], bias=False)
        self.UE = torch.nn.Linear(U.shape[1], self.out_features, bias=False)

        UE = (U * E.unsqueeze(0)).to(U.dtype)

        self.UE.weight.data = UE.contiguous() 
        self.V_t.weight.data = V.T.contiguous() 

    def forward(self, inputs):
        x = self.V_t(inputs)
        return self.UE(x)

    def __str__(self):
        return f"LinearLowRank(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank})"

    def __repr__(self):
        return self.__str__()
    
class LowrankLinear(torch.nn.Module):
    def __init__(self, current_layer, init_frac, svd_vector, alpha=1., niter=2, tau=0.1, bias_init=False, mask_eval_type=""):
        """
        Decomposes the weight in a linear layer into its singular vectors and values and introduces a learnable mask.

        Args:
            current_layer (nn.Linear): Current linear layer to be decomposed.
            init_frac (float): Initial fraction of singular values to use (default: 1.0).
            svd_vector (torch.Tensor, optional): Weight scales for weighted SVD.
            alpha (float): Hyperparameter for weighted ASVD (default: 1.0).
            niter (int): Number of SVD iterations (default: 2).
            tau (float): Temperature of Gumbel sigmoid (lower values create harder boundaries) (default: 0.1).
            bias_init (bool): Flag to add bias in initialization of the weights (default: False).
            mask_eval_type (str): Type of evaluation mode: 'topk', 'threshold', or '' (default: "").
        """
        super(LowrankLinear, self).__init__()

        if not isinstance(current_layer, torch.nn.Linear):
            raise ValueError(f"Expected input into SVDLayer be of instance nn.Linear, got {type(current_layer)}")

        dtype = current_layer.weight.dtype
        self.in_features, self.out_features = current_layer.in_features, current_layer.out_features
        self.rank = min(current_layer.weight.shape[1], current_layer.weight.shape[0])

        weight = current_layer.weight.float()
        layer_device = weight.device


        if torch.cuda.is_available():
            weight = weight.cuda()

        if svd_vector is not None: 
            svd_vector += 1e-6 # division by zero
            svd_vector = svd_vector.to(weight.device)**alpha
            weight = weight * svd_vector.unsqueeze(0)
        
        with torch.no_grad():
            U, E, V = torch.svd_lowrank(weight, 
                                    q=self.rank,
                                    niter=niter)
        
        if svd_vector is not None: 
            V = V / svd_vector.unsqueeze(1)
        
        U, E, V = U.to(layer_device), E.to(layer_device), V.to(layer_device)

        assert len(E.shape) == 1, 'expected singular values to have only one dim'

        # precompute EV for efficency
        self.UE = torch.nn.Parameter((U * E.unsqueeze(0)).to(dtype))
        self.V_t = torch.nn.Parameter(V.T.to(dtype))

        if not bias_init:
            init_vector = torch.ones_like(E, dtype=dtype, device=E.device) * 3.5
        else: 
            init_vector = torch.linspace(6, 3., len(E), device=E.device).to(dtype)
        init_vector[round(len(init_vector) * init_frac):] =  -3.5

        self.E_train = torch.nn.Parameter(init_vector)
        self.tau = tau

        self.mask_eval_type = mask_eval_type
        if self.mask_eval_type not in ["", "threshold", "topk"]:
            raise NotImplementedError(f"mask_eval_type: {self.mask_eval_type} not supported in LowrankLinear")


    def forward(self, inputs):
        """
        Computes forward pass with selection of singular values through predicted mask.

        Args:
            inputs (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        E_train_mask = self.calculate_mask(is_training=self.training)
        inputs = inputs.transpose(1, 2)

        if inputs.device != self.V_t.device: # multi-gpu setup
            inputs = inputs.to(self.V_t.device)
    
        output = (self.UE * E_train_mask.unsqueeze(0)) @ (self.V_t @ inputs)
        output = output.transpose(1, 2)

        return output
    
    def calculate_mask(self, is_training, return_probs=False):
        """
        Calculates the mask for singular value selection. During training, it uses Gumbel-Sigmoid approximation.
        During evaluation, various non-differentiable operations can be used

        Args:
            is_training (bool): Whether the model is in training mode.
            return_probs (bool): Whether to return probabilities instead of binary mask (default: False).

        Returns:
            torch.Tensor: Mask for singular value selection.
        """

        if is_training:
            E_train_mask = gumbel_sigmoid(self.E_train, tau=self.tau)
        else:
            if return_probs: # if we dont need binary mask, just return E_train
                E_train_mask = torch.sigmoid(self.E_train)
            elif not self.mask_eval_type:
                E_train_mask = gumbel_sigmoid(self.E_train, tau=self.tau) > 0.5
            elif self.mask_eval_type == 'threshold': 
                E_train_mask = torch.sigmoid(self.E_train) > 0.50
            elif self.mask_eval_type == 'topk':
                topk = int((torch.sigmoid(self.E_train) > 0.50).sum().item())
                E_train_mask = torch.zeros_like(self.E_train, dtype=torch.bool)
                E_train_mask[:topk] = True
        
            else:
                raise NotImplementedError(f"self.mask_eval_type: {self.mask_eval_type} not supported")

            # in eval, if compression rate is less use the full rank 
            compression_rate = E_train_mask.sum().item() * (self.in_features + self.out_features) / (self.in_features * self.out_features)
            if compression_rate > 0.99:
                E_train_mask = torch.ones_like(self.E_train, dtype=torch.bool)

        return E_train_mask

    def __str__(self):
        return f"LowrankLinear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank})"

    def __repr__(self):
        return self.__str__()
    
def gumbel_sigmoid(logits, tau=0.5):
    """Apply Gumbel Sigmoid to logits"""

    def sample_gumbel(shape, dtype, device, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    gumbel_noise = sample_gumbel(logits.shape, logits.dtype, logits.device)
    gumbel_logits = logits + gumbel_noise
    y_soft = torch.sigmoid(gumbel_logits / tau)
    return y_soft
