import time
from copy import deepcopy
from itertools import zip_longest
from typing import *

import numpy as np
import torch
from torch.utils.data import DataLoader

import utils


class Appr(object):
    def __init__(self, device: str,
                 model, args, lr_min, lr_factor, clipgrad, lamb, smax,
                 list__dl_train: List[DataLoader], list__dl_val: List[DataLoader], list__dl_test: List[DataLoader]):
        self.model = model
        self.device = device
        self.model = self.model.to(self.device)

        self.list__dl_train = list__dl_train
        self.list__dl_val = list__dl_val
        self.list__dl_test = list__dl_test

        self.nepochs = args.nepochs
        self.lr = args.lr
        self.lr_min = lr_min
        self.lr_factor = lr_factor
        self.lr_patience = args.lr_patience
        self.clipgrad = clipgrad
        self.args = args
        self.ce = torch.nn.CrossEntropyLoss()
        self.optimizer = self._get_optimizer()

        '''
        self.nepochs_kt = args.nepochs_kt
        self.lr_kt = args.lr_kt
        self.lr_patience_kt = args.lr_patience_kt
        '''

        self.lamb = lamb  # Grid search = [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2.5, 4]; chosen was 0.75
        self.smax = smax  # Grid search = [25, 50, 100, 200, 400, 800]; chosen was 400
        if len(args.parameter) >= 1:
            params = args.parameter.split(',')
            print('Setting parameters to', params)
            self.lamb = float(params[0])
            self.smax = float(params[1])
        # endif

        self.mask_pre = None
        self.mask_back = None

        self.transfer_initial_model = deepcopy(model.transfer)

        return
    # enddef

    def _get_optimizer(self, lr=None, phase=None, args=None):
        # if phase==None:
        # raise NotImplementedError
        if lr is None:
            lr = self.lr
        elif phase == 'mcl' and 'no_attention' in args.loss_type:
            return torch.optim.SGD(list(self.model.mcl.parameters()), lr=lr)
        elif phase == 'mcl' and 'joint' in args.loss_type:
            return torch.optim.SGD(list(self.model.kt.parameters()) + list(self.model.mcl.parameters()), lr=lr)
        elif phase == 'transfer':
            return torch.optim.SGD(list(self.model.transfer.parameters()), lr=lr)
        elif phase == 'reference':
            return torch.optim.SGD(list(self.model.transfer.parameters()), lr=lr)
        # endif
    # enddef

    def train(self, t, phase, args,
              pre_mask, pre_task,
              similarity, history_mask_back,
              history_mask_pre, check_federated):

        self.model.transfer = deepcopy(self.transfer_initial_model)  # Restart transfer network: isolate

        best_loss = np.inf
        best_model = utils.get_model(self.model)

        if phase == 'mcl' or phase == 'transfer' or phase == 'reference':
            lr = self.lr
            patience = self.lr_patience
            nepochs = self.nepochs
        else:
            raise NotImplementedError
        # endif

        self.optimizer = self._get_optimizer(lr, phase, args)
        print('similarity: ', similarity)

        try:
            for e in range(nepochs):
                # Train
                clock0 = time.time()
                self.train_epoch(t, phase=phase, pre_mask=pre_mask,
                                 pre_task=pre_task, similarity=similarity, history_mask_back=history_mask_back,
                                 history_mask_pre=history_mask_pre, check_federated=check_federated)
                clock1 = time.time()
                train_loss, train_acc = self.eval(t, self.list__dl_train, phase=phase, pre_mask=pre_mask,
                                                  pre_task=pre_task, similarity=similarity,
                                                  history_mask_pre=history_mask_pre, check_federated=check_federated)
                clock2 = time.time()

                sbatch = self.list__dl_train[t].batch_size
                print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |' \
                      .format(e + 1,
                              1000 * sbatch * (clock1 - clock0) / len(self.list__dl_train[t].dataset),
                              1000 * sbatch * (clock2 - clock1) / len(self.list__dl_train[t].dataset),
                              train_loss,
                              100 * train_acc),
                      end='')
                # Valid
                valid_loss, valid_acc = self.eval(t, self.list__dl_val, phase=phase, pre_mask=pre_mask,
                                                  pre_task=pre_task, similarity=similarity,
                                                  history_mask_pre=history_mask_pre, check_federated=check_federated)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, 100 * valid_acc), end='')
                # Adapt lr
                if valid_loss < best_loss:
                    best_loss = valid_loss
                    best_model = utils.get_model(self.model)
                    patience = self.lr_patience
                    print(' *', end='')
                else:
                    patience -= 1
                    if patience <= 0:
                        lr /= self.lr_factor
                        print(' lr={:.1e}'.format(lr), end='')
                        if lr < self.lr_min:
                            print()
                            break
                        # endif
                        patience = self.lr_patience
                        self.optimizer = self._get_optimizer(lr, phase, args)
                    # endif
                # endif
                print()
            # endfor
        except KeyboardInterrupt:
            print()
        # endwith

        # Restore best validation model
        utils.set_model_(self.model, best_model)

        if phase == 'mcl':
            # Activations mask
            # task = torch.autograd.Variable(torch.LongTensor([t]).cuda(), volatile=False)
            task = torch.tensor([t], requires_grad=False, device=self.device).long()
            mask = self.model.mask(task, s=self.smax)

            for i in range(len(mask)):
                mask[i] = torch.autograd.Variable(mask[i].data.clone(), requires_grad=False)
            # endfor

            if t == 0:
                self.mask_pre = mask
            else:
                for i in range(len(self.mask_pre)):
                    self.mask_pre[i] = torch.max(self.mask_pre[i], mask[i])
                # endfor
            # endif

            # Weights mask
            self.mask_back = {}

            for n, _ in self.model.named_parameters():
                vals = self.model.get_view_for(n, self.mask_pre)
                if vals is not None:
                    self.mask_back[n] = 1 - vals
                # endif
            # endfor
        # endif

        return
    # enddef

    def train_epoch(self, t, phase,
                    pre_mask, pre_task,
                    similarity, history_mask_back,
                    history_mask_pre, check_federated,
                    thres_cosh=50, thres_emb=6, ):
        self.model.train()
        dataloader_train = self.list__dl_train[t]

        # Loop batches
        for idx_batch, (images, targets) in enumerate(dataloader_train):
            images = images.to(self.device)
            targets = targets.to(self.device)
            # task = torch.autograd.Variable(torch.LongTensor([t]).to(self.device), volatile=False)
            task = torch.tensor([t], requires_grad=False, device=self.device).long()
            s = (self.smax - 1 / self.smax) * idx_batch / len(dataloader_train) + 1 / self.smax

            # Forward
            if phase == 'mcl':
                outputs, masks, outputs_attn = self.model.forward(task, images, s=s, phase=phase,
                                                                  similarity=similarity,
                                                                  history_mask_pre=history_mask_pre,
                                                                  check_federated=check_federated)
                output = outputs[t]

                if outputs_attn is None:
                    print(f'training {t}, output_attn is None')
                    raise NotImplementedError
                    loss = self.criterion(output, targets, masks)
                else:
                    # print(f'training {t}, output_attn is used')
                    output_attn = outputs_attn[t]
                    loss = self.joint_criterion(output, targets, masks, output_attn)
                # endif
            elif phase == 'transfer' or phase == 'reference':

                outputs = self.model.forward(task, images, s=s, phase=phase,
                                             pre_mask=pre_mask, pre_task=pre_task)
                output = outputs[t]
                loss = self.transfer_criterion(output, targets)
            # endif

            # Backward
            self.optimizer.zero_grad()
            loss.backward()

            if phase == 'mcl':
                # Restrict layer gradients in backprop
                if t > 0:
                    for n, p in self.model.named_parameters():
                        if n in self.mask_back and p.grad is not None:
                            Tsim_mask = self.model.Tsim_mask(task, history_mask_pre=history_mask_pre,
                                                             similarity=similarity)
                            Tsim_vals = self.model.get_view_for(n, Tsim_mask).clone()
                            p.grad.data *= torch.max(self.mask_back[n], Tsim_vals)
                        # endif
                    # endfor
                # endif

                # Compensate embedding gradients
                for n, p in self.model.named_parameters():
                    if n.startswith('mcl.e') and p.grad is not None:
                        num = torch.cosh(torch.clamp(s * p.data, -thres_cosh, thres_cosh)) + 1
                        den = torch.cosh(p.data) + 1
                        p.grad.data *= self.smax / s * num / den
                    # endif
                # endfor
            elif phase == 'reference':
                # Compensate embedding gradients
                for n, p in self.model.named_parameters():
                    if n.startswith('transfer.e') and p.grad is not None:
                        num = torch.cosh(torch.clamp(s * p.data, -thres_cosh, thres_cosh)) + 1
                        den = torch.cosh(p.data) + 1
                        p.grad.data *= self.smax / s * num / den
                    # endif
                # endfor
            # endif

            # Apply step
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
            self.optimizer.step()

            if phase == 'mcl':
                # Constrain embeddings
                for n, p in self.model.named_parameters():
                    if n.startswith('mcl.e'):
                        p.data = torch.clamp(p.data, -thres_emb, thres_emb)
                    # endif
                # endfor
            elif phase == 'reference':
                # Constrain embeddings
                for n, p in self.model.named_parameters():
                    if n.startswith('transfer.e'):
                        p.data = torch.clamp(p.data, -thres_emb, thres_emb)
                    # endif
                # endfor
            # endif
        # endfor
        return
    # enddef

    def eval(self, t, list__dl: List[DataLoader], phase,
             pre_mask, pre_task, similarity,
             history_mask_pre, check_federated
             ):
        dataloader = list__dl[t]

        total_att_loss = 0
        total_att_acc = 0

        total_mask_loss = 0
        total_mask_acc = 0

        total_num = 0
        self.model.eval()

        with torch.no_grad():
            # Loop batches
            for idx_batch, (images, targets) in enumerate(dataloader):
                images = images.to(self.device)
                targets = targets.to(self.device)
                # task = torch.autograd.Variable(torch.LongTensor([t]).to(self.device), volatile=True)
                task = torch.tensor([t], requires_grad=False, device=self.device).long()

                # Forward
                if phase == 'mcl':
                    outputs, masks, outputs_attn = self.model.forward(task, images, s=self.smax, phase=phase,
                                                                      similarity=similarity,
                                                                      history_mask_pre=history_mask_pre,
                                                                      check_federated=check_federated)
                    output = outputs[t]

                    if outputs_attn is None:
                        loss = self.criterion(output, targets, masks)
                    else:
                        output_attn = outputs_attn[t]
                        loss = self.joint_criterion(output, targets, masks, output_attn)
                    # endif
                elif phase == 'transfer' or phase == 'reference':
                    outputs = self.model.forward(task, images, s=self.smax, phase=phase,
                                                 pre_mask=pre_mask, pre_task=pre_task)
                    output = outputs[t]
                    loss = self.transfer_criterion(output, targets)
                # endif
                # if phase=='mcl' and (similarity is not None and t<len(similarity) and np.count_nonzero(similarity[:t])>1 and similarity[t]==1):

                if phase == 'mcl' and 'no_attention' not in self.args.loss_type and outputs_attn is not None:
                    # print(f'evaluating {t}, output_attn is used')

                    _, att_pred = output_attn.max(1)
                    _, mask_pred = output.max(1)

                    mask_hits = (mask_pred == targets).float()
                    att_hits = (att_pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += mask_hits.sum().data.cpu().numpy().item()

                    # Log
                    total_att_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_att_acc += att_hits.sum().data.cpu().numpy().item()
                else:
                    # print(f'evaluating {t}, output_attn is None')
                    # raise NotImplementedError

                    _, pred = output.max(1)
                    hits = (pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += hits.sum().data.cpu().numpy().item()
                # endif
                total_num += targets.shape[0]
            # endfor

            if 'all-one' in self.args.similarity_detection:
                total_loss = total_att_loss
                total_acc = total_att_acc
            elif phase == 'mcl' and 'no_attention' not in self.args.loss_type:
                if total_att_acc > total_mask_acc:
                    total_loss = total_att_loss
                    total_acc = total_att_acc
                else:
                    total_loss = total_mask_loss
                    total_acc = total_mask_acc
                # endif
            else:
                total_loss = total_mask_loss
                total_acc = total_mask_acc
            # endif
            return total_loss / total_num, total_acc / total_num
        # endwith
    # enddef

    def test(self, t, phase,
             pre_mask, pre_task, similarity,
             history_mask_pre, check_federated,
             ):
        dataloader_val = self.list__dl_val[t]
        dataloader_test = self.list__dl_test[t]

        choose_att = False
        total_att_loss = 0
        total_att_acc = 0

        total_mask_loss = 0
        total_mask_acc = 0

        total_num = 0

        self.model.eval()
        with torch.no_grad():
            # Loop validation batches
            for idx_batch, (images, targets) in enumerate(dataloader_val):
                images = images.to(self.device)
                targets = targets.to(self.device)
                # task = torch.autograd.Variable(torch.LongTensor([t]).to(self.device), volatile=True)
                task = torch.tensor([t], requires_grad=False, device=self.device).long()

                # Forward
                if phase == 'mcl':
                    outputs, masks, outputs_attn = self.model.forward(task, images, s=self.smax, phase=phase,
                                                                      similarity=similarity,
                                                                      history_mask_pre=history_mask_pre,
                                                                      check_federated=check_federated)
                    output = outputs[t]
                    if outputs_attn is None:
                        loss = self.criterion(output, targets, masks)
                    else:
                        output_attn = outputs_attn[t]
                        loss = self.joint_criterion(output, targets, masks, output_attn)
                    # endif
                elif phase == 'transfer' or phase == 'reference':
                    outputs = self.model.forward(task, images, s=self.smax, phase=phase,
                                                 pre_mask=pre_mask, pre_task=pre_task)
                    output = outputs[t]
                    loss = self.transfer_criterion(output, targets)
                # endif
                # if phase=='mcl' and (similarity is not None and t<len(similarity) and np.count_nonzero(similarity[:t])>1 and similarity[t]==1):

                if phase == 'mcl' and 'no_attention' not in self.args.loss_type and outputs_attn is not None:
                    _, att_pred = output_attn.max(1)
                    _, mask_pred = output.max(1)

                    mask_hits = (mask_pred == targets).float()
                    att_hits = (att_pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += mask_hits.sum().data.cpu().numpy().item()

                    # Log
                    total_att_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_att_acc += att_hits.sum().data.cpu().numpy().item()
                else:
                    _, pred = output.max(1)
                    hits = (pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += hits.sum().data.cpu().numpy().item()
                # endif
                total_num += targets.shape[0]
            # endfor
            if 'all-one' in self.args.similarity_detection:
                choose_att = True
            elif phase == 'mcl' and 'no_attention' not in self.args.loss_type:
                if total_att_acc > total_mask_acc:
                    choose_att = True
                # endif
            # endif

            print('choose_att: ', choose_att)
            # Here simply use validation to choose attention in testing.
            # One can also remember which tasks have used the attention in training and then apply attention for testing

            # Loop batches
            for idx_batch, (images, targets) in enumerate(dataloader_test):
                images = images.to(self.device)
                targets = targets.to(self.device)
                # task = torch.autograd.Variable(torch.LongTensor([t]).to(self.device), volatile=True)
                task = torch.tensor([t], requires_grad=False, device=self.device).long()

                # Forward
                if phase == 'mcl':
                    outputs, masks, outputs_attn = self.model.forward(task, images, s=self.smax, phase=phase,
                                                                      similarity=similarity,
                                                                      history_mask_pre=history_mask_pre,
                                                                      check_federated=check_federated)
                    output = outputs[t]

                    if outputs_attn is None:
                        loss = self.criterion(output, targets, masks)
                    else:
                        output_attn = outputs_attn[t]
                        loss = self.joint_criterion(output, targets, masks, output_attn)
                    # endif
                elif phase == 'transfer' or phase == 'reference':
                    outputs = self.model.forward(task, images, s=self.smax, phase=phase,
                                                 pre_mask=pre_mask, pre_task=pre_task)
                    output = outputs[t]
                    loss = self.transfer_criterion(output, targets)
                # endif
                if phase == 'mcl' and 'no_attention' not in self.args.loss_type and outputs_attn is not None:
                    _, att_pred = output_attn.max(1)
                    _, mask_pred = output.max(1)

                    mask_hits = (mask_pred == targets).float()
                    att_hits = (att_pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += mask_hits.sum().data.cpu().numpy().item()

                    # Log
                    total_att_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_att_acc += att_hits.sum().data.cpu().numpy().item()
                else:
                    _, pred = output.max(1)
                    hits = (pred == targets).float()

                    # Log
                    total_mask_loss += loss.data.cpu().numpy().item() * targets.shape[0]
                    total_mask_acc += hits.sum().data.cpu().numpy().item()
                # endif
                total_num += targets.shape[0]
            # endfor
            if choose_att == True:
                total_loss = total_att_loss
                total_acc = total_att_acc
            else:
                total_loss = total_mask_loss
                total_acc = total_mask_acc
            # endif
            return total_loss / total_num, total_acc / total_num
        # endwith
    # enddef

    def transfer_criterion(self, outputs, targets, mask=None):
        return self.ce(outputs, targets)
    # enddef

    def joint_criterion(self, outputs, targets, masks, outputs_attn):
        return self.criterion(outputs, targets, masks) + self.args.model_weights * self.ce(outputs_attn, targets)
    # enddef

    def criterion(self, outputs, targets, masks):
        reg = 0
        count = 0

        if self.mask_pre is not None:
            for m, mp in zip(masks, self.mask_pre):
                aux = 1 - mp
                reg += (m * aux).sum()
                count += aux.sum()
            # endfor
        else:
            for m in masks:
                reg += m.sum()
                count += np.prod(m.size()).item()
            # endfor
        # endif
        reg /= count

        return self.ce(outputs, targets) + self.lamb * reg
    # enddef


# endclass


class CheckFederated():
    def __init__(self):
        pass
    def set_similarities(self, similarities):
        self.similarities = similarities

    def fix_length(self):
        return len(self.similarities)

    def get_similarities(self):
        return self.similarities

    def check_t(self, t):
        if t < len([sum(x) for x in zip_longest(*self.similarities, fillvalue=0)]) and \
                [sum(x) for x in zip_longest(*self.similarities, fillvalue=0)][t] > 0:
            return True

        elif np.count_nonzero(self.similarities[t]) > 0:
            return True

        elif t < len(self.similarities[-1]) and self.similarities[-1][t] == 1:
            return True

        return False
