from typing import Optional, Union, List, Tuple
from dataclasses import dataclass
import torch
from torch import nn
from torch.autograd import Function
from transformers.models.llama.modeling_llama import LlamaForCausalLM, CausalLMOutputWithPast
from transformers.utils import ModelOutput
from transformers.cache_utils import Cache


class ConsistencyMaskGenerator(nn.Module):
    def __init__(self, hidden_dim, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2)
        self.proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        nn.init.xavier_uniform_(self.proj.weight, gain=0.1)
        nn.init.xavier_uniform_(self.conv.weight, gain=0.1)
        
        self.w1 = nn.Parameter(torch.tensor(0.5))
        self.w2 = nn.Parameter(torch.tensor(0.5))

    def generate_mask(self, H):
        sigma_local = H.var(dim=-1)  
        sigma_local_norm = (sigma_local.max() - sigma_local) / (sigma_local.max() - sigma_local.min() + 1e-4)
        
        H_mean = H.mean(dim=0)  
        delta = (H - H_mean).norm(dim=-1) 
        delta_norm = delta / (delta.max() + 1e-4)
        
        s = self.w1 * sigma_local_norm + self.w2 * delta_norm  
        s_batch = s.unsqueeze(-1).expand(-1, -1, H.size(-1))  
        
        m_logits = self.proj(s_batch)
        m_logits = torch.clamp(m_logits, -1, 1)
        m = torch.sigmoid(m_logits)
        return m    


    def generate_random_mask(self, x):
        shape = x.shape
        device = x.device
        random_float_mask = torch.rand(shape, device=device, dtype=x.dtype)
        return random_float_mask
    
    def forward(self, H):
        mask = self.generate_mask(H)  
        H_bkd = H * mask  
        H_benign = H * (1 - mask)  

        # orthogonality regularization
        H_benign_flat = H_benign.view(-1, H_benign.size(-1))  
        H_bkd_flat = H_bkd.view(-1, H_bkd.size(-1))  
        ortho_loss = torch.norm(torch.matmul(H_benign_flat, H_bkd_flat.t()), p='fro') ** 2 / (H_benign_flat.size(0) * H_bkd_flat.size(0))

        sparse_loss = torch.norm(mask) ** 2 / (mask.size(0) * mask.size(1) * mask.size(2))

        # decouple_loss = 1e-4 * ortho_loss.to(H.device) + sparse_loss.to(H.device)
        decouple_loss = 1e-5 * ortho_loss.to(H.device) + sparse_loss.to(H.device)
        return decouple_loss, H_bkd, H_benign
    
    def forward_w_random_mask(self, H):
        mask = self.generate_random_mask(H) 
        H_bkd = H * mask 
        H_benign = H * (1 - mask)  

        decouple_loss = torch.tensor(0.0,device=H.device,dtype=H.dtype)
        return decouple_loss, H_bkd, H_benign


class DecoupleLLAMA(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        # self.adptive_decouple = AdaptiveDecoupleModule(config.hidden_size, config.hidden_size // 4)
        self.masker = ConsistencyMaskGenerator(config.hidden_size)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        decouple: bool = True,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )
        hidden_states = outputs[0]

        # ortho_loss, H_bkd, H_benign = None, None, None
        loss = None
        if decouple:
            decouple_loss, H_bkd, H_benign = self.masker(hidden_states)

            # # ablation with random mask
            # decouple_loss, H_bkd, H_benign = self.masker.forward_w_random_mask(hidden_states)

            logits = self.lm_head(H_benign[:, -num_logits_to_keep:, :])  # [batch, vocab_size]
            away_logits = self.lm_head(H_bkd[:, -num_logits_to_keep:, :])  # [batch, vocab_size]

            if labels is not None:
                logit_loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
                away_logit_loss = self.loss_function(logits=away_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
                loss = logit_loss - away_logit_loss + decouple_loss
        else:
            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
            if labels is not None:
                loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )   






