import json
import os
import tempfile
import time
from copy import deepcopy
from typing import *

import mlflow
import numpy as np
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader

import utils
from approaches.hat.model_hat import ModelHAT
from approaches.param_consumable import ParamConsumable
from utils import myprint as print, print_num_params


class Appr:
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int, lamb: float,
                 nhid: int, drop1: float, drop2: float, backbone: str):
        self.device = device

        # dataloader
        self.list__ncls = list__ncls
        self.inputsize = inputsize

        # variables
        m = 10
        self.lr = lr / m
        self.lr_factor = lr_factor
        self.lr_min = lr_min / m
        self.epochs_max = epochs_max
        self.patience_max = patience_max
        self.smax = 1
        self.lamb = lamb

        # misc
        self.criterion = nn.CrossEntropyLoss()
        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=self.smax, hat_enabled=False,
                              nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone).to(self.device)

        self.w = {}
        for n, p in self.model.feature.named_parameters():
            self.w[n] = p.clone().detach().zero_()
        # endfor
        # The initial_params will only be used in the first task (when the regularization_terms is empty)
        self.initial_params = {}
        for n, p in self.model.feature.named_parameters():
            self.initial_params[n] = p.clone().detach()
        # endfor
        self.regularization_terms = {}
        self.task_count = 0
        self.online_reg = True
        self.damping_factor = 0.1

        print_num_params(self.model)
    # enddef

    def compute_param_consumed(self, idx_task: int) -> float:
        return 0
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any], reg: bool) -> Tensor:
        loss = self.criterion(output, target)

        if reg:
            # Calculate the reg_loss only when the regularization_terms exists
            reg_loss = 0
            for i, reg_term in self.regularization_terms.items():
                task_reg_loss = 0
                importance = reg_term['importance']
                task_param = reg_term['task_param']
                for n, p in self.model.feature.named_parameters():
                    task_reg_loss += (importance[n] * (p - task_param[n]) ** 2).sum()
                # endfor
                reg_loss += task_reg_loss
            # endfor

            loss += self.lamb * reg_loss
        # endif

        return loss
    # enddef

    def on_after_train_epoch(self):
        pass
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        # optimizer
        optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=1.0 / self.lr_factor,
                                                         patience=max(self.patience_max - 1, 0),
                                                         min_lr=self.lr_min,
                                                         verbose=True,
                                                         )
        # metric
        patience = 0
        loss_val_best = np.inf
        acc_val_best = NotImplemented
        loss_train_best = NotImplemented
        acc_train_best = NotImplemented
        state_dict_best = NotImplemented
        epoch_best = NotImplemented

        # learn by epoch
        time_start = time.time()
        for epoch in range(self.epochs_max):
            results_train = self.train_epoch(epoch=epoch,
                                             optimizer=optimizer,
                                             idx_task=idx_task,
                                             dl_train=dl_train, dl_val=dl_val,
                                             args_on_forward=args_on_forward,
                                             args_on_after_backward=args_on_after_backward)
            loss_train, acc_train = results_train['loss_train'], results_train['acc_train']
            loss_val, acc_val = results_train['loss_val'], results_train['acc_val']

            # precise_bn.update_bn_stats(self.model, dl_tmp)

            self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=False,
                                           dl_train=dl_train, dl_val=dl_val)

            # check and save
            show_msg = True
            if epoch == 0 or loss_val < loss_val_best:
                loss_train_best = loss_train
                acc_train_best = acc_train
                loss_val_best = loss_val
                acc_val_best = acc_val
                epoch_best = epoch
                state_dict_best = self.copy_model()
                patience = 0
            else:
                show_msg = False
                if utils.get_current_lr(optimizer) <= self.lr_min:
                    patience += 1
                else:
                    patience = 0
                # endif
            # endif

            if show_msg:
                msg = ' '.join([f'epoch: {epoch}/{self.epochs_max}, patience: {patience}/{self.patience_max}',
                                f'[train] loss: {loss_train_best:.4f}, acc: {acc_train_best:.4f}',
                                f'[val] loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f}',
                                ])
                print(f'{msg}')
            # endif

            # early stop
            if patience >= self.patience_max or epoch == (self.epochs_max - 1):
                print(f'Load back to epoch={epoch_best}(loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f})')
                self.load_model(state_dict_best)

                self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=True,
                                               dl_train=dl_train, dl_val=dl_val)
                break
            # endif

            '''
            if list__dl_test is not None:
                for _t, dl_test in enumerate(list__dl_test):
                    result_t = self._eval_common(_t, dl_test, {})
                    acc_t = result_t['acc']
                    print(f'[test] idx_task: {_t}, acc: {acc_t:.4f}')
                # endfor
            # endif
            '''

            scheduler.step(loss_val)
        # endfor
        time_end = time.time()
        time_consumed = time_end - time_start

        if isinstance(self, ParamConsumable):
            param_consumed = self.compute_param_consumed(idx_task)
        else:
            param_consumed = None
        # endif

        results = {
            'epoch': epoch,
            'time_consumed': time_consumed,
            'param_consumed': param_consumed,
            'loss_train': loss_train_best,
            'acc_train': acc_train_best,
            'loss_val': loss_val_best,
            'acc_val': acc_val_best,
            }

        return results
    # enddef

    def freeze_mask_on_each_epoch(self, idx_task: int, epoch: int, is_final: bool, **kwargs) -> None:
        pass
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        pass
    # enddef

    def before_blocking_params(self, optimizer: optim.Optimizer):
        pass
    # enddef

    def step_optimize(self, idx_task: int, optimizer: optim.Optimizer, blocking: Dict[str, Tensor], x: Tensor, s: float):
        optimizer.step()
    # enddef

    def train_epoch(self, epoch: int, optimizer: optim.Optimizer,
                    idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
                    args_on_forward: Dict[str, Any],
                    args_on_after_backward: Dict[str, Any]) -> Dict[str, float]:
        # train
        self.model.train()
        num_batch_train = len(dl_train)
        list__target_train, list__output_train = [], []
        loss_train = 0  # type: Tensor

        # [SI] 1.Learn the parameters for current task
        for idx_batch, (x, y) in enumerate(dl_train):
            x = x.to(self.device)  # type: Tensor
            # x.requires_grad_(True)
            y = y.to(self.device)
            s = 1 / self.smax + (self.smax - 1 / self.smax) * idx_batch / num_batch_train

            unreg_gradients = {}

            # [SI] 1.Save current parameters
            old_params = {}
            for n, p in self.model.feature.named_parameters():
                old_params[n] = p.clone().detach()
            # endfor

            args_fw = args_on_forward.copy()
            args_fw['epoch'] = epoch
            args_fw['idx_batch'] = idx_batch

            output, misc = self.model(idx_task, x, s=s, args_on_forward=args_fw)

            # [SI] 2. Collect the gradients without regularization term
            loss_tmp = self.compute_loss(output=output, target=y, misc=misc, reg=False)
            optimizer.zero_grad()
            loss_tmp.backward(retain_graph=True)
            for n, p in self.model.feature.named_parameters():
                if p.grad is not None:
                    unreg_gradients[n] = p.grad.clone().detach()
                # endif
            # endfor

            # [SI] 3. Normal update with regularization
            loss = self.compute_loss(output=output, target=y, misc=misc, reg=True)
            loss_train += loss
            list__target_train.append(y)
            list__output_train.append(output)

            # optim
            optimizer.zero_grad()
            loss.backward()
            args_bw = args_on_after_backward.copy()
            args_bw['epoch'] = epoch
            args_bw['idx_batch'] = idx_batch

            self.model.on_after_backward_emb(s=s)
            self.before_blocking_params(optimizer)
            blocking = self.model.on_after_backward_params(idx_task, s=s, args=args_bw)
            self.step_optimize(idx_task, optimizer, blocking, x, s)

            # [SI] 4. Accumulate the w
            for n, p in self.model.feature.named_parameters():
                delta = p.detach() - old_params[n]
                if n in unreg_gradients.keys():  # In multi-head network, some head could have no grad (lazy) since no loss go through it.
                    self.w[n] -= unreg_gradients[n] * delta  # w[n] is >=0
                # endif
            # endfor
        # endfor | idx_batch

        acc_train = utils.my_accuracy(torch.cat(list__target_train, dim=0),
                                      torch.cat(list__output_train, dim=0)).item()

        # val
        results_val = self._eval_common(idx_task, dl_val, args_on_forward=args_on_forward)

        # [SI] 2.Backup the weight of current task
        task_param = {}
        for n, p in self.model.feature.named_parameters():
            task_param[n] = p.clone().detach()
        # endfor

        # [SI] 3.Calculate the importance of weights for current task
        importance = self.calculate_importance()

        # [SI] Save the weight and importance of weights of current task
        self.task_count += 1
        if self.online_reg and len(self.regularization_terms) > 0:
            # Always use only one slot in self.regularization_terms
            self.regularization_terms[1] = {'importance': importance, 'task_param': task_param}
        else:
            # Use a new slot to store the task-specific information
            self.regularization_terms[self.task_count] = {'importance': importance, 'task_param': task_param}
        # endif

        results = {
            'loss_train': loss_train.item(),
            'acc_train': acc_train,
            'loss_val': results_val['loss'],
            'acc_val': results_val['acc'],
            }

        return results
    # enddef

    def calculate_importance(self):
        assert self.online_reg, 'SI needs online_reg=True'

        # Initialize the importance matrix
        if len(self.regularization_terms) > 0:  # The case of after the first task
            importance = self.regularization_terms[1]['importance']
            prev_params = self.regularization_terms[1]['task_param']
        else:  # It is in the first task
            importance = {}
            for n, p in self.model.feature.named_parameters():
                importance[n] = p.clone().detach().fill_(0)  # zero initialized
            # endfor
            prev_params = self.initial_params
        # endif

        # Calculate or accumulate the Omega (the importance matrix)
        params = {n: p for n, p in self.model.feature.named_parameters()}
        for n, p in importance.items():
            delta_theta = params[n].detach() - prev_params[n]
            p += self.w[n] / (delta_theta ** 2 + self.damping_factor)
            self.w[n].zero_()
        # endfor

        return importance
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        results_test = self._eval_common(idx_task, dl_test, args_on_forward=args_on_forward)

        results = {
            'loss_test': results_test['loss'],
            'acc_test': results_test['acc'],
            }
        return results
    # enddef

    def _eval_common(self, idx_task: int, dl: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        self.model.eval()
        list__target, list__output = [], []
        loss = 0  # type: Tensor

        with torch.no_grad():
            for idx_batch, (x, y) in enumerate(dl):
                x = x.to(self.device)
                y = y.to(self.device)

                output, misc = self.model(idx_task, x, s=self.smax, args_on_forward=args_on_forward)
                loss += self.compute_loss(output, y, misc, reg=False)
                list__target.append(y)
                list__output.append(output)
            # endfor | idx_batch
        # endwith
        acc = utils.my_accuracy(torch.cat(list__target, dim=0),
                                torch.cat(list__output, dim=0)).item()

        results = {
            'loss': loss.item(),
            'acc': acc,
            }
        return results
    # enddef

    def copy_model(self) -> Dict[str, Tensor]:
        return deepcopy(self.model.state_dict())
    # enddef

    def load_model(self, state_dict: Dict[str, Tensor]) -> None:
        self.model.load_state_dict(deepcopy(state_dict))
    # enddef

    def save_object_as_artifact(self, obj: Any, filename: str):
        with tempfile.TemporaryDirectory() as dir:
            filepath = os.path.join(dir, filename)
            with open(filepath, 'w') as fp:
                json.dump(obj, fp)
            # endwith
            mlflow.log_artifact(filepath)
        # endwith
    # enddef
# endclass
