import torch
from tqdm import tqdm
import pdb
import torch
import torch.nn as nn

def gumbel_sigmoid(logits, tau=0.5):
    """Apply Gumbel Sigmoid to logits"""

    def sample_gumbel(shape, dtype, device, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    gumbel_noise = sample_gumbel(logits.shape, logits.dtype, logits.device)
    gumbel_logits = logits + gumbel_noise
    y_soft = torch.sigmoid(gumbel_logits / tau)

    return y_soft

# class LinearLowRank(torch.nn.Module):
#     def __init__(self, init_config):
#         super(LinearLowRank, self).__init__()
#         """
#         More efficient in the forward pass by avoiding first materialization of the weight

#         Inputs: Linear layer to perform ASVD on.
#         Approach: Parameter + gumbel sigmoid to generate mask
#         """
#         self.in_features = init_config['in_features']
#         self.out_features = init_config['out_features']
#         self.rank = init_config['rank']

#         self.V_t = torch.nn.Linear(self.in_features, self.rank, bias=False)
#         self.UE = torch.nn.Linear(self.rank, self.out_features, bias=False)

#     def forward(self, inputs):
#         x = self.V_t(inputs)
#         return self.UE(x)

#     def __str__(self):
#         return f"LinearLowRank(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank})"

#     def __repr__(self):
#         return self.__str__()
    
class LinearLowRank(nn.Module):
    def __init__(self, init_config):
        super(LinearLowRank, self).__init__()  # Fixed class name here

        self.in_features = init_config['in_features']
        self.out_features = init_config['out_features']
        self.rank = init_config['rank']

        # Base low-rank decomposition layers
        self.V_t = nn.Linear(self.in_features, self.rank, bias=False)
        self.UE = nn.Linear(self.rank, self.out_features, bias=False)

        self.peft_config = init_config.get('peft_config', None)
        self.lora_alpha = 1.

        if self.peft_config:
            r2 = self.peft_config.get('lora_rank', self.rank)  # Reduced rank for LoRA layers, can be configured
            self.lora_alpha = self.peft_config.get('lora_alpha', 2.)
            self.lora_dropout_p  = nn.Dropout(self.peft_config.get('lora_dropout', 0.0))

            # LoRA layers for V_t
            self.peft_V_t_A = nn.Linear(self.in_features, r2, bias=False)
            self.peft_V_t_B = nn.Linear(r2, self.rank, bias=False)

            # LoRA layers for UE
            self.peft_UE_A = nn.Linear(self.rank, r2, bias=False)
            self.peft_UE_B = nn.Linear(r2, self.out_features, bias=False)

            nn.init.zeros_(self.peft_V_t_A.weight)
            nn.init.zeros_(self.peft_V_t_B.weight)
            nn.init.zeros_(self.peft_UE_A.weight)
            nn.init.zeros_(self.peft_UE_B.weight)

    def forward(self, inputs):
        # Forward pass for base layers
        x = self.V_t(inputs)
        
        # If LoRA (peft) is available, add LoRA adaptation to both layers
        if self.peft_config:
            # Compute the LoRA adaptation for V_t
            peft_V_t_weight = self.peft_V_t_B(self.peft_V_t_A(inputs))
            x += self.lora_alpha * self.lora_dropout_p(peft_V_t_weight)

        output = self.UE(x)

        if self.peft_config:
            # print('Input: ', x.shape) 
            # print('self.peft_UE_A', self.peft_UE_A)
            # print('self.peft_UE_B', self.peft_UE_B)
            peft_UE_weight = self.peft_UE_B(self.peft_UE_A(x))
            output += self.lora_alpha * self.lora_dropout_p(peft_UE_weight)
        return output

    def __str__(self):
        return f"LinearLowRankPEFT(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank})"

    def __repr__(self):
        return self.__str__()
