import pandas as pd
import torch
from tqdm import tqdm

import torch.nn as nn
from torchmetrics import Metric
from torchmetrics import MeanMetric
import numpy
from torch.nn import KLDivLoss

from defog.metrics.train_metrics import TrainLossDiscrete


class CrossEntropyTracker(Metric):
    def __init__(self, dim: int, min_samples: int, lambda_train, class_weight):
        super().__init__()
        self.dim = dim
        self.loss = TrainLossDiscrete(
            lambda_train=lambda_train,
            label_smoothing=0,
            class_weight=class_weight,
            kld=False,
        )
        self.metrics = torch.nn.ModuleList([MeanMetric() for _ in range(dim)])
        # ASSUMPTION: sweep t evenly from 0 to 1
        self.t_array = torch.linspace(0, 1, steps=dim + 1)[:-1]
        self.min_samples = min_samples
        dt = self.t_array[1] - self.t_array[0]
        last_term = torch.tensor([self.t_array[-1] + dt])
        self.hist_t_array = torch.cat((self.t_array, last_term)).cpu().numpy()
        self.lambda_train = lambda_train

    @property
    def seen_samples(self):
        return self.metrics[0].weight

    def update(self, list_masked_pred_holder, list_masked_true_holder):
        """pred should be in logits, true should be in probabilities."""
        num_samples = list_masked_true_holder[0].X.shape[0]
        ce_array = self.get_loss_array(
            list_masked_pred_holder=list_masked_pred_holder,
            list_masked_true_holder=list_masked_true_holder,
        )
        for i in range(self.dim):
            # add weight because train loss is not averaged per sample and batches may have different size
            self.metrics[i].update(ce_array[i].item(), weight=num_samples)

    def compute(self):
        return torch.tensor([metric.compute() for metric in self.metrics])

    def reset(self):
        for metric in self.metrics:
            metric.reset()

    @torch.no_grad()
    def get_loss_array(
        self,
        list_masked_pred_holder,
        list_masked_true_holder,
    ):
        ce_array = torch.zeros(self.dim)
        for i in tqdm(range(self.dim), desc="CE track - Computing loss array"):
            self.loss.reset()  # zero out the loss
            pred_holder = list_masked_pred_holder[i]
            true_holder = list_masked_true_holder[i]
            ce_array[i] = self.loss.forward(
                masked_pred_X=pred_holder.X,
                masked_pred_E=pred_holder.E,
                pred_y=pred_holder.y,
                true_X=true_holder.X,
                true_E=true_holder.E,
                true_y=true_holder.y,
                weight=None,
                log=False,
            )

        assert (
            not torch.isnan(ce_array).any().item()
        ), f"ce array has nan values at indices: {torch.isnan(ce_array).nonzero(as_tuple=True)[0]}"
        assert (
            not torch.isinf(ce_array).any().item()
        ), f"ce array has inf values at indices: {torch.isinf(ce_array).nonzero(as_tuple=True)[0]}"

        return ce_array

    def save_to_csv(self, path):
        ce_array = self.compute()
        if not torch.isnan(ce_array).all():
            ce_array = ce_array.cpu().numpy()
            t_array = self.t_array.cpu().numpy()
            df = pd.DataFrame({"t": t_array, "ce": ce_array})
            df.to_csv(path, index=False)
            return df
        else:
            print("CE TRACKER - CE array contains only NaNs, not saving to csv")

    def get_histogram(self):
        ce_array = self.compute().cpu().numpy()
        return ce_array, self.hist_t_array
