from abc import ABC, abstractmethod
import torch.nn as nn
from typing import Dict
import torch
from loss import loss_registry
from collections import defaultdict
from copy import deepcopy
import numpy as np
from models.s4.s4 import S4Block
from models.s4seq_model import S4DualSeqModel


def eval_to_wandb(eval_dict: Dict[str, Dict[str, float]], is_train: bool, prefix: str = ""):
    '''Logs the evaluation dictionary to wandb
    :param eval_dict: dictionary of evaluation values
    :param is_train: bool, whether the evaluation is for training or testing'''
    log_dict = {}
    for category, items in eval_dict.items():
        if category.startswith("Average_"): 
            continue
        for item, value in items.items():
            # clip if float to avoid large positive or negative values
            if isinstance(value, float):
                value = np.clip(value, -1e5, 1e5)
            log_dict[f"{category}/{prefix}_{item}"] = value
            # backward compatibility: 
            # step_inL -> step_id, step_LG -> step_ood
            # if item == "step_inL":
            #     item = "step_id"
            # elif item == "step_LG":
            #     item = "step_ood"
            # nRMSE -> Loss
            if category == "nRMSE":
                log_dict[f"Loss/{prefix}_loss_{item}"] = value
            elif category == "Step_nRMSE":
                log_dict[f"StepLoss/{prefix}_loss_{item}"] = value
    return log_dict

def eval_to_wandb_summary(eval_dict: Dict[str, Dict[str, float]], is_train: bool, is_final: bool = False):
    prefix = "train" if is_train else "test"
    log_dict = {}
    for category, items in eval_dict.items():
        if not category.startswith("Average_"): 
            continue
        for item, value in items.items():
            if isinstance(value, float):
                value = np.clip(value, -1e5, 1e5)
            log_dict[f"{category}/{prefix}_{item}"] = value
    return log_dict

def eval_to_print(eval_dict: Dict[str, Dict[str, float]], is_train: bool):
    '''Prints the evaluation dictionary
    :param eval_dict: dictionary of evaluation values
    :param is_train: bool, whether the evaluation is for training or testing
    :param epoch: int, current epoch'''
    print_msg = ""
    prefix = "train" if is_train else "test"
    ITEMS = ["step","step_inL","step_LG","step_id"]
    CATEGORIES = ["nRMSE"]
    for category, items in eval_dict.items():
        for item, value in items.items():
            if item in ITEMS and category in CATEGORIES:
                print_msg += f" | {prefix}_{item} {value:.5f}"
    return print_msg

class Evaluator(ABC):
    def __init__(self, metrics = None, global_metrics = None, **kwargs):
        super(Evaluator, self).__init__()
        if metrics is None:
            metrics = []
        self.metrics = [loss_registry[metric]() for metric in metrics ]
        self.n = 0
        self._init(**kwargs)
    
    @abstractmethod
    def _init(self, **kwargs):
        '''Initializes the evaluation dictionary'''
        pass
    
    def reset(self):
        for category, items in self.eval_dict.items():
            for item in items:
                self.eval_dict[category][item] = 0.0
        self.n = 0
    
    def evaluate(self, y_pred, y_true) -> Dict[str, Dict[str, float]]:
        for metric in self.metrics:
            self._evaluate(y_pred, y_true, metric)
        self.n += y_pred.shape[0]

    @abstractmethod
    def _evaluate(self, y_pred, y_true, loss_fn):
        '''Computes the loss
        :param y_pred: (B, L, D)
        :param y_true: (B, L, D)
        :return: dictionary Dict[key1: Dict[key2, val]], where key1 is evaluation category, key2 is evaluation item, and val is evaluation value, 
        eg. {"Loss" {"loss_full": 0.1}}
        '''
        raise NotImplementedError
    
    def get_evaluation(self, reset = True) -> Dict[str, Dict[str, float]]:
        '''Flushes the evaluation dictionary and returns the average of the values
        :return: dictionary of evaluation values
        '''
        # assert self.n > 0, "No evaluations to flush"
        if self.n == 0:
            return {}
        avg_dict = {}
        for category, items in self.eval_dict.items():
            avg_dict[category] = {}
            for item, value in items.items():
                avg_dict[category][item] = value / self.n
        return avg_dict
    
class DummyEvaluator(Evaluator):
    def _init(self, **kwargs):
        '''Initializes the evaluation dictionary'''
        self.eval_dict = {}
    
    def _evaluate(self, y_pred, y_true) -> Dict[str, Dict[str, float]]:
        '''Does nothing'''
        pass


class SingleEvaluator(Evaluator):
    def _init(self, **kwargs):
        '''Initializes the evaluation dictionary'''
        self.eval_dict = {"Loss": {"loss_full": 0.0}}

    def _evaluate(self, y_pred, y_true, loss_fn, local = True) -> Dict[str, Dict[str, float]]:
        '''Computes the loss
        :param y_pred: (B, L, D)
        :param y_true: (B, L, D)
        :return: dictionary Dict[key1: Dict[key2, val]], where key1 is evaluation category, key2 is evaluation item, and val is evaluation value, 
        eg. {"Loss" {"loss_ful": 0.1}}
        '''
        loss = loss_fn(y_pred, y_true)
        self.eval_dict["Loss"][loss_fn.name] += loss.item()


class RolloutEvaluator(Evaluator):
    def _init(self, average_steps = False, **kwargs):
        '''Initializes the evaluation dictionary'''
        self.eval_dict = {}
        for lname in [loss_fn.name for loss_fn in self.metrics]:
            self.eval_dict.update( {lname: defaultdict(lambda: 0.0),
                            f"Step_{lname}": defaultdict(lambda: 0.0)} )
        self.average_steps = average_steps

    def _evaluate(self, y_pred, y_true, loss_fn) -> Dict:
        '''Computes the loss
        :param y_pred: (B, Sx, T, D)
        :param y_true: (B, Sx, T, D)
        :return: loss value
        '''
        B = y_pred.size(0)
        T, V = y_pred.shape[-2:] 
        # calculate loss for each time step
        lname = loss_fn.name

        loss_step = torch.empty(T)
        for t in range(T):
            loss_step[t] = loss_fn(y_pred[...,t,:], y_true[...,t,:])
        loss_step_full = loss_step.sum()
        
        ls = loss_step_full.item() / T if self.average_steps else loss_step_full.item()
        self.eval_dict[lname][f"step"] += ls
        for t in range(T):
            loss_step_key = f"step_{t:02}"
            self.eval_dict[f"Step_{lname}"][loss_step_key] += loss_step[t].item()


class RolloutOODEvaluator(Evaluator):
    def _init(self, train_timesteps, t_train, average_steps = False, **kwargs):
        '''Initializes the evaluation dictionary'''
        self.eval_dict = {}
        for lname in [loss_fn.name for loss_fn in self.metrics]:
            self.eval_dict.update( {lname: defaultdict(lambda: 0.0),
                            f"Step_{lname}": defaultdict(lambda: 0.0),
                            f"Average_{lname}": defaultdict(lambda: 0.0) } )
        self.LG_initial_t = train_timesteps
        self.ood_initial_t = t_train - 1
        self.average_steps = average_steps

    def _evaluate(self, y_pred, y_true, loss_fn) -> Dict:
        '''Computes the loss
        :param y_pred: (B, Sx, T, V)
        :param y_true: (B, Sx, T, V)
        :return: loss value
        '''
        B = y_pred.size(0)
        T, V = y_pred.shape[-2:]
        # reshape to (B, L*T*D)
        lname = loss_fn.name
        # calculate loss for each time step
        loss_step = torch.empty(T)
        for t in range(T):
            loss_step[t] = loss_fn(y_pred[...,t,:], y_true[...,t,:])
        loss_step_full = loss_step.sum()
        ls = loss_step_full.item() / T if self.average_steps else loss_step_full.item()

        self.eval_dict[lname][f"step"] += ls

        for t in range(T):
            # make item value have as many digits as T, ie loss_step_01 instead of loss_step_1 if T = 10
            loss_step_key = f"step_{t:02}"
            self.eval_dict[f"Step_{lname}"][loss_step_key] += loss_step[t].item()
        
        if self.ood_initial_t >= 1:
            self.eval_dict[lname][f"step_inL"] += loss_step[:self.LG_initial_t].sum() / (len(loss_step[:self.LG_initial_t]) if self.average_steps else 1 )
            self.eval_dict[lname][f"step_LG"] += loss_step[self.LG_initial_t:self.ood_initial_t].sum() / (len(loss_step[self.LG_initial_t:self.ood_initial_t]) if self.average_steps else 1 )
            self.eval_dict[lname][f"step_id"] += loss_step[:self.ood_initial_t].sum() / (len(loss_step[:self.ood_initial_t]) if self.average_steps else 1 )
            self.eval_dict[lname][f"step_ood"] += loss_step[self.ood_initial_t:].sum() / (len( loss_step[self.ood_initial_t:]) if self.average_steps else 1 )
        
        for t in range(T):
            loss_step_key = f"step_{t:02}"
            self.eval_dict[f"Average_{lname}"][f"_{loss_step_key}"] += loss_step[:t+1].sum().item() / (t+1)

def s4model_eval(model):
    if not isinstance(model, S4DualSeqModel):
        return {}
    As = []
    Bs = []
    Cs = []
    for block in model.s4_layers:
        if isinstance(block, S4Block):
            layer = block.layer
            As.append(torch.complex(layer.kernel.A_real, layer.kernel.A_imag).abs())
            Bs.append(layer.kernel.B.abs())
            Cs.append(layer.kernel.C.abs())
    A = torch.stack(As)
    B = torch.stack(Bs)
    C = torch.stack(Cs)
    return {"S4Params/A_mean": A.mean().item(),
            "S4Params/B_mean": B.mean().item(),
            "S4Params/C_mean": C.mean().item(),
            "S4Params/A_max": A.max().item(),
            "S4Params/B_max": B.max().item(),
            "S4Params/C_max": C.max().item()}



evaluator_registry = {"single": SingleEvaluator,    
                      "rollout": RolloutEvaluator,
                      "dummy": DummyEvaluator}

