import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.nn.init as init
import gc

from typing import Dict, Any, Optional, Tuple, Union

from transformers import (
    AutoConfig, AutoModel, AutoTokenizer
)


def clean_gpus() -> None:
    gc.collect()
    torch.cuda.empty_cache() 
clean_gpus()

hugging_token = "Fill in"



class LMProtoNet(nn.Module):
    def __init__(self, backbone: nn.Module,
                 num_labels: int = 2,
                 num_protos_per_class: int = 5,
                 init_prototypes: torch.Tensor = None,
                 baseline=True,
                ):
        super().__init__()

        self.baseline = baseline

        self.backbone = backbone
        self.num_labels = num_labels
        self.num_total_prototypes = num_protos_per_class * num_labels
        self.device = next(self.backbone.parameters()).device
        latent_size = self.backbone.latent_size

        if not self.backbone.no_llm_head and self.backbone.model_type == 'llm':
            latent_size=self.backbone.prototype_dim
        
        if self.backbone.model_type == 'bert':
            prototype_dim = latent_size
        elif self.backbone.model_type == 'llm':
            if not self.backbone.no_llm_head:
                prototype_dim = self.backbone.prototype_dim
            else:
                prototype_dim = latent_size
        else:
            raise NameError('Wrong model_type in LMProtoNet')


        # ----- prototypes -----
        if init_prototypes is not None:
            # keep user-provided prototypes exactly as they are
            assert init_prototypes.shape == (self.num_total_prototypes, prototype_dim)
            self.prototypes = nn.Parameter(init_prototypes.to(self.device))
        else:
            # Xavier-uniform instead of torch.rand
            proto_tensor = torch.empty(
                self.num_total_prototypes, prototype_dim, device=self.device
            )
            init.xavier_uniform_(proto_tensor, gain=1.0)   # or xavier_normal_
            self.prototypes = nn.Parameter(proto_tensor)

        # ----- linear classifier over prototype activations -----
        self.classfn_model = nn.Linear(self.num_total_prototypes, num_labels, bias=False)
        init.xavier_uniform_(self.classfn_model.weight, gain=1.0)
        self.classfn_model.to(self.device)


    def forward(self, input_ids=None, attention_mask=None, llm_encodings=None, forward_type='train'):

        rep = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            llm_encodings=llm_encodings,
            forward_type=forward_type,
        )
        
        cls_rep_norm = F.normalize(rep, p=2, dim=1)              # [B, H]
        proto_norm   = F.normalize(self.prototypes, p=2, dim=1)  # [P, H]

        # Cosine similarities
        loss_acts = 1 - (cls_rep_norm @ proto_norm.T)
        acts      = cls_rep_norm @ proto_norm.T
        
        l_p1 = loss_acts.min(dim=0).values.mean() 
        l_p2 = loss_acts.min(dim=1).values.mean()   

        # Prototype separation (upper-triangular, no diagonal)
        proto_sim = proto_norm @ proto_norm.T
        mask = torch.triu(torch.ones_like(proto_sim, dtype=torch.bool), diagonal=1)
        l_p3 = (1 + proto_sim[mask]).mean()

        # Classification logits
        logits = self.classfn_model(acts)                          # [B, C]
        return {
            "logits": logits,
            "acts": acts,
            "cls_rep_normalized": cls_rep_norm,
            "l_p1": l_p1,
            "l_p2": l_p2,
            "l_p3": l_p3,
        }


class ModelWrapper(nn.Module):
    def __init__(self, model_name: str, latent_dim: int = None, max_length: int = 128, prototype_dim=128, device='cuda:0', no_llm_head=True):
        super().__init__()
        self.model_name = model_name
        self.max_length = max_length
        self.prototype_dim = prototype_dim
        self.device = device
        self.no_llm_head = no_llm_head

        # Load backbone model and tokenizer
        if model_name not in {"bert", "electra", "llama", "qwen", "modern_bert", 'roberta', 'mpnet'}:
            raise ValueError(f"Unsupported model_name: {model_name}")

        # Load backbone model and tokenizer
        if model_name == "bert":
            base = 'bert-base-uncased'
            config = AutoConfig.from_pretrained(base)
            self.config = config
            clean_gpus()
            self.hugging_model = AutoModel.from_pretrained(
                base,
                # config=config,
                token=hugging_token,
            )
            self.hugging_model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(base,token=hugging_token)
            self.model_type = "bert"

            # ── freeze everything first ─────────────────────────────────────────────
            for name, p in self.hugging_model.named_parameters():
                p.requires_grad = False
            
            # ── un-freeze the parts we want to fine-tune ────────────────────────────
            for name, p in self.hugging_model.named_parameters():
                if (name.startswith("encoder.layer.11.") or 
                    name.startswith("pooler.") or 
                    name.startswith("classifier.") or  # Common name for classification head
                    name.startswith("cls.")):          # Alternative name
                    p.requires_grad = True

            self.encoder = partial(CLSWrapper, self.hugging_model)
            self.latent_size = 768

            
        elif model_name == "electra":
            base = "google/electra-base-discriminator"
            config = AutoConfig.from_pretrained(base)
            self.config = config
            clean_gpus()
            self.hugging_model = AutoModel.from_pretrained(
                base,
                # config=config,
                token=hugging_token,
            )
            self.hugging_model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(base,token=hugging_token)
            self.model_type = "bert"

            # Freeze all Electra layers except the last one.
            for name, param in self.hugging_model.named_parameters():
                if "encoder.layer.11" not in name:      # keep only the last block trainable
                    param.requires_grad = False
                    
            self.encoder = partial(CLSWrapper, self.hugging_model)
            self.latent_size = 768


        # RoBERTa
        elif model_name == "roberta":
            base = "FacebookAI/roberta-base"
            config = AutoConfig.from_pretrained(base)
            self.config = config
            clean_gpus()
            self.hugging_model = AutoModel.from_pretrained(
                base,
                # config=config,
                token=hugging_token,
            )
            self.hugging_model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(base,token=hugging_token)
            self.model_type = "bert"

            for name, param in self.hugging_model.named_parameters():
                # keep only the last encoder block (layer 11) and pooler trainable
                if ("encoder.layer.11." in name) or name.startswith("pooler"):
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            self.encoder = partial(CLSWrapper, self.hugging_model)
            self.latent_size = config.hidden_size


            
    def forward(
        self,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        llm_encodings: torch.Tensor = None,
        forward_type: str = 'full',
    ) -> torch.Tensor:
        
        # --- collect_llm_encodings MODE ---
        if forward_type == 'collect_llm_encodings':

            with torch.no_grad():

                if input_ids is None or attention_mask is None:
                    raise ValueError("enc mode requires input_ids and attention_mask")
    
                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
    
                if self.model_name == 'bert' or self.model_name == 'electra':
                    raise TypeError('Do not use BERT model in encoder setting, only for saving data')
                    
                elif self.model_name == 'llama' or self.model_name == 'qwen':
                    return self.llm_encoder(input_ids)
                    
                else:
                    raise NameError('wrong model name')
                
                
        # --- TRAIN MODE ---
        if forward_type == 'train':
            if self.model_type == 'bert':
                return self.encoder(input_ids=input_ids, attention_mask=attention_mask)
            elif self.model_type == 'llm':
                if not self.no_llm_head:
                    return self.trainable_head(llm_encodings)
                else:
                    return llm_encodings
            else:
                raise NameError('wrong model name')

        # --- FULL MODE ---
        if forward_type == 'full':
            return self.encoder(input_ids=input_ids, attention_mask=attention_mask)                
        raise ValueError(f"Unknown forward_type: {forward_type}")




def CLSWrapper(base_model, input_ids: torch.Tensor,attention_mask: torch.Tensor,**kwargs: Dict[str, Any]):
    input_ids = input_ids.to(base_model.device)
    attention_mask = attention_mask.to(base_model.device)
    outputs = base_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        return_dict=True,
        **kwargs
    )
    return outputs.last_hidden_state[:, 0, :]   # [B, H]















