from abc import ABC, abstractmethod
import numpy as np

class Trainer(ABC):
    def __init__(
        self, optimizer, loss_fn, device,
    ):
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device

    def _update_losses(self, names, losses, d):
        for name, loss in zip(names, losses):
            name = "loss/" + name
            if name in d.keys():
                d[name] += loss if isinstance(loss, int) else loss.detach().item()
            else:
                d[name] = loss if isinstance(loss, int) else loss.detach().item()
        return d
    
    def _update_norms(self, model, d):
        for name, p in model.named_parameters():
            if p.grad is not None:
                name_g = "gradients/" + name
                if name_g in d.keys():
                    d[name_g] += p.grad.detach().norm(2).item()
                else:
                    d[name_g] = p.grad.detach().norm(2).item()
            if "weight" in name:
                name_w = "weights/" + name
                if name_w in d.keys():
                    d[name_w] += p.detach().norm(2).item()
                else:
                    d[name_w] = p.detach().norm(2).item()
        return d
    
    def _update_hidden_layer(self, names, values, d):
        for name, v in zip(names, values):
            name = "hidden_layer/" + name
            if name in d.keys():
                d[name] += v
            else:
                d[name] = v
        return d

    @abstractmethod
    def train_epoch(self):
        pass
    
    @abstractmethod
    def test_epoch(self):
        pass


class EarlyStopping():
    """
        Early stopping callback that stops training when validation loss stops improving.

        Parameters:
        -----------
        tolerance : int
            Number of epochs to wait before stopping if no improvement is seen.
        min_delta : float
            Minimum change in the monitored quantity to qualify as an improvement.
    """

    def __init__(self, lr_tolerance, es_tolerance, min_delta):

        self.lr_tolerance = lr_tolerance
        self.es_tolerance = es_tolerance
        self.min_delta = min_delta

        self.reduced_lr = False
        self.early_stop = False

    def __call__(self, val_loss_list):

        if len(val_loss_list) < max(self.lr_tolerance, self.es_tolerance):
            return
        
        val_loss_tocheck = (
            np.array(val_loss_list[-self.es_tolerance:]) if self.reduced_lr
            else np.array(val_loss_list[-self.lr_tolerance:])
        )

        # Take care of increasing loss or outliers
        indices = np.where(val_loss_tocheck[1:] > val_loss_tocheck[:-1])[0]
        max_n_idx = (
            int(0.55*self.es_tolerance) if self.reduced_lr
            else int(0.55*self.lr_tolerance)
        )
        if len(indices) >= max_n_idx:
            if self.reduced_lr : self.early_stop = True
            self.reduced_lr = True
            return

        # Take care of plateauing loss
        if np.allclose(
            val_loss_tocheck[:-1], # all but the last element
            val_loss_tocheck[1:], # all but the first element
            rtol=0, # relative tolerance
            atol=self.min_delta # absolute tolerance
        ):
            if self.reduced_lr : self.early_stop = True
            self.reduced_lr = True
