import time
from argparse import Namespace
from typing import *

import numpy as np
import torch
from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.cat.appr_cat_orig import Appr as OrigAppr, CheckFederated
from approaches.cat.model_cat_alexnet import ModelCATAlexNet
from approaches.cat.model_cat_mlp import ModelCATMLP
from approaches.param_consumable import ParamConsumable
from utils import myprint as print, print_num_params


class Appr(AbstractAppr, ParamConsumable):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 smax: float, lamb: float,
                 backbone: str,
                 nhid: int, drop1: float, drop2: float, nheads: int,
                 dict__idx_task__dataloader: Dict[int, Dict[str, Any]]):
        super().__init__(device=device,
                         list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=smax, lamb=lamb)

        self.backbone = backbone

        self.args = Namespace(similarity_detection='auto',
                              scenario='til',
                              loss_type='multi-loss-joint-Tsim',
                              # loss_type='no_attention',
                              # output='.',
                              model_weights=1,
                              nepochs=epochs_max,
                              lr=lr,
                              lr_patience=patience_max,
                              n_head=nheads,
                              pdrop1=drop1,
                              pdrop2=drop2,
                              smax=smax,
                              parameter='',
                              )

        if backbone == 'mlp':
            model = ModelCATMLP(device=self.device, list__ncls=list__ncls, inputsize=inputsize,
                                nhid=nhid, args=self.args)
        elif backbone == 'alexnet':
            model = ModelCATAlexNet(device=self.device, list__ncls=list__ncls, inputsize=inputsize,
                                    nhid=nhid, args=self.args)
        else:
            raise NotImplementedError(backbone)
        # endif

        print_num_params(model)

        num_tasks = len(dict__idx_task__dataloader.keys())
        self.list__dl_train = [dict__idx_task__dataloader[idx_task]['train'] for idx_task in range(num_tasks)]
        self.list__dl_val = [dict__idx_task__dataloader[idx_task]['val'] for idx_task in range(num_tasks)]
        self.list__dl_test = [dict__idx_task__dataloader[idx_task]['test'] for idx_task in range(num_tasks)]
        self.list__name = [dict__idx_task__dataloader[idx_task]['name'] for idx_task in range(num_tasks)]
        self.appr = OrigAppr(device=device,
                             model=model,
                             args=self.args,
                             lr_min=lr_min,
                             lr_factor=lr_factor,
                             clipgrad=10000,
                             lamb=lamb,
                             smax=smax,
                             list__dl_train=self.list__dl_train,
                             list__dl_val=self.list__dl_val,
                             list__dl_test=self.list__dl_test,
                             )
        self.check_federated = CheckFederated()
        taskcla = [(t, ncls) for t, ncls in enumerate(list__ncls)]

        # Loop tasks
        self.acc_ac = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.lss_ac = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)

        self.acc_mcl = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.lss_mcl = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)

        self.acc_an = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.lss_an = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)

        self.unit_overlap_sum_transfer = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.norm_transfer_raw = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.norm_transfer_one = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.norm_transfer = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.acc_transfer = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.acc_reference = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.lss_transfer = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)
        self.similarity_transfer = np.zeros((len(taskcla), len(taskcla)), dtype=np.float32)

        self.history_mask_back = []
        self.history_mask_pre = []
        self.similarities = []

        # timer
        self.timer = dict()
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        if idx_task == 0:
            return 0
        # endif

        num_all = 0
        num_blocked = 0

        for n, _ in self.appr.model.named_parameters():
            vals = self.appr.model.get_view_for(n, self.appr.mask_pre)
            if vals is not None:
                num_all += vals.numel()
                num_blocked += vals.sum().item()
            # endif
        # endfor

        return num_blocked / num_all
    # enddef

    def auto_similarity(self, t: int):
        if t > 0:
            for pre_task in range(t + 1):
                print('pre_task: ', pre_task)
                print('t: ', t)
                pre_task_torch = torch.autograd.Variable(torch.LongTensor([pre_task]).to(self.device),
                                                         volatile=False)

                if self.backbone == 'mlp':
                    gfc1, gfc2 = self.appr.model.mask(pre_task_torch)
                    gfc1 = gfc1.detach()
                    gfc2 = gfc2.detach()
                    pre_mask = [gfc1, gfc2]
                elif self.backbone == 'alexnet':
                    gc1, gc2, gc3, gfc1, gfc2 = self.appr.model.mask(pre_task_torch)
                    gc1 = gc1.detach()
                    gc2 = gc2.detach()
                    gc3 = gc3.detach()
                    gfc1 = gfc1.detach()
                    gfc2 = gfc2.detach()
                    pre_mask = [gc1, gc2, gc3, gfc1, gfc2]
                else:
                    raise NotImplementedError(self.backbone)
                # endif

                if pre_task == t:  # the last one
                    print('>>> Now Training Phase: {:6s} <<<'.format('reference'))
                    self.appr.train(t, phase='reference', args=self.args,
                                    pre_mask=pre_mask, pre_task=pre_task,
                                    similarity=None, history_mask_back=None, history_mask_pre=None,
                                    check_federated=None)  # implemented as random mask
                elif pre_task != t:
                    print('>>> Now Training Phase: {:6s} <<<'.format('transfer'))
                    self.appr.train(t, phase='transfer', args=self.args,
                                    pre_mask=pre_mask, pre_task=pre_task,
                                    similarity=None, history_mask_back=None, history_mask_pre=None,
                                    check_federated=None)
                # endif

                if pre_task == t:  # the last one
                    test_loss, test_acc = self.appr.eval(t, self.list__dl_val, phase='reference',
                                                         pre_mask=pre_mask, pre_task=pre_task,
                                                         similarity=None, history_mask_pre=None,
                                                         check_federated=None)
                elif pre_task != t:
                    test_loss, test_acc = self.appr.eval(t, self.list__dl_val, phase='transfer',
                                                         pre_mask=pre_mask, pre_task=pre_task,
                                                         similarity=None, history_mask_pre=None,
                                                         check_federated=None)
                else:
                    raise NotImplementedError
                # endif

                print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<' \
                      .format(t, self.list__name[t], test_loss, 100 * test_acc))

                self.acc_transfer[t, pre_task] = test_acc
                self.lss_transfer[t, pre_task] = test_loss
            # endfor
        # endif
        print('test_acc: ', self.acc_transfer[t][:t + 1])
        print('test_loss: ', self.lss_transfer[t][:t + 1])

        '''
        print('Save at transfer_acc')
        np.savetxt(self.args.output + '_acc_transfer', self.acc_transfer, '%.4f', delimiter='\t')

        print('Save at transfer_loss')
        np.savetxt(self.args.output + '_loss_transfer', self.lss_transfer, '%.4f', delimiter='\t')
        '''

        similarity = [0]
        if t > 0:
            acc_list = self.acc_transfer[t][:t]  # t from 0
            print('acc_list: ', acc_list)

            if 'auto' in self.args.similarity_detection:
                similarity = [0 if (acc_list[acc_id] <= self.acc_transfer[t][t]) else 1 for acc_id in
                              range(len(acc_list))]  # remove all acc < 0.5
            else:
                raise NotImplementedError
            # endif

            for source_task in range(len(similarity)):
                self.similarity_transfer[t, source_task] = similarity[source_task]
            # endfor
            '''
            print('Save at similarity_transfer')
            np.savetxt(self.args.output + '_similarity_transfer', self.similarity_transfer, '%.4f', delimiter='\t')
            '''
        # endif

        print('similarity: ', similarity)
        return similarity
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_after_backward: Dict[str, Any],
              args_on_forward: Dict[str, Any],
              **kwargs) -> Dict[str, Any]:
        time_start = time.time()

        candidate_phase = 'mcl'
        if candidate_phase == 'mcl' and 'auto' in self.args.similarity_detection:
            similarity = self.auto_similarity(idx_task)
        elif candidate_phase == 'mcl' and 'by-name' in self.args.similarity_detection:
            # similarity = true_similarity(task, data)
            raise NotImplementedError
        elif candidate_phase == 'mcl' and 'all-one' in self.args.similarity_detection:
            # similarity = all_one_similarity(task, data)
            raise NotImplementedError
        elif candidate_phase == 'mcl' and 'all-zero' in self.args.similarity_detection:
            # similarity = all_zero_similarity(task, data)
            raise NotImplementedError
        else:
            raise NotImplementedError
        # endif

        self.similarities.append(similarity)
        self.check_federated.set_similarities(self.similarities)

        print('>>> Now Training Phase: {:6s} <<<'.format(candidate_phase))

        self.appr.train(t=idx_task, phase=candidate_phase, args=self.args,
                        similarity=similarity, history_mask_back=self.history_mask_back,
                        history_mask_pre=self.history_mask_pre, check_federated=self.check_federated,
                        pre_mask=NotImplemented, pre_task=NotImplemented)
        time_end = time.time()
        print('-' * 100)

        if candidate_phase == 'mcl':
            self.history_mask_back.append(dict((k, v.data.clone()) for k, v in self.appr.mask_back.items()))
            self.history_mask_pre.append([m.data.clone() for m in self.appr.mask_pre])

            for u in range(idx_task + 1):
                test_loss, test_acc = self.appr.test(t=u, phase=candidate_phase,
                                                     similarity=similarity, history_mask_pre=self.history_mask_pre,
                                                     check_federated=self.check_federated,
                                                     pre_mask=NotImplemented, pre_task=NotImplemented)
                print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<' \
                      .format(u, self.list__name[u], test_loss, 100 * test_acc))

                self.acc_mcl[idx_task, u] = test_acc
                self.lss_mcl[idx_task, u] = test_loss
            # endfor
            # Save
            '''
            print('Save at ' + self.args.output + '_' + candidate_phase)
            np.savetxt(self.args.output + '_' + candidate_phase, self.acc_mcl, '%.4f', delimiter='\t')
            '''
        # endif

        time_consumed = time_end - time_start

        ret = {
            'time_consumed': time_consumed,
            'param_consumed': self.compute_param_consumed(idx_task),
            }

        print(f'similarity_transfer(at task={idx_task}):\n{self.similarity_transfer}')

        return ret
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader,
             args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        return {
            'loss_test': self.lss_mcl[idx_task, idx_task],
            'acc_test': self.acc_mcl[idx_task, idx_task],
            }
    # enddef
# endclass
