import time
from typing import *

from torch import Tensor
from torch.utils.data import DataLoader

from approaches.ewc.mixin_ewc import MixinEWC
from approaches.hat.appr_hat import Appr as HATAppr


class Appr(HATAppr, MixinEWC):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 smax: float, lamb: float,
                 drop1: float, drop2: float):
        MixinEWC.__init__(self)
        HATAppr.__init__(self, device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=smax, lamb=lamb,
                         drop1=drop1, drop2=drop2)
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any]) -> Tensor:
        loss = super(Appr, self).compute_loss(output, target, misc)
        loss_ewc = self.lamb * self.ewc_compute_loss()

        return loss + loss_ewc
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        time_consumed = super().train(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val,
                                      args_on_forward=args_on_forward,
                                      args_on_after_backward=args_on_after_backward)
        time_start = time.time()
        self.ewc_in_train(idx_task=idx_task, dl_train=dl_train,
                          smax=self.smax, args_on_forward=args_on_forward)
        time_end = time.time()
        time_consumed += (time_end - time_start)

        return time_consumed
    # enddef
