import torch

from torch.utils.data import DataLoader
from utils import collate_fn
from tqdm import tqdm
import pandas as pd
import numpy as np


def get_logits(model, train_dataset):
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    # Create list to store logits
    logits_list = []

    # Ensure model is in evaluation mode
    model.eval()

    with torch.no_grad():
        for idx, batch in enumerate(tqdm(train_dataloader, desc="Computing label quality scores")):
            # Get model outputs
            outputs_chosen = model(input_ids=batch["input_ids_chosen"], attention_mask=batch["attention_mask_chosen"])
            outputs_rejected = model(
                input_ids=batch["input_ids_rejected"], attention_mask=batch["attention_mask_rejected"]
            )

            # Get logits
            logits_chosen = outputs_chosen.logits.squeeze()
            logits_rejected = outputs_rejected.logits.squeeze()

            # Stack logits to compute probabilities
            logits = torch.stack([logits_chosen, logits_rejected])

            # Compute predicted probabilities using softmax
            probs = torch.softmax(logits, dim=0)

            # Store logits
            logits_list.append(
                {
                    "index": idx,
                    "logits_chosen": logits_chosen.cpu().numpy(),
                    "logits_rejected": logits_rejected.cpu().numpy(),
                    "probs_chosen": probs[0].cpu().numpy(),
                    "probs_rejected": probs[1].cpu().numpy(),
                }
            )

    return logits_list


def compute_self_confidence(probs):
    return -probs[0]


def compute_entropy(probs):
    # Compute entropy as the negative sum of p * log(p) across classes
    return -(probs * torch.log(probs)).sum()


def compute_quality_scores(logits_list):
    # Prepare to collect label quality scores
    self_confidence_scores = []
    entropy_scores = []
    indices = []

    for idx, logits in enumerate(logits_list):
        # Extract probabilities
        probs_chosen = (
            torch.tensor(logits["probs_chosen"])
            if isinstance(logits["probs_chosen"], (list, np.ndarray))
            else logits["probs_chosen"]
        )
        probs_rejected = (
            torch.tensor(logits["probs_rejected"])
            if isinstance(logits["probs_rejected"], (list, np.ndarray))
            else logits["probs_rejected"]
        )

        # Stack probabilities to compute scores
        probs = torch.stack([probs_chosen, probs_rejected])

        # Compute various label quality scores
        self_confidence = compute_self_confidence(probs)
        entropy = compute_entropy(probs)

        # Append the results
        self_confidence_scores.append(self_confidence.item())
        entropy_scores.append(entropy.item())
        indices.append(idx)

    # Return the results for further processing
    results = {
        "index": indices,
        "self_confidence": self_confidence_scores,
        "entropy": entropy_scores,
    }

    return results
