from typing import *

from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.supsup.appr_supsup_orig import Appr as OrigAppr
from approaches.supsup.args import args
from torch import nn


class Appr(AbstractAppr):
    def __init__(self,
                 device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 momentum: float, sparsity: int,
                 expname: str, log_dir: str, batch_size: int,
                 ):
        super().__init__(device=device,
                         list__ncls=list__ncls,
                         inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=0, lamb=0,
                         )
        for k, v in [('name', expname),
                     ('log_dir', log_dir),
                     ('num_tasks', len(list__ncls)),
                     ('model', 'FC2048'),
                     ('individual_heads', True),
                     ('conv_type', 'MultitaskMaskConv'),
                     ('bn_type', 'MultitaskNonAffineBN'),
                     ('conv_init', 'default'),
                     ('width_mult', 1.0),
                     ('output_size', max(list__ncls)),
                     ('er_sparsity', True),
                     # ('sparsity', 4),  # [1, 2, 4, 8, 16, 32]
                     ('sparsity', sparsity),
                     ('multigpu', None if device == 'cpu' else [
                         # int(device.split(':')[1]) if ':' in device else 0
                         device
                         ]),
                     ('trainer', None),
                     ('set', expname),
                     ('train_weight_tasks', -1),
                     ('lr', lr),
                     ('train_weight_lr', lr),
                     ('optimizer', 'sgd'),
                     # ('momentum', 0.9),
                     ('momentum', momentum),
                     ('wd', 0.0),
                     ('epochs', epochs_max),
                     ('no_scheduler', False),
                     ('iter_lim', -1),
                     ('log_interval', 10),
                     ('batch_size', batch_size),
                     ('eval_ckpts', []),
                     ('adaptor', 'gt'),
                     ]:
            setattr(args, k, v)
        # endfor
        self.appr = OrigAppr(inputsize=inputsize,
                             lr_factor=lr_factor, lr_min=lr_min,
                             patience_max=patience_max,
                             list__ncls=list__ncls,
                             )
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        time_consumed = self.appr.train_task(idx_task,
                                             dl_train=dl_train, dl_val=dl_val,
                                             )
        return time_consumed
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        self.appr.test_task(idx_task, dl_test)
        acc_test = self.appr.adapt_acc1[idx_task]
        return {
            'loss_test': 0.0,
            'acc_test': acc_test,
            }
    # enddef

    def complete_learning(self, idx_task: int) -> None:
        fc = self.appr.model.fc[idx_task]  # type: nn.Module
        for n, p in fc.named_parameters():
            p.requires_grad_(False)
        # endfor
    # enddef
