import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from src.nanoGPT.utils import pad_sequence
from src.nlp.experiments.base_experiment import BaseExperiment

from tqdm import tqdm
import warnings

from src.nlp.models.DummyDistributed import DummyDistributed

# Import dataloader
from src.nlp.dataloaders.tinystories_loader import TinyStoriesDataset
from src.nlp.finetune_heads.PoolingModule import CausalAggregator

class NextTokenPrediction(BaseExperiment):
    def __init__(self, cfg, device, model=None):
        super().__init__(cfg)

        self.model = model # for this experiment model is included at initialisation

        # Initialize a causal aggregator
        if self.cfg.pooling in ["sum", "avg", "attention_pool", "weighted_avg", "max"]:
            kwargs = {
                "attn_pool": self.model.pooling.attention_pooling if self.cfg.pooling == "attention_pool" else None,
                "weighted_avg_pool": self.model.pooling.weighted_average_pooling if self.cfg.pooling == "weighted_avg" else None,
            }
            self.causal_aggregator = CausalAggregator(self.cfg.pooling, **kwargs)


        self.distributed = False
        self.device = device

        # Load data
        self.train_data= TinyStoriesDataset("train", tokenizer=self.model.tokenizer, cfg=cfg)
        self.val_data= TinyStoriesDataset("validation", tokenizer=self.model.tokenizer, cfg=cfg)
        self.test_data= TinyStoriesDataset("test", tokenizer=self.model.tokenizer, cfg=cfg)

        # Load dataloader
        self.train_dataloader = DataLoader(self.train_data, batch_size=cfg.learning.batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_data, batch_size=cfg.learning.batch_size, shuffle=False)
        self.test_dataloader = DataLoader(self.test_data, batch_size=cfg.learning.batch_size, shuffle=False)

    def _model_forward(self, x, target_idx):
        """
        Forward pass of the model.
        """
        x_tokenized = self.model.module._tokenize(x)


        B, N = x_tokenized["input_ids"].shape
        # Get target tokens from indices
        target_tokens = [x_tokenized["input_ids"][i, idx.item()] for i, idx in enumerate(target_idx)]
        target_tokens = torch.stack(target_tokens).to(self.device)

        # Cut off target tokens from input and attention mask
        trimmed_input = [x_tokenized["input_ids"][i, :idx.item()] for i, idx in enumerate(target_idx)]
        # Pad trimmed input and attention mask
        pad_value = self.model.module.tokenizer.pad_token_id if self.cfg.backbone not in ["NanoGPT", "GPT2", "GPT2-Custom"] else self.model.module.eos_token
        # x_tokenized["input_ids"] = pad_sequence(trimmed_input, pad_token=pad_value, target_length=self.cfg.datasets.max_length)
        # x_tokenized["attention_mask"] = pad_sequence(trimmed_attention_mask, pad_token=0, target_length=self.cfg.datasets.max_length)
        x_tokenized = [pad_sequence(x_i.tolist(), pad_token=pad_value, target_length=self.cfg.datasets.max_length-1, side=self.cfg.padding_side)
                          for x_i in trimmed_input]
        # Function uses 'attn_mask' as key for attention mask
        input_ids = torch.stack([t['input_ids'] for t in x_tokenized])
        attention_mask = torch.stack([t['attn_mask'] for t in x_tokenized])
        x_tokenized = {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }


        # Move to device
        x_tokenized = {k: v.to(self.device) for k, v in x_tokenized.items()}

        # Get the hidden states
        x_embed = self.model.module.backbone(**x_tokenized)
        if self.cfg.backbone not in ["NanoGPT", "GPT2", "GPT2-Custom"]:
            x_embed = x_embed.hidden_states[-1]
        else:
            x_embed = x_embed.to(torch.bfloat16)
        x_embed = self.model.module.pooling(x_embed, attention_mask=x_tokenized["attention_mask"])

        # Get logits
        logits = torch.matmul(x_embed, self.model.module.lm_head.weight.T)
        
        return logits, target_tokens

    def _model_forward_oom(self, batch):
        # Note: embeddings, attention_mask, and target_ids are already shifted appropriately
        # during the precomputation so we don't have to do it here.
        # The target_ids have also been filled with -100 targets for any masked positions.
        embeddings, attention_mask, target_ids, target_indices = batch

        embeddings = embeddings.to(self.model.device)
        attention_mask = attention_mask.to(self.model.device)
        target_ids = target_ids.to(self.model.device)
        target_indices = target_indices.to(self.model.device)

        embeddings = self.causal_aggregator(embeddings, padding_mask=attention_mask)

        logits = torch.matmul(embeddings, self.model.module.lm_head.weight.T)
        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)

        # TODO: Why are we not computing the metrics for all positions?
        # Select the logits corresponding to the target indices used in evaluation.
        b, _, _ = embeddings.shape
        logits = logits[torch.arange(b), target_indices, :]
        target_ids = target_ids[torch.arange(b), target_indices]

        # Flatten.
        logits = logits.view(-1, logits.size(-1))
        target_ids = target_ids.view(-1)

        return logits, target_ids

    def enable_gradient(self):
        """
        Put require gradient on the input embedding matrix
        """
        self.model.module.linear_head.weight.requires_grad = False # Disable gradient for linear head, as its not used
        self.model.module.linear_head.bias.requires_grad = False
        self.model.module.lm_head.weight.requires_grad = True
        # Count how many parameters require gradient
        n_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Number of parameters requiring gradient: {n_params}")


    def evaluate(self, split: str, finetuned=False, oom_case: bool = False, **kwargs):
        self.model.eval()

        # Initialize accumulators
        total_samples = 0
        correct_topk = 0
        mrr_sum = 0.0
        counter = 0  # DEBUG MODE

        if split == "val":
            dataloader = self.val_dataloader
        elif split == "test":
            dataloader = self.test_dataloader
        else:
            raise ValueError("Invalid split")

        with torch.no_grad():
            for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
                counter += 1

                if oom_case:
                    logits, target_tokens = self._model_forward_oom(batch)
                else:
                    words, target_idx = batch
                    logits, target_tokens = self._model_forward(words, target_idx)

                k = 10
                topk = torch.topk(logits, k, dim=-1).indices  # [B, k]

                hits = (topk == target_tokens.unsqueeze(-1))  # [B, k]
                batch_hits = hits.any(dim=-1).float()         # [B]
                batch_size = target_tokens.shape[0]

                # Compute MRR
                hit_indices = torch.where(hits)
                if len(hit_indices[0]) > 0:
                    ranks = hit_indices[1] + 1  # 1-based ranks
                    batch_mrr = (1 / ranks.float()).mean()
                else:
                    batch_mrr = torch.tensor(0.0)

                total_samples += batch_size
                correct_topk += batch_hits.sum().item()
                mrr_sum += batch_mrr.item() * batch_size

                if self.cfg.debug_data and counter == 10:
                    break

        # Gather results across processes
        gathered_topk = [None for _ in range(dist.get_world_size())]
        gathered_mrr = [None for _ in range(dist.get_world_size())]
        gathered_samples = [None for _ in range(dist.get_world_size())]

        dist.all_gather_object(gathered_topk, correct_topk)
        dist.all_gather_object(gathered_mrr, mrr_sum)
        dist.all_gather_object(gathered_samples, total_samples)

        if dist.get_rank() == 0:
            total_correct_topk = sum(gathered_topk)
            total_mrr_sum = sum(gathered_mrr)
            total_count = sum(gathered_samples)

            topk_accuracy = total_correct_topk / total_count
            mrr = total_mrr_sum / total_count

            print(f"Top-{k} Accuracy: {topk_accuracy:.4f}")
            print(f"MRR: {mrr:.4f}")

            metrics = {"topk_accuracy": topk_accuracy, "mrr": mrr}
        else:
            metrics = None

        dist.barrier()
        return metrics
    
    def finetune_pass(self, batch, oom_case: bool = False):
        if oom_case:
            return self.finetune_pass_oom(batch)

        words, _ = batch
        tokenized = self.model.module._tokenize(words)
        targets = tokenized["input_ids"][:, 1:] # Shift targets by one position
        targets = targets.to(self.device)
        target_mask = tokenized['attention_mask'][:, 1:] # Shift attention mask by one position
        tokenized['input_ids'] = tokenized['input_ids'][:, :-1]  # Remove last token from input
        tokenized['attention_mask'] = tokenized['attention_mask'][:, :-1]


        # Move to device
        tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
        target_mask = target_mask.to(self.device)

        encoded_output = self.model.module.backbone(**tokenized)
        if self.cfg.backbone not in ["NanoGPT", "GPT2", "GPT2-Custom"]:
            encoded_output = encoded_output.hidden_states[-1]
        else:
            encoded_output = encoded_output.to(torch.bfloat16) # the Wrapper returns float32
        if self.cfg.pooling in ["sum", "avg", "attention_pool", "weighted_avg", "max"]:

            encoded_output = self.causal_aggregator(
                encoded_output, 
                padding_mask=tokenized['attention_mask']
            )

        if self.cfg.svd.add_to_pooling:
            singular_pool = self.model.module.singular_pooler(encoded_output, attention_mask=tokenized['attention_mask'])
            assert singular_pool.shape == encoded_output.shape, "Singular pooling output must have same shape as pooled output."
            encoded_output = encoded_output + singular_pool

        # Make sure encoded_output is same dtype as lm_head
        encoded_output = encoded_output.to(self.model.module.lm_head.weight.dtype)

        # Pooled representations done, get logits
        logits = torch.matmul(encoded_output, self.model.module.lm_head.weight.T)
        # Push nans to zero
        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)

        # Compute loss (cross-entropy)
        # Set target to -100 where attention mask is 0
        targets = targets.masked_fill(target_mask == 0, -100)
        loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1))
        return loss

    def finetune_pass_oom(self, batch):
        # Note: embeddings, attention_mask, and target_ids are already shifted appropriately
        # during the precomputation so we don't have to do it here.
        # The target_ids have also been filled with -100 targets for any masked positions.
        embeddings, attention_mask, target_ids, _ = batch
        
        embeddings = embeddings.to(self.model.device)
        attention_mask = attention_mask.to(self.model.device)
        target_ids = target_ids.to(self.model.device)

        embeddings = self.causal_aggregator(embeddings, padding_mask=attention_mask)

        logits = torch.matmul(embeddings, self.model.module.lm_head.weight.T)
        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
        logits = logits.view(-1, logits.size(-1))
        
        target_ids = target_ids.reshape(-1)
        
        loss = torch.nn.functional.cross_entropy(logits, target_ids)
        
        return loss