from typing import *

import torch
from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.param_consumable import ParamConsumable
from approaches.supsup.appr_supsup_orig import Appr as OrigAppr
from approaches.supsup.args import args
from approaches.supsup.models import module_util
from approaches.supsup.models.modules import SupSupModule


class Appr(AbstractAppr, ParamConsumable):
    def __init__(self,
                 device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 backbone: str,
                 nhid: int, drop1: float, drop2: float, momentum: float, sparsity: float,
                 expname: str, log_dir: str, batch_size: int,
                 ):
        super().__init__(device=device,
                         list__ncls=list__ncls,
                         inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=0, lamb=0,
                         )

        if backbone == 'mlp':
            modelname = 'MLP'
        elif backbone == 'alexnet':
            modelname = 'AlexNet'
        else:
            raise NotImplementedError(backbone)
        # endif

        for k, v in [('name', expname),
                     ('log_dir', log_dir),
                     ('num_tasks', len(list__ncls)),
                     ('model', modelname),
                     ('individual_heads', True),
                     ('conv_type', 'MultitaskMaskConv'),
                     ('bn_type', 'MultitaskNonAffineBN'),
                     ('conv_init', 'default'),
                     ('width_mult', 1.0),
                     ('output_size', max(list__ncls)),
                     ('er_sparsity', True),
                     # ('sparsity', 4),  # [1, 2, 4, 8, 16, 32]
                     ('sparsity', sparsity),
                     ('multigpu', None if device == 'cpu' else [
                         # int(device.split(':')[1]) if ':' in device else 0
                         device
                         ]),
                     ('trainer', None),
                     ('set', expname),
                     # ('train_weight_tasks', 1),
                     ('train_weight_tasks', 0),
                     ('lr', lr),
                     ('train_weight_lr', lr),
                     ('optimizer', 'sgd'),
                     ('momentum', momentum),
                     ('wd', 0.0),
                     ('epochs', epochs_max),
                     ('no_scheduler', False),
                     ('iter_lim', -1),
                     ('log_interval', 10),
                     ('batch_size', batch_size),
                     ('eval_ckpts', []),
                     ('adaptor', 'gt'),
                     ]:
            setattr(args, k, v)
        # endfor

        self.appr = OrigAppr(inputsize=inputsize,
                             lr_factor=lr_factor, lr_min=lr_min,
                             patience_max=patience_max,
                             list__ncls=list__ncls,
                             nhid=nhid, drop1=drop1, drop2=drop2,
                             )
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any], args_on_after_backward: Dict[str, Any],
              **kwargs) -> Dict[str, Any]:
        time_consumed = self.appr.train_task(idx_task,
                                             dl_train=dl_train, dl_val=dl_val,
                                             )

        ret = {
            'time_consumed': time_consumed,
            'param_consumed': self.compute_param_consumed(idx_task),
            }

        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        self.appr.test_task(idx_task, dl_test)
        acc_test = self.appr.adapt_acc1[idx_task]

        return {
            'loss_test': 0.0,
            'acc_test': acc_test,
            }
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        num_all = 0
        num_blocked = 0

        for n, m in self.appr.model.named_modules():
            if isinstance(m, SupSupModule):
                if idx_task < 0:
                    stacked = torch.stack([
                        module_util.get_subnet_fast(m.scores[j])
                        for j in range(min(len(self.list__ncls), m.num_tasks_learned))
                        ])
                    alpha_weights = m.alphas[: m.num_tasks_learned]
                    subnet = (alpha_weights * stacked).sum(dim=0)
                else:
                    subnet = module_util.GetSubnetFast.apply(m.scores[idx_task])
                # endif

                num_all += subnet.numel()
                num_blocked += (1 - subnet).sum().item()
            # endif
        # endfor

        return num_blocked / num_all
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
        '''
        fc = self.appr.model.clf  # type: nn.Module
        for n, p in fc.named_parameters():
            p.requires_grad_(False)
        # endfor
        '''
    # enddef
