from collections import defaultdict
from pathlib import Path

import einops
import pandas as pd
import torch

from tabicl.core.enums import DownstreamTask
from tabicl.utils.paths_and_filenames import METRICS_TRAIN_FILE_NAME, METRICS_VAL_FILE_NAME


class MetricsFewSteps():

    def __init__(self):
        self.reset()

    def update(self, pred: torch.Tensor, target: torch.Tensor):
        "Predictions are assumed to be logits"

        pred = einops.rearrange(pred, 'b n c -> (b n) c')
        target = einops.rearrange(target, 'b n -> (b n)')

        loss = torch.nn.functional.cross_entropy(pred, target, reduction='sum')
        self._loss += loss.item()

        pred_ = pred.argmax(dim=-1)
        self._correct += pred_.eq(target).sum().item()

        self._total += target[target != -100].shape[0]

    def reset(self):
        self._loss = 0
        self._correct = 0
        self._total = 0

    
    @property
    def loss(self):
        return self._loss / self._total
    
    @property
    def accuracy(self):
        return self._correct / self._total
    


class MetricsFullRun():

    def __init__(self):
        self.reset()

    def update(self, pred: torch.Tensor, target: torch.Tensor, scaler: torch.amp.GradScaler):
        "Predictions are assumed to be logits"

        pred = einops.rearrange(pred, 'b n c -> (b n) c')
        target = einops.rearrange(target, 'b n -> (b n)')

        loss = torch.nn.functional.cross_entropy(pred, target, reduction='sum').item()

        pred_ = pred.argmax(dim=-1)
        correct = pred_.eq(target).sum().item()

        total = target[target != -100].shape[0]

        self._loss.append(loss / total)
        self._correct.append(correct / total)
        self._scale.append(scaler._scale.cpu().item())     # type: ignore 

    def update_val(self, norm_acc_val: float, norm_acc_test: float, step: int, task: DownstreamTask):

        self._norm_acc_val[task].append(norm_acc_val)
        self._norm_acc_test[task].append(norm_acc_test)

        if step not in self._val_step:
            # Multiple downstream tasks can add the same step
            self._val_step.append(step)



    def reset(self):
        self._loss = []
        self._correct = []
        self._norm_acc_val: dict[DownstreamTask, list[float]] = defaultdict(list)
        self._norm_acc_test: dict[DownstreamTask, list[float]] = defaultdict(list)
        self._val_step = []
        self._scale = []


    def save(self, output_dir: Path):

        metrics_train = {
            'step': list(range(len(self._loss))),
            'loss': self._loss,
            'accuracy': self._correct,
            'scale': self._scale,
        }

        pd.DataFrame(metrics_train).to_csv(output_dir / METRICS_TRAIN_FILE_NAME, index=False)

        metrics_val = {
            'step': self._val_step,
        }

        for task in self._norm_acc_val.keys():
            metrics_val[f'norm_acc_val_{task.value}'] = self._norm_acc_val[task]
            metrics_val[f'norm_acc_test_{task.value}'] = self._norm_acc_test[task]

        pd.DataFrame(metrics_val).to_csv(output_dir / METRICS_VAL_FILE_NAME, index=False)
