import math
import torch.nn as nn
import torch
from loralib import LoRALayer
from utils.layers import T5LayerNorm
from torch import Tensor
import torch.nn.functional as F


class ConLoRALinear(nn.Linear, LoRALayer):

    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        dataset: str = "multiwoz",
        layer_idx=0,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        self.dataset = dataset
        self.layer_idx = layer_idx
        self.num_layers = 6
        self.is_high_layer = (self.layer_idx+1) >= (self.num_layers // 2)
        self.alpha = 0.5

        self.tau = 1.0
        self.tau_step = 5e-5
        self.tau_final = 0.1
        
        # Use external gate parameter
        self.gate = None

        # Actual trainable parameters
        if r > 0:
            self.lora_UniRep_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_UniRep_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.lora_SemAdapt_A = nn.ParameterList()
            self.lora_SemAdapt_B = nn.ParameterList()

            # Add M SemAdapt A matrices
            for _ in range(2):
                self.lora_SemAdapt_A.append(nn.Parameter(self.weight.new_zeros((r, in_features))))
            
            # Add N SemAdapt B matrices
            for _ in range(3):
                self.lora_SemAdapt_B.append(nn.Parameter(self.weight.new_zeros((out_features, r))))

            # Initialize context-prompt attention layer
            # self.cross_attn = nn.MultiheadAttention(embed_dim=out_features, num_heads=8, batch_first=True)
            # self.attn_layer_norm = T5LayerNorm(out_features)

            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False

        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_UniRep_A'):
            #weight = self.weight.data           


            self.lora_UniRep_init()
            self.lora_SemAdapt_init()
            # M = 2
            # N = 3
            # for _ in range(M):
            #     self.lora_SemAdapt_A.append(nn.Parameter(self.lora_UniRep_A.clone() / M))
            # for _ in range(N):
            #     self.lora_SemAdapt_B.append(nn.Parameter(self.lora_UniRep_B.clone() / N))
            
            # nn.init.kaiming_uniform_(self.lora_UniRep_A, a=math.sqrt(5))
            # nn.init.zeros_(self.lora_UniRep_B)
            
            for a in self.lora_SemAdapt_A:
                nn.init.kaiming_uniform_(a, a=math.sqrt(5))
            for b in self.lora_SemAdapt_B:
                nn.init.zeros_(b)       

            self.lora_UniRep_A.data = self.lora_UniRep_A.data.cuda()
            self.lora_UniRep_B.data = self.lora_UniRep_B.data.cuda()
            self.lora_SemAdapt_A = self.lora_SemAdapt_A.cuda()
            self.lora_SemAdapt_B = self.lora_SemAdapt_B.cuda()
            self.weight.data = self.weight.data.cuda()

            # import utils.plot.init_heatmap as hp
            # hp.plot_heatmaps(weight, self.weight.data, self.lora_SemAdapt_B[0], self.lora_SemAdapt_A[0], self.lora_UniRep_B, self.lora_UniRep_A)
            # exit(0)

    @staticmethod
    def compute_entropy(embedding: Tensor) -> Tensor:
        p = embedding / (torch.sum(embedding, dim=-1, keepdim=True) + 1e-8)  # Calculate probability distribution
        p = p + 1e-8  # Prevent numerical instability
        return -torch.sum(p * torch.log(p), dim=-1)

    @staticmethod
    def compute_alpha_layer(entropy: Tensor, mu: float = 0.5, k: float = 2.0) -> Tensor:
        return 1.0 / (1.0 + torch.exp(-k * (entropy - mu)))

    @classmethod
    def get_domain_slot_cluster_embeddings(cls, dataset='multiwoz'):
        # Initialize class-level cache
        if not hasattr(cls, '_domain_slot_emb_cache'):
            cls._domain_slot_emb_cache = {}
        
        # Check if cache exists
        if dataset in cls._domain_slot_emb_cache:
            return cls._domain_slot_emb_cache[dataset]
        
        # Calculate category embeddings
        from utils.domain_slot_clustering import domain_slot_clustering
        clusters, cluster_num, slot_embeddings = domain_slot_clustering(dataset=dataset)
        
        category_embeddings = torch.zeros(cluster_num, slot_embeddings.size(1), device=slot_embeddings.device)
        counts = torch.zeros(cluster_num, device=slot_embeddings.device)
        
        for i, cluster in enumerate(clusters):
            label = cluster['cluster_label']
            category_embeddings[label] += slot_embeddings[i]
            counts[label] += 1
        
        # Calculate average embeddings (add epsilon to prevent division by zero)
        domain_slot_cluster_embeddings = category_embeddings / (counts.unsqueeze(1) + 1e-8)
        domain_slot_cluster_embeddings = domain_slot_cluster_embeddings.cuda()
        domain_slot_cluster_embeddings = torch.nn.functional.normalize(domain_slot_cluster_embeddings, p=2, dim=1)
        
        # Cache results
        cls._domain_slot_emb_cache[dataset] = domain_slot_cluster_embeddings
        return domain_slot_cluster_embeddings

    @classmethod
    def get_domain_embeddings(cls, dataset='multiwoz'):
        # Initialize class-level cache
        if not hasattr(cls, '_domain_emb_cache'):
            cls._domain_emb_cache = {}
        
        # Check if cache exists
        if dataset in cls._domain_emb_cache:
            return cls._domain_emb_cache[dataset]
        
        # Calculate category embeddings
        from utils.domain_slot_clustering import domain_clustering
        clusters, cluster_num, domain_embeddings = domain_clustering(dataset=dataset)
        
        category_embeddings = torch.zeros(cluster_num, domain_embeddings.size(1), device=domain_embeddings.device)
        counts = torch.zeros(cluster_num, device=domain_embeddings.device)
        
        for i, cluster in enumerate(clusters):
            label = cluster['cluster_label']
            category_embeddings[label] += domain_embeddings[i]
            counts[label] += 1
        
        # Calculate average embeddings (add epsilon to prevent division by zero)
        domain_cluster_embeddings = category_embeddings / (counts.unsqueeze(1) + 1e-8)
        domain_cluster_embeddings = domain_cluster_embeddings.cuda()
        domain_cluster_embeddings = torch.nn.functional.normalize(domain_cluster_embeddings, p=2, dim=1)
        
        # Cache results
        cls._domain_emb_cache[dataset] = domain_cluster_embeddings
        return domain_cluster_embeddings

    def select_ab_matrices(self, prompt_emb, tau=1.0, training=True):
        mu_m = self.get_domain_slot_cluster_embeddings(dataset=self.dataset)  # (M, embed_dim)
        nu_n = self.get_domain_embeddings(dataset=self.dataset)  # (N, embed_dim)
        batch_size = prompt_emb.size(0)

        # Calculate similarity
        # Average pool prompt_emb over sequence dimension, shape changes from [batch, seq_len, embed_dim] to [batch, embed_dim]
        prompt_emb_avg = prompt_emb.mean(dim=1)
        # Calculate similarity (batch_size, M) and (batch_size, N)
        sim_m = F.cosine_similarity(prompt_emb_avg.unsqueeze(1), mu_m.unsqueeze(0), dim=2)  # [batch, M]
        match_n = F.cosine_similarity(prompt_emb_avg.unsqueeze(1), nu_n.unsqueeze(0), dim=2)  # [batch, N]

        if training and self.r > 0:
            # Gumbel-Softmax sampling (differentiable version)
            gumbel_m = -torch.log(-torch.log(torch.rand_like(sim_m) + 1e-10) + 1e-10)
            logits_m = (sim_m + gumbel_m) / tau
            selected_m = F.gumbel_softmax(logits_m, tau=tau, hard=True)  # [batch, M]
        else:
            selected_m = F.one_hot(torch.argmax(sim_m, dim=1), num_classes=sim_m.size(1)).float() if self.r > 0 else torch.zeros(batch_size, mu_m.size(0), device=prompt_emb.device)

        if training and self.r > 0:
            # Gumbel-Softmax sampling (differentiable version)
            gumbel_n = -torch.log(-torch.log(torch.rand_like(match_n) + 1e-10) + 1e-10)
            logits_n = (match_n + gumbel_n) / tau
            selected_n = F.gumbel_softmax(logits_n, tau=tau, hard=True)  # [batch, N]
        else:
            selected_n = F.one_hot(torch.argmax(match_n, dim=1), num_classes=match_n.size(1)).float() if self.r > 0 else torch.zeros(batch_size, nu_n.size(0), device=prompt_emb.device)

        # Combine matrices through probability weighting (preserve gradients)
        batch_lora_SemAdapt_A = torch.stack(list(self.lora_SemAdapt_A))  # [M, r, in_features]
        batch_lora_SemAdapt_B = torch.stack(list(self.lora_SemAdapt_B))  # [N, out_features, r]

        # Use matrix multiplication instead of index selection for probability weighted combination
        selected_a = torch.einsum('bm,mri->bri', selected_m, batch_lora_SemAdapt_A) * len(self.lora_SemAdapt_A) # [batch, r, in_features]
        selected_b = torch.einsum('bn,nro->bro', selected_n, batch_lora_SemAdapt_B) * len(self.lora_SemAdapt_B)# [batch, out_features, r]

        self.tau = max(self.tau - self.tau_step, self.tau_final)
        return selected_a, selected_b

    def lora_UniRep_init(self):
        # Get base layer weight
        weight = self.weight.cuda()
        dtype = weight.dtype
        # Check if data type supports PiSSA initialization
        if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
            raise TypeError(
                "Initialize PiSSA under float32, float16, or bfloat16 data types."
                "Then re-quantize the residual model to help minimize quantization error."
            )
        # Convert to float32 for SVD computation
        weight = weight.to(torch.float32)
        # Perform singular value decomposition on weight matrix: weight = V @ diag(S) @ Uh
        V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)

        # Get cached category embedding matrix
        E_c_normalized = self.get_domain_slot_cluster_embeddings(dataset=self.dataset)

        # Extract first r singular vectors
        Vr = V[:, : self.r]  # First r columns of right singular vectors (out_dim, rank)
        Sr = S[: self.r]     # First r singular values
        # Sr /= self.scaling   # Divide singular values by scaling factor
        Uhr = Uh[: self.r]   # First r rows of left singular vectors

        # Calculate correlation matrix between singular vectors and category embeddings (Step 3)
        # Vr shape: (out_dim, rank), slot_embeddings shape: (K, embed_dim)
        # Assume out_dim == embed_dim, otherwise need to add projection layer
        # Normalize singular vectors and category embeddings to calculate cosine similarity
        Vr_normalized = torch.nn.functional.normalize(Vr, p=2, dim=0)

        R = torch.matmul(Vr_normalized.T, E_c_normalized.T)  # Cosine similarity matrix (rank, K)

        # Singular value enhancement based on category correlation (Step 4)
        max_correlations = R.max(dim=1).values  # Maximum correlation for each singular vector
        Sr_enhanced = Sr * (1 + self.alpha * max_correlations)  # Enhance singular values

        # Initialize UniRep-LoRA A and B matrices (Step 5)
        lora_A = torch.diag(torch.sqrt(Sr_enhanced)) @ Uhr
        lora_B = Vr @ torch.diag(torch.sqrt(Sr_enhanced))

        # Adjust base weights (Step 6)
        weight = weight.data - self.scaling * (lora_B @ lora_A)

        # Restore original data type and update base layer weights
        weight = weight.to(dtype)
        self.weight.data = weight
        self.lora_UniRep_A = nn.Parameter(lora_A)
        self.lora_UniRep_B = nn.Parameter(lora_B)

    def lora_SemAdapt_init(self):
        # Get base layer weight
        weight = self.weight.cuda()
        dtype = weight.dtype
        # Check if data type supports initialization
        if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
            raise TypeError(
                "Initialize SemAdapt-LoRA under float32, float16, or bfloat16 data types."
                "Then re-quantize the residual model to help minimize quantization error."
            )
        # Convert to float32 for SVD computation
        weight = weight.to(torch.float32)
        # Perform singular value decomposition on weight matrix: weight = V @ diag(S) @ Uh
        V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)

        # Get cached domain-slot and domain clustering average embeddings (step 1)
        mu_m_normalized = self.get_domain_slot_cluster_embeddings(dataset=self.dataset)  # μ_m
        nu_n_normalized = self.get_domain_embeddings(dataset=self.dataset)  # ν_n

        # Extract first r singular vectors
        Vr = V[:, : self.r]  # First r columns of right singular vectors (out_dim, rank)
        Sr = S[: self.r]     # First r singular values
        # Sr /= self.scaling   # Divide singular values by scaling factor
        Uhr = Uh[: self.r]   # First r rows of left singular vectors

        # Normalize singular vectors and cluster embeddings to calculate cosine similarity
        Vr_normalized = torch.nn.functional.normalize(Vr, p=2, dim=0)

        # Calculate correlation matrix and enhance singular values (Step 3)
        
        # Calculate enhanced singular values for each domain-slot cluster (M clusters)
        R_m = torch.matmul(Vr_normalized.T, mu_m_normalized.T)  # (rank, m_clusters)
        m_clusters = R_m.shape[1]
        Sr_enhanced_m = Sr.unsqueeze(0) * (1 + self.alpha * R_m.T)  # (m_clusters, rank)
        
        # Calculate enhanced singular values for each domain cluster (N clusters)
        R_n = torch.matmul(Vr_normalized.T, nu_n_normalized.T)  # (rank, n_clusters)
        n_clusters = R_n.shape[1]
        Sr_enhanced_n = Sr.unsqueeze(0) * (1 + self.alpha * R_n.T)  # (n_clusters, rank)
        
        # Initialize SemAdapt-LoRA A_m and B_n matrix lists (Step 4)
        self.lora_SemAdapt_A = nn.ParameterList()
        self.lora_SemAdapt_B = nn.ParameterList()
        
        # Create M A matrices
        for i in range(m_clusters):
            sqrt_Sr = torch.diag(torch.sqrt(Sr_enhanced_m[i]))
            A_m_i = (sqrt_Sr @ Uhr) / m_clusters  # (rank, in_dim), divided by M
            self.lora_SemAdapt_A.append(nn.Parameter(A_m_i))
        
        # Create N B matrices
        for i in range(n_clusters):
            sqrt_Sr = torch.diag(torch.sqrt(Sr_enhanced_n[i]))
            B_n_i = (Vr @ sqrt_Sr) / n_clusters  # (out_dim, rank), divided by N
            self.lora_SemAdapt_B.append(nn.Parameter(B_n_i))
        
        # Adjust base weights: accumulate product of all B_n_i @ A_m_i
        sum_term = torch.zeros_like(weight.data)
        for a in self.lora_SemAdapt_A:
            for b in self.lora_SemAdapt_B:
                sum_term += b @ a
        
        weight = weight.data - self.scaling * sum_term

        # Restore original data type and update parameters
        weight = weight.to(dtype)
        self.weight.data = weight

    def train(self, mode: bool = True):
        """
        Switch model between training/evaluation modes
        
        Parameters:
        - mode: True for training mode, False for evaluation mode
        
        Handles weight merging during mode switching:
        - Training mode: separate weights if already merged
        - Evaluation mode: merge weights if not merged to accelerate inference
        """
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)

        self.training = mode
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_UniRep_B @ self.lora_UniRep_A) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_UniRep_B @ self.lora_UniRep_A) * self.scaling
                self.merged = True


    def forward(self, x, prompt_emb=None, prompt_embed_mask=None, hidden_attention_mask=None, global_prompt=None, p_bias=None):
        """
        Dual LoRA forward propagation process
        
        """
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        if self.r > 0 and not self.merged:
            # 1. Apply original linear transformation
            result = F.linear(x, T(self.weight), bias=self.bias)
            
            if global_prompt is not None:
                # 2. Calculate LoRA transformation for input features
                result = result + (self.lora_dropout(x) @ self.lora_UniRep_A.transpose(0, 1) @ self.lora_UniRep_B.transpose(0, 1)) * self.scaling
                # 3. Calculate controlled LoRA transformation for global prompt
                # glora_result = (self.lora_dropout(global_prompt) @ self.lora_SemAdapt_A.transpose(0, 1) @ self.lora_SemAdapt_B.transpose(0,1)) * self.scaling
                # Accumulate outputs of all SemAdapt-LoRA matrix combinations
                if self.is_high_layer:
                    # High layer: full collaboration combination
                     # glora_result = 0
                    # for a, b in zip(self.lora_SemAdapt_A, self.lora_SemAdapt_B):
                    #     glora_result += (self.lora_dropout(global_prompt) @ a.transpose(0, 1) @ b.transpose(0, 1)) * self.scaling
                    A_sum = sum(self.lora_SemAdapt_A)  # [r, in_features]
                    B_sum = sum(self.lora_SemAdapt_B)  # [out_features, r]
                    gp = self.lora_dropout(global_prompt)  # [batch, seq, in_features]
                    glora_result = (gp @ A_sum.transpose(0, 1) @ B_sum.transpose(0, 1)) * self.scaling                   
                else:
                    # Low layer: full collaboration combination
                    selected_a, selected_b = self.select_ab_matrices(prompt_emb, tau=self.tau, training=self.training)
                    glora_result = (self.lora_dropout(global_prompt) @ selected_a.transpose(1, 2) @ selected_b.transpose(1, 2)) * self.scaling

                # Use learnable gate parameter
                if self.gate is not None:
                    result = self.gate * result + (1 - self.gate) * glora_result.mean(dim=1).unsqueeze(dim=1)
                else:
                    # If gate parameter not set, use default value 0.5
                    result = result + 0.5 * glora_result.mean(dim=1).unsqueeze(dim=1)
            else:
                # If no global prompt, only add regular LoRA transformation
                result = result + (self.lora_dropout(x) @ self.lora_UniRep_A.transpose(0, 1) @ self.lora_UniRep_B.transpose(0, 1)) * self.scaling

            return result
        else:
            # Evaluation mode: use merged weights and precomputed bias
            # glora_result = (self.lora_dropout(global_prompt) @ self.lora_SemAdapt_A.transpose(0,1) @ self.lora_SemAdapt_B.transpose(0, 1)) * self.scaling
            # if torch.equal(p_bias,0.5*glora_result) == False:
            #     print(1)

            result = F.linear(x, T(self.weight), bias=self.bias)
            result = result + p_bias.mean(dim=1).unsqueeze(dim=1)
            return result

class ControlPrompt(nn.Module):
    def __init__(self, in_proj, out_proj, lora_r, lora_alpha, lora_dropout, bias, args, layer_idx=0):

        super().__init__()
        # Get dataset type from args
        dataset = getattr(args, 'dataset', 'multiwoz')
        # Create dual LoRA linear layer instance
        self.Con = ConLoRALinear(in_proj, out_proj, lora_r, lora_alpha, 
                                lora_dropout=lora_dropout, bias=bias, dataset=dataset, layer_idx=layer_idx)

    def forward(self, x, prompt_emb=None, prompt_embed_mask=None, 
               hidden_attention_mask=None, global_prompt=None, p_bias=None):
        hidden_state = self.Con(x, prompt_emb, prompt_embed_mask, 
                              hidden_attention_mask, global_prompt, p_bias)
        return hidden_state
