import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
from peft import PeftModel
import os
import torch
from tqdm import tqdm

@dataclass
class TTMMOutput:
    logits: torch.Tensor

class TTMM(nn.Module):
    def __init__(
        self,
        corpus: torch.Tensor, 
        tokenizer=None, 
        device=torch.device("cpu"), 
        max_merge_count=50, 
        verbose=False, 
        encoder=None, 
        rbf_length_scale=0.2, 
        rbf_threshold=0.01, 
        prefix_length=50, 
        adapter_location="/path/to/adapters", 
        base_model_name="meta-llama/Llama-3.2-1B", 
        null_adapter="/path/to/null_adapter"
    ):
        """
        Initializes the TTMM (Test-Time Model Merging) class.

        This class merges specialized language models based on topic similarity.

        Parameters:
            tokenizer: Tokenizer for the language model.
            device (torch.device, optional): Device to run the model on (default: "cpu").
            max_merge_count (int, optional): Maximum number of models to merge (default: 50).
            verbose (bool, optional): Whether to enable verbose logging (default: False).
            encoder (SentenceTransformer, optional): Encoder for text embeddings (default: 'all-mpnet-base-v2').
            rbf_length_scale (float, optional): Length scale for RBF (beta) (default: 0.2).
            rbf_threshold (float, optional): Similarity threshold for RBF (sparsity) (default: 0.01).
            prefix_length (int, optional): Prefix length used for generating query embeddings (default: 50).
            corpus (torch.Tensor): Embeddings of model clusters with shape (n_clusters, embedding_dim), 
                where each row represents a cluster embedding.
            adapter_location (str): Path to the directory containing expert model adapters. 
                Expected structure:
                    ├── adapter_location/
                    │   ├── 0/
                    │   │   ├── adapter_model.safetensors
                    │   ├── 1/
                    │   │   ├── adapter_model.safetensors
                    │   ├── ...
            base_model_name (str, optional): Name of the base model (default: "meta-llama/Llama-3.2-1B").
            null_adapter (str): Path to an untrained adapter used for initialization.
        """
        super().__init__()
        self.corpus = corpus.cpu()
        self.max_merge_count = max_merge_count
        self.device = device
        self.tokenizer = tokenizer if tokenizer else AutoTokenizer.from_pretrained(self.modelName)
        self.encoder = encoder if encoder else SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
        self.modelName = base_model_name
        self.baseModel = AutoModelForCausalLM.from_pretrained(self.modelName, attn_implementation="flash_attention_2", device_map="auto")
        self.config = self.baseModel.config
        self.tie_weights = self.baseModel.tie_weights
        self.peft_model = PeftModel.from_pretrained(
            self.baseModel,
            model_id=null_adapter,
            adapter_name="initZero"
        )
        self.peft_model.eval()
        self.verbose = verbose
        self.rbf_length_scale = rbf_length_scale
        self.rbf_threshold = rbf_threshold
        self.prefix_length = prefix_length
        self.adapterLocation = adapter_location

    def _load_merged_adapter(self, x, merged_adapter_name):
        if x.shape[0] != 1:
            raise ValueError("Batch size must be 1. Please provide input with a batch size of 1.")
        
        #Compute the query embedding
        query = self.tokenizer.decode(x[0][:self.prefix_length], skip_special_tokens=True)
        query_embedding = self.encoder.encode(query, convert_to_tensor=True, show_progress_bar=False).unsqueeze(0).to(self.device)
        query_embedding = util.normalize_embeddings(query_embedding).cpu()
        query_embedding = query_embedding.to(self.device)

        #Compute the rbf kernel values between the query and all the cluster's embedding vectors
        corpus_tensor = self.corpus.to(self.device)
        similarities = torch.mm(query_embedding, corpus_tensor.T).squeeze(0)
        distances = 1 - similarities
        kernel_values = torch.exp(-(1.0/((self.rbf_length_scale**2))) * distances)

        #Normalize the kernel values and select the models above the threshold
        normalized_weights = kernel_values / kernel_values.sum()
        selected_mask = normalized_weights > self.rbf_threshold


        if selected_mask.sum() == 0:
            #Handle case where nans are present in the normalized weights due to numerical instability
            if torch.isnan(normalized_weights).any():
                if self.verbose:
                    print("NANs in normalized weights, falling back to k=1", flush=True)
                    print("FALLBACK_TO_NN", flush=True)
                # Get the closest corpus vector to the query embedding
                query_embedding = query_embedding.cpu()
                similarities = torch.mm(query_embedding, self.corpus.T).squeeze(0)
                selected_clusters = torch.argmax(similarities).item()
                if isinstance(selected_clusters, int):
                    selected_clusters = [selected_clusters]
                weights = [1.0]
            else:
            #Handle case where no models are above the threshold falling back to base model
                if self.verbose: 
                    print("NO MODEL ABOVE THRESH, falling back to base model", flush=True)
                    print("FALLBACK_TO_BASE_MODEL", flush=True)
                self.peft_model.set_adapter("initZero")
                print(f"Unique Adapter count: 0", flush=True)
                return False
        else:
            if selected_mask.sum() > self.topK:
                #Handle case where the number of models above the threshold is greater than the max_merge_count
                if self.verbose: 
                    print("Too many models above threshold, falling back to max_merge_count", flush=True)
                    print("FALLBACK_TO_MAX_MERGE_COUNT", flush=True)
                _, selected_indices = torch.topk(kernel_values, self.max_merge_count)
                selected_mask = torch.zeros_like(normalized_weights, dtype=torch.bool)
                selected_mask[selected_indices] = True

            selected_clusters = torch.where(selected_mask)[0].tolist()
            selected_weights = normalized_weights[selected_mask]
            if self.verbose: 
                print(f"Selected pre-norm weights: {selected_weights}", flush=True)
            
            #Renormalize the weights
            renormalized_weights = selected_weights / selected_weights.sum()
            weights = renormalized_weights.tolist()

        adapters = []
        for cluster_num in tqdm(selected_clusters, disable=not self.verbose):
            complete_path = os.path.join(self.adapterLocation, f'{cluster_num}')
            name = str(cluster_num)
            if name not in adapters:
                self.peft_model.load_adapter(complete_path, adapter_name=name, load_as="adapter")
            adapters.append(name)

     

        if self.verbose:
            print(f"Unique adapters: {adapters}", flush=True)
            print(f"Unique Adapter count: {len(adapters)}", flush=True)
            print(f"Weights: {weights}", flush=True)

        self.peft_model.add_weighted_adapter(adapters, weights, merged_adapter_name,  combination_type="cat")
        self.peft_model.set_adapter(merged_adapter_name)
        self.peft_model.eval()

        deleted=[]
        for adapter in adapters:
            if adapter not in deleted:
                self.peft_model.delete_adapter(adapter)
                deleted.append(adapter)

        return True

        

    def forward(self, x : torch.Tensor):
        """
        Forward pass for the TTMM model. Processes input by loading and merging the necessary expert adapter(s).
        It evaluates the model without gradient calculation and returns the resulting logits. 
        The merged adapter is deleted after use to free up memory.
        Args:
            x: Input tensor containing token IDs to be processed by the model
        Returns:
            TTMMOutput: An object containing the output logits from the model evaluation
        """


        merged_adapter_name = "merged"
        has_merged = self._load_merged_adapter(x, merged_adapter_name)

        with torch.no_grad():
            self.peft_model.eval()
            outputs = self.peft_model(input_ids=x)
            average_logits = outputs.logits

        if has_merged:
            self.peft_model.delete_adapter(merged_adapter_name)
        
        torch.cuda.empty_cache()
        
        return TTMMOutput(logits=average_logits)
    

    def generate(self, x, max_length, do_sample):
        """
        Generates text using the loaded adapter configuration.
        
        This method loads an appropriate merged adapter based on the input context,
        then uses the configured model to generate text up to the specified max_length.
        After generation, it cleans up any temporarily created adapters to free memory.
        
        Args:
            x: Input tensor containing token IDs that serve as the generation prompt
            max_length: Maximum length of the generated sequence (including prompt)
            do_sample: Whether to use sampling for generation (True) or greedy decoding (False)
            
        Returns:
            torch.Tensor: Generated token IDs
        """
        merged_adapter_name = "merged"
        has_merged = self._load_merged_adapter(x, merged_adapter_name)

        self.peft_model.eval()
        outputs = self.peft_model.generate(input_ids=x, max_length=max_length, do_sample=do_sample)

        if has_merged:
            self.peft_model.delete_adapter(merged_adapter_name)
        torch.cuda.empty_cache()

        return outputs