import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F

from src.nlp.experiments.base_experiment import BaseExperiment

from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr

# Data related imports
from torch.utils.data import DataLoader
from src.nlp.dataloaders.stsb_loader import STSBDataset

class STSBSimilarity(BaseExperiment):

    def __init__(self, cfg, device, model=None):
        super().__init__(cfg)

        self.device = device
        self.model = model

        # Load datasets
        self.train_data = STSBDataset("train")
        self.val_data = STSBDataset("validation")
        self.test_data = STSBDataset("test")

        # Create dataloaders
        self.train_dataloader = DataLoader(self.train_data, batch_size=cfg.learning.batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_data, batch_size=cfg.learning.batch_size, shuffle=False)
        self.test_dataloader = DataLoader(self.test_data, batch_size=cfg.learning.batch_size, shuffle=False)

    def forward(self, sentence1, sentence2):
        sentence1_embed = self.model(sentence1)
        sentence2_embed = self.model(sentence2)
        sim = torch.nn.functional.cosine_similarity(sentence1_embed, sentence2_embed, dim=-1)
        return sim

    def compute_loss(self, sim, label):
        return F.mse_loss(sim, label.to(self.device).to(sim.dtype)).to(sim.dtype)

    def evaluate(self, split: str, **kwargs):
        self.model.eval()

        all_predictions = []
        all_outputs = []
        all_losses = []

        if split == "val":
            dataloader = self.val_dataloader
        elif split == "test":
            dataloader = self.test_dataloader
        else:
            raise ValueError("Invalid split")

        with torch.no_grad():
            for sentence1, sentence2, label in tqdm(dataloader, disable=dist.get_rank() != 0):
                sim = self.forward(sentence1, sentence2)
                loss = self.compute_loss(sim, label)

                all_predictions.extend(sim.cpu().float().numpy())
                all_outputs.extend(label.cpu().float().numpy())
                all_losses.append(loss.cpu().float().numpy())

                if self.cfg.debug_data and len(all_predictions) > 10:
                    break

        # Gather predictions and labels from all processes
        gathered_predictions = [None for _ in range(dist.get_world_size())]
        gathered_outputs = [None for _ in range(dist.get_world_size())]
        gathered_losses = [None for _ in range(dist.get_world_size())]

        dist.all_gather_object(gathered_predictions, all_predictions)
        dist.all_gather_object(gathered_outputs, all_outputs)
        dist.all_gather_object(gathered_losses, all_losses)

        # Only rank 0 computes the metrics
        if dist.get_rank() == 0:
            flat_predictions = [item for sublist in gathered_predictions for item in sublist]
            flat_outputs = [item for sublist in gathered_outputs for item in sublist]

            pearson_score = pearsonr(flat_predictions, flat_outputs)[0]
            spearman_score = spearmanr(flat_predictions, flat_outputs)[0]
            loss = np.array(gathered_losses).mean()

            print(f"pearson: {pearson_score:.5f}, spearman: {spearman_score:.5f}")
            metrics = {"pearson": pearson_score, "spearman": spearman_score, "loss": loss}
        else:
            metrics = None

        # Synchronize all processes before returning
        dist.barrier()

        return metrics 

    def finetune_pass(self, batch, **kwargs):
        sentence1, sentence2, label = batch
        sim = self.forward(sentence1, sentence2)
        loss = self.compute_loss(sim, label)
        return loss