from collections import defaultdict
from pathlib import Path

import einops
import pandas as pd
import torch

from tabicl.core.enums import DownstreamTask, Task
from tabicl.utils.paths_and_filenames import METRICS_TRAIN_FILE_NAME, METRICS_VAL_FILE_NAME


class MetricsFewSteps():

    def __init__(self, task: Task):
        self.task = task
        self.reset()

    def update(self, pred: torch.Tensor, target: torch.Tensor):
        "Predictions are assumed to be logits, if the task is classification"

        match self.task:
            case Task.REGRESSION:
                ss_res = (pred - target).pow(2)
                ss_abs = (pred - target).abs()
                ss_res = einops.reduce(ss_res, 'b n -> b', 'sum')
                target_mean = einops.reduce(target, 'b n -> b', 'mean')
                ss_tot = (target_mean[:, None] - target).pow(2)
                ss_tot = einops.reduce(ss_tot, 'b n -> b', 'sum')
                r2 = 1 - ss_res / ss_tot
                r2 = einops.reduce(r2, 'b -> ()', 'sum') * pred.shape[1]
                loss = einops.reduce(ss_res, 'b -> ()', 'sum').item()
                loss2 = einops.reduce(ss_abs, 'b n -> ()', 'sum').item()
                score = r2.item()

            case Task.CLASSIFICATION:
                pred = einops.rearrange(pred, 'b n c -> (b n) c')
                target = einops.rearrange(target, 'b n -> (b n)')
                loss = torch.nn.functional.cross_entropy(pred, target, reduction='sum').item()
                loss2 = 0

                pred_ = pred.argmax(dim=-1)
                score = pred_.eq(target).sum().item()

        self._score += score
        self._loss += loss
        self._loss2 += loss2
        self._total += target[target != -100].shape[0]


    def reset(self):
        self._loss = 0
        self._loss2 = 0
        self._score = 0
        self._total = 0

    
    @property
    def loss(self):
        return self._loss / self._total
    
    @property
    def loss2(self):
        return self._loss2 / self._total
    
    @property
    def score(self):
        return self._score / self._total
    


class MetricsFullRun():

    def __init__(self, task: Task):
        self.task = task
        self.reset()

    def update(self, pred: torch.Tensor, target: torch.Tensor, scaler: torch.amp.GradScaler, grad_norm: float):
        "Predictions are assumed to be logits"

        
        match self.task:
            case Task.REGRESSION:
                ss_res = (pred - target).pow(2)
                ss_abs = (pred - target).abs()
                ss_res = einops.reduce(ss_res, 'b n -> b', 'sum')
                target_mean = einops.reduce(target, 'b n -> b', 'mean')
                ss_tot = (target_mean[:, None] - target).pow(2)
                ss_tot = einops.reduce(ss_tot, 'b n -> b', 'sum')
                r2 = 1 - ss_res / ss_tot
                loss = einops.reduce(ss_res, 'b -> ()', 'sum')
                loss2 = einops.reduce(ss_abs, 'b n -> ()', 'sum')
                r2 = einops.reduce(r2, 'b -> ()', 'sum') * pred.shape[1]
                score = r2

            case Task.CLASSIFICATION:
                pred = einops.rearrange(pred, 'b n c -> (b n) c')
                target = einops.rearrange(target, 'b n -> (b n)')
                loss = torch.nn.functional.cross_entropy(pred, target, reduction='sum')

                pred_class = pred.argmax(dim=-1)
                acc = (pred_class == target).sum()
                score = acc

        total = target[target != -100].shape[0]

        self._score.append((score / total).item())
        self._loss.append((loss / total).item())
        self._loss2.append((loss2 / total).item())
        self._scale.append(scaler._scale.cpu().item())     # type: ignore 
        self._grad_norm.append(grad_norm)


    def update_val(self, norm_metric_val: float, norm_metric_test: float, step: int, task: DownstreamTask):

        self._norm_metric_val[task].append(norm_metric_val)
        self._norm_metric_test[task].append(norm_metric_test)

        if step not in self._val_step:
            # Multiple downstream tasks can add the same step
            self._val_step.append(step)



    def reset(self):
        self._loss = []
        self._loss2 = []
        self._score = []
        self._norm_metric_val: dict[DownstreamTask, list[float]] = defaultdict(list)
        self._norm_metric_test: dict[DownstreamTask, list[float]] = defaultdict(list)
        self._val_step = []
        self._scale = []
        self._grad_norm = []


    def save(self, output_dir: Path):

        steps = list(range(len(self._loss)))

        match self.task:
            case Task.REGRESSION:
                metrics_train = {
                    'step': steps,
                    'mse': self._loss,
                    'mae': self._loss2,
                    'r2': self._score,
                    'scale': self._scale,
                    'grad_norm': self._grad_norm,
                }
            case Task.CLASSIFICATION:
                metrics_train = {
                    'step': steps,
                    'cross_entropy': self._loss,
                    'accuracy': self._score,
                    'scale': self._scale,
                }

        pd.DataFrame(metrics_train).to_csv(output_dir / METRICS_TRAIN_FILE_NAME, index=False)

        metrics_val = {
            'step': self._val_step,
        }

        for task in self._norm_metric_val.keys():
            metrics_val[f'norm_acc_val_{task.value}'] = self._norm_metric_val[task]
            metrics_val[f'norm_acc_test_{task.value}'] = self._norm_metric_test[task]

        pd.DataFrame(metrics_val).to_csv(output_dir / METRICS_VAL_FILE_NAME, index=False)
