import os
from pathlib import Path
import torch
import tqdm
import joblib
import numpy as np
from huggingface_hub import login
from transformers import AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
from torch_pca import PCA
import torch.functional as F

from .base import BrainAlignLanguageModelBase
import time

torch.set_grad_enabled(False)

class BrainAlignLanguageModel(BrainAlignLanguageModelBase):
    def __init__(
        self,
        hugging_face_model_id: str,
        apply_pca: bool,
        pca_components: int,
        device: str,
        model_type: str,
        load_head: bool = False,
    ):
        """Instantiates a pre-trained `HookedTransformer` model. This constructors should
        almost never be called directly. Instead, one should use the factory method
        `BrainAlignTransformer.from_pretrained`.

        Args:
            hugging_face_model_id: HF model ID.
            model_type: either 'encoder', 'decoder', or 'encoder-decoder'
        """
        self.hugging_face_model_id = hugging_face_model_id
        self.apply_pca = apply_pca
        self.pca_components = pca_components
        self.model_type = model_type
        self.device = device

        cache_dir = Path("./.hf_cache")
        if not os.path.isdir(cache_dir):
            os.makedirs(cache_dir)
        
        kwargs = {}
        # Use 8-bit precision to fit Jamba into GPU when using long context lengths
        if "Jamba" in self.hugging_face_model_id:
            kwargs["torch_dtype"] = torch.float16
            kwargs["attn_implementation"] = "flash_attention_2"
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
        elif "Mixtral" in self.hugging_face_model_id or "Mistral" in self.hugging_face_model_id:
            kwargs["device_map"] = "auto"
            kwargs["torch_dtype"] = torch.float16
            kwargs["attn_implementation"] = "flash_attention_2"
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
        elif "llama" in self.hugging_face_model_id:
            kwargs["attn_implementation"] = "flash_attention_2"
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True)
        elif "qwen" in self.hugging_face_model_id:
            kwargs["device_map"] = "auto"
            kwargs["attn_implementation"] = "flash_attention_2"
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
        elif "falcon-11B" in self.hugging_face_model_id:
            kwargs["device_map"] = "auto"
            kwargs["torch_dtype"] = torch.float16
        elif "gemma-7B" in self.hugging_face_model_id:
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
        if load_head:
            self.model = AutoModelForCausalLM.from_pretrained(self.hugging_face_model_id, cache_dir=cache_dir, **kwargs)
        else:
            self.model = AutoModel.from_pretrained(self.hugging_face_model_id, cache_dir=cache_dir, **kwargs)
        if not ("Mixtral" in self.hugging_face_model_id) and not ("Mistral" in self.hugging_face_model_id) and not ("falcon-11B" in self.hugging_face_model_id):
            self.model.to(device)
        self.model.eval()
        self.model.gradient_checkpointing_enable()

        self.per_layer_pca = {}

        # Initialize layer_idx for attribution purposes
        self.layer_idx = None
        self.token_idxs_to_avg = None
        self.fold_loss, self.fold_count = 0, 0

    def __str__(self):
        return self.model.__str__()

    def get_embedding_layer(self):
        if "mamba" in self.hugging_face_model_id:
            return self.model.embeddings
        elif "longt5" in self.hugging_face_model_id:
            return self.model.shared
        elif "gpt2" in self.hugging_face_model_id:
            return self.model.wte # model.model.encoder.embed_tokens for the ones in the encoder
        elif "falcon" in self.hugging_face_model_id and self.hugging_face_model_id != "falcon3":
            return self.model.word_embeddings
        else:
            return self.model.embed_tokens

    @torch.inference_mode
    def get_layer_representations(
        self,
        token_ids: torch.LongTensor,
        attention_mask: torch.LongTensor,
    ) -> torch.Tensor:
        """Gets the hidden representations of `tokens` in the model at every layer. For
        a transformer model, we define each layer as after the residuals have been added
        to the MLP outputs.

        Args:
            tokens: Indices of the input sequence tokens in the vocabulary. It is assumed
                that this will have dimension :math:`\text{batch}\times\text{num_tokens}`.
        Return:
            A tuple where the first entry is the logits of the model and the second entry is
            a torch tensor that has dimension ``layers x [num_examples, d_model]`` where
            ``d_model`` is the dimension of the embeddings. If ``rpre`` or ``rpost`` is specified, then
            this is called on all of the hidden layers (by-batch) before being returned.
        """

        assert len(token_ids) > 0

        if self.model_type == 'encoder' or self.model_type == 'decoder':
            outputs = self.model(token_ids, attention_mask=attention_mask, output_hidden_states=True)
            hidden_states = outputs['hidden_states'][1:]
        elif self.model_type == 'encoder-decoder':
            outputs = self.model(token_ids, decoder_input_ids=token_ids, attention_mask=attention_mask, output_hidden_states=True)
            encoder_hidden_states = outputs['encoder_hidden_states'][1:]
            decoder_hidden_states = outputs['decoder_hidden_states'][1:]
            hidden_states = encoder_hidden_states + decoder_hidden_states
        all_hidden_states = torch.stack(hidden_states)
        
        return all_hidden_states
    
    @torch.inference_mode
    def extract_word_embeddings(self, data_loader, token_idxs_to_avg, batch_size, verbose=False):        
        layer_embeddings = []
        for batch_idx, (batch_token_ids, batch_attention_mask) in enumerate(tqdm.tqdm(data_loader, disable=not verbose)):
            batch_token_ids = batch_token_ids.to(self.device, non_blocking=True)
            batch_attention_mask = batch_attention_mask.to(self.device, non_blocking=True)
            batch_layer_embeddings = self.get_layer_representations(batch_token_ids, batch_attention_mask)
            
            for b_idx, c_idx in enumerate(range(batch_idx*batch_size, batch_idx*batch_size+batch_layer_embeddings.shape[1])):
                layer_embeddings.append(torch.mean(batch_layer_embeddings[:, b_idx, token_idxs_to_avg[c_idx], :], dim=1))
        
        layer_embeddings = torch.stack(layer_embeddings, dim=1)  # (num_layers, num_contexts, hidden_dim)

        return layer_embeddings

    def aggregate_tr_embeddings(self, word_embeddings, tr_to_word_idxs, verbose=False):        
        # Average embeddings of words in the same TR
        if verbose:
            print("     2b - Aggregate TR's word embeddings.")
        tr_embeddings = torch.zeros((word_embeddings.shape[0], len(tr_to_word_idxs.keys()), word_embeddings.shape[2]), device=self.device)
        for tr_idx, tr in enumerate(tr_to_word_idxs.keys()):
            tr_embeddings[:, tr_idx, :] = torch.mean(word_embeddings[:, tr_to_word_idxs[tr], :], dim=1)
        
        return tr_embeddings

    def extract_model_embeddings(self, data_loader, token_idxs_to_avg, tr_to_word_idxs, batch_size, experiment_dir, verbose=False, story_idx=None):
        # Create PCA weights directory
        if self.apply_pca:
            pca_weights_dir = experiment_dir / 'pca_weights'
            os.makedirs(pca_weights_dir, exist_ok=True)

        # Extract model embeddings
        start_time = time.time()
        word_embeddings = self.extract_word_embeddings(data_loader, token_idxs_to_avg, batch_size, verbose)
        print("Word embeddings shape: ", word_embeddings.shape)
        tr_embeddings = self.aggregate_tr_embeddings(word_embeddings, tr_to_word_idxs, verbose)

        # Apply PCA
        if self.apply_pca:
            num_layers, _, _ = tr_embeddings.shape
            reduced_trs = []
            for layer in range(num_layers):
                layer_trs = torch.nan_to_num(tr_embeddings[layer])  # (num_TRs, dim)
                layer_trs = (layer_trs - layer_trs.mean(axis=0)) / (layer_trs.std(axis=0) + 1e-10) # zscore
                pca = PCA(n_components=self.pca_components, svd_solver='full', random_state=42)
                reduced = pca.fit_transform(layer_trs) # (num_TRs, pca_components)
                reduced_trs.append(reduced.detach().cpu())
                self.per_layer_pca[layer] = pca
                joblib.dump(pca, experiment_dir / f'pca_weights/pca_layer{layer}.joblib')
            tr_embeddings = torch.stack(reduced_trs, dim=0)  # (num_layers, num_TRs, pca_components)
        
        print(f"TR embedding shape: {tr_embeddings.shape}")
        end_time = time.time()
        print(f"Time taken: {end_time - start_time:.2f} seconds")

        # Save TR embeddings
        tr_embeddings = tr_embeddings.detach().cpu().numpy()  # (num_layers, num_TRs, pca_components)
        if story_idx is not None:
            np.save(experiment_dir / f'aggregated_embeddings_story_{story_idx}.npy', tr_embeddings)
        else:
            np.save(experiment_dir / 'aggregated_embeddings.npy', tr_embeddings)
        return tr_embeddings
    
    def load_pca_weights(self, experiment_dir):
        # Load PCA weights
        pca_weights_dir = experiment_dir / 'pca_weights'
        if os.path.isdir(pca_weights_dir):
            for layer in range(len(self.per_layer_pca)):
                self.per_layer_pca[layer] = joblib.load(pca_weights_dir / f'pca_layer{layer}.joblib')
        else:
            print("PCA weights not found.")
    
    def set_attribution_params(self, layer_idx, token_idxs_to_avg):
        self.layer_idx = layer_idx
        self.token_idxs_to_avg = token_idxs_to_avg
    
    def extract_layer_aggregated_representation(self, token_ids, target_idx):
        assert len(token_ids) > 0

        if self.model_type == 'encoder' or self.model_type == 'decoder':
            outputs = self.model(token_ids, output_hidden_states=True)
            layer_embeddings = outputs['hidden_states'][1:][self.layer_idx][0]
        elif self.model_type == 'encoder-decoder':
            outputs = self.model(token_ids, decoder_input_ids=token_ids, output_hidden_states=True)
            encoder_hidden_states = outputs['encoder_hidden_states'][1:]
            decoder_hidden_states = outputs['decoder_hidden_states'][1:]
            hidden_states = encoder_hidden_states + decoder_hidden_states
            layer_embeddings = hidden_states[self.layer_idx][0] # (num_tokens, hiddend_dim)

        # Average token embeddings for the tokens in the context's last word and get activation of target_idx neuron
        word_embeddings = torch.mean(layer_embeddings[self.token_idxs_to_avg,:], dim=0)[target_idx]
        word_embeddings = torch.unsqueeze(word_embeddings, 0) # (1,)

        return word_embeddings
    
    def forward_next_word_prediction(self, token_ids, attention_mask, next_word_lens, pad_token_id):
        """
            Predict *all* tokens of the next word.

            For sample i:
            * context  = token_ids[i, : L_i - n_i ]
            * targets  = token_ids[i, L_i - n_i : L_i]
            where n_i = next_word_lens[i].

            Returns
            -------
            ce_per_sample : torch.Tensor, shape (B,)
                Mean cross‑entropy over the n_i target tokens of each sample.
        """
        B, L = token_ids.shape
        device   = token_ids.device
        vocab_sz = self.model.config.vocab_size

        # 1. Forward
        logits = self.model(input_ids=token_ids,
                            attention_mask=attention_mask).logits
        
        # 2. Build label matrix aligned with shifted logits  (positions 1..L‑1 are predicted)
        #    labels[:, t] is the *token to predict at position t*
        labels = torch.full_like(token_ids, pad_token_id)
        seq_lens = attention_mask.sum(dim=1) # number of non‑padding tokens
        for i in range(B):
            n = next_word_lens[i].item()
            if n == 0:
                continue
            end = seq_lens[i].item()
            start = end - n
            labels[i, start:end] = token_ids[i, start:end]
        
        # 3. Shift (standard teacher‑forcing)
        shift_logits  = logits[:, :-1, :].contiguous()
        shift_labels  = labels[:, 1: ].contiguous()
        
        # 4. Per‑token CE (ignoring padding)
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id, reduction='none')
        loss_per_tok = loss_fct(
            shift_logits.view(-1, vocab_sz),
            shift_labels.view(-1)
        ).view(B, L - 1)

        # 5. Average only over the n_i next‑word tokens
        ce_per_sample = torch.empty(B, device=device)
        for i in range(B):
            n    = next_word_lens[i].item()
            if n == 0:
                ce_per_sample[i] = 0.0
                continue

            end  = seq_lens[i].item() # last real token idx  = end-1
            start = end - n # first next‑word token idx

            # positions in *shift* tensors are [start‑1  …  end‑2]
            start_shift = start - 1
            end_shift   = end - 1

            ce_per_sample[i] = loss_per_tok[i, start_shift:end_shift].mean()

        # 6. Bookkeeping for fold‑level stats
        self.fold_loss  += ce_per_sample.sum().item()
        self.fold_count += B

        return ce_per_sample

