import torch
import torch.distributed as dist

from src.nlp.experiments.base_experiment import BaseExperiment

from tqdm import tqdm
from sklearn.metrics import f1_score

# Data related imports
from torch.utils.data import DataLoader
from src.nlp.dataloaders.hellaswag_loader import HellaswagDataset

import warnings

class HellaswagMatching(BaseExperiment):
    def __init__(self, cfg, device, model=None):
        super().__init__(cfg)

        self.device = device
        self.model = model

        # Load datasets
        self.train_data = HellaswagDataset("train")
        self.val_data = HellaswagDataset("validation")
        self.test_data = HellaswagDataset("validation") # No public test set available

        # Create dataloaders
        self.train_dataloader = DataLoader(self.train_data, batch_size=cfg.learning.batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_data, batch_size=cfg.learning.batch_size, shuffle=False)
        self.test_dataloader = DataLoader(self.test_data, batch_size=cfg.learning.batch_size, shuffle=False)

    def evaluate(self, split: str, **kwargs):
        self.model.eval()

        all_predictions = []
        all_outputs = []

        if split == "val":
            dataloader = self.val_dataloader
        elif split == "test":
            dataloader = self.test_dataloader
        else:
            raise ValueError("Invalid split")

        with torch.no_grad():
            for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
                ctx, endings, label = batch

                ctx_embed = self.model(ctx)  # [batch_size, hidden_dim]
                endings_embed = [self.model(ending) for ending in endings]  # list of [batch_size, hidden_dim]
                endings_embed = torch.stack(endings_embed).permute(1, 0, 2)  # [batch_size, num_endings, hidden_dim]

                sim = torch.nn.functional.cosine_similarity(ctx_embed.unsqueeze(1), endings_embed, dim=-1)
                prediction = sim.argmax(dim=1)

                all_predictions.extend(prediction.cpu().float().numpy())
                all_outputs.extend(label.cpu().float().numpy())

        # Gather predictions and labels from all processes
        gathered_predictions = [None for _ in range(dist.get_world_size())]
        gathered_outputs = [None for _ in range(dist.get_world_size())]

        dist.all_gather_object(gathered_predictions, all_predictions)
        dist.all_gather_object(gathered_outputs, all_outputs)

        if dist.get_rank() == 0:
            flat_predictions = [p for sublist in gathered_predictions for p in sublist]
            flat_outputs = [l for sublist in gathered_outputs for l in sublist]

            correct = sum([1 for pred, lab in zip(flat_predictions, flat_outputs) if pred == lab])
            accuracy = correct / len(flat_outputs)

            f1 = f1_score(flat_outputs, flat_predictions, average='weighted')
            print(f"Accuracy: {accuracy:.5f}, F1: {f1:.5f}")

            metrics = {"accuracy": accuracy, "f1": f1}
        else:
            metrics = None

        # Sync all processes
        dist.barrier()

        return metrics
    
    def finetune_pass(self, batch, **kwargs):
        """
        Single forward and backward pass of the model finetuning for the given task.
        """
        ctx, endings, label = batch 

        ctx_embed = self.model(ctx)
        endings_embed = [self.model(ending) for ending in endings]
        endings_embed = torch.stack(endings_embed).permute(1, 0, 2)

        # Cosine similarity between ctx and each ending
        sim = torch.nn.functional.cosine_similarity(ctx_embed.unsqueeze(1), endings_embed, dim=-1)

        # Use cross entropy loss
        loss = torch.nn.functional.cross_entropy(sim, label.to(self.device))

        return loss