import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F

from src.nlp.experiments.base_experiment import BaseExperiment

from tqdm import tqdm
from sklearn.metrics import f1_score

# Data related imports
from torch.utils.data import DataLoader

from src.nlp.dataloaders.tweet_loader import TweetSentimentExtractionDataset
from src.nlp.dataloaders.banking77_loader import Banking77Dataset

import warnings

class ClassificationTask(BaseExperiment):
    def __init__(self, cfg, device, model=None):
        super().__init__(cfg)

        self.device = device
        self.model = model

        # Load datasets
        if cfg.datasets.name == "tweet_sentiment":
            self.train_data = TweetSentimentExtractionDataset("train")
            self.val_data = TweetSentimentExtractionDataset("test")  
            self.test_data = TweetSentimentExtractionDataset("test")

            self.loss = F.cross_entropy
            linear_head = torch.nn.Linear(self.model.hidden_dim, self.train_data.num_classes, dtype=self.model.get_dtype())

        elif cfg.datasets.name == "banking77":
            self.train_data = Banking77Dataset("train")
            self.val_data = Banking77Dataset("test")
            self.test_data = Banking77Dataset("test")

            self.loss = F.cross_entropy
            linear_head = torch.nn.Linear(self.model.hidden_dim, self.train_data.num_classes, dtype=self.model.get_dtype())

        self.num_classes = self.train_data.num_classes

        self.model.change_linear_head(linear_head)

        # Create dataloaders
        self.train_dataloader = DataLoader(self.train_data, batch_size=cfg.learning.batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_data, batch_size=cfg.learning.batch_size, shuffle=False)
        self.test_dataloader = DataLoader(self.test_data, batch_size=cfg.learning.batch_size, shuffle=False)

    def compute_loss(self, logits, labels, **kwargs):
        if logits.shape[1] == 1: # means binary classification
            loss = self.loss(logits, labels.unsqueeze(-1).to(self.device).to(logits.dtype))
        else:
            loss = self.loss(logits, labels.to(self.device))
        return loss

    def evaluate(self, split: str, **kwargs):
        self.model.eval()

        all_predictions = []
        all_outputs = []
        all_losses = []

        if split == "val":
            dataloader = self.val_dataloader
        elif split == "test":
            dataloader = self.test_dataloader
        else:
            raise ValueError("Invalid split")

        with torch.no_grad():
            for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
                text, label = batch
                logits = self.model(text)
                loss = self.compute_loss(logits, label)

                predictions = logits.argmax(dim=-1)

                all_predictions.extend(predictions.cpu().float().numpy())
                all_outputs.extend(label.cpu().float().numpy())
                all_losses.append(loss.cpu().float().numpy())

        # Gather predictions and labels from all processes
        gathered_predictions = [None for _ in range(dist.get_world_size())]
        gathered_outputs = [None for _ in range(dist.get_world_size())]
        gathered_losses = [None for _ in range(dist.get_world_size())]

        dist.all_gather_object(gathered_predictions, all_predictions)
        dist.all_gather_object(gathered_outputs, all_outputs)
        dist.all_gather_object(gathered_losses, all_losses)

        # Only rank 0 computes metrics
        if dist.get_rank() == 0:
            flat_predictions = [item for sublist in gathered_predictions for item in sublist]
            flat_outputs = [item for sublist in gathered_outputs for item in sublist]

            correct = sum([1 for pred, lab in zip(flat_predictions, flat_outputs) if pred == lab])
            accuracy = correct / len(flat_outputs)

            f1 = f1_score(flat_outputs, flat_predictions, average='weighted')
            print(f"Accuracy: {accuracy:.5f}, F1: {f1:.5f}")

            loss = np.array(gathered_losses).mean()

            metrics = {"accuracy": accuracy, "f1": f1, "loss": loss}
        else:
            metrics = None

        # Ensure all processes wait for rank 0 to finish
        dist.barrier()

        return metrics
    
    def finetune_pass(self, batch, **kwargs):
        """
        Single forward and backward pass of the model finetuning for the given task.
        """
        text, label = batch 
        logits = self.model(text)
        loss = self.compute_loss(logits, label)
        return loss