import time
from typing import *

from torch import Tensor
from torch.utils.data import DataLoader

from approaches.ewc.mixin_ewc import MixinEWC
from approaches.hat.appr_hat import Appr as HATAppr
from utils import print_num_params


class Appr(HATAppr, MixinEWC):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int, backbone: str,
                 smax: float, lamb: float,
                 nhid: int, drop1: float, drop2: float):
        MixinEWC.__init__(self)
        HATAppr.__init__(self, backbone=backbone,
                         device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=smax, lamb=lamb,
                         nhid=nhid, drop1=drop1, drop2=drop2, hat_enabled=False)

        print_num_params(self.model)
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any]) -> Tensor:
        loss = super(Appr, self).compute_loss(output, target, misc)
        loss_ewc = self.lamb * self.ewc_compute_loss()

        return loss + loss_ewc
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              **kwargs) -> Dict[str, Any]:
        ret = super().train(idx_task=idx_task, dl_train=dl_train, dl_val=dl_val,
                            args_on_forward=args_on_forward,
                            args_on_after_backward=args_on_after_backward)

        time_start = time.time()
        self.ewc_in_train(idx_task=idx_task, dl_train=dl_train,
                          smax=self.smax, args_on_forward=args_on_forward)
        time_end = time.time()
        ret['time_consumed'] += (time_end - time_start)

        return ret
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return 0
    # enddef
