import json
import os
import tempfile
import time
from argparse import Namespace
from copy import deepcopy
from typing import *

import mlflow
import numpy as np
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader

import utils
from approaches.hat.model_hat import ModelHAT
from approaches.param_consumable import ParamConsumable
from approaches.tag.tag_update import TAG
from utils import myprint as print, print_num_params


class Appr:
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 nhid: int, drop1: float, drop2: float, backbone: str):
        self.device = device

        # dataloader
        self.list__ncls = list__ncls
        self.num_task = len(list__ncls)
        self.inputsize = inputsize

        # variables
        m = 1000
        self.lr = lr / m
        self.lr_factor = lr_factor
        self.lr_min = lr_min / m
        self.epochs_max = epochs_max
        self.patience_max = patience_max
        self.smax = 1
        self.lamb = 0

        # misc
        self.criterion = nn.CrossEntropyLoss()
        # self.model = NotImplemented  # type: nn.Module
        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=self.smax, hat_enabled=False,
                              nhid=nhid, drop1=drop1, drop2=drop2, backbone=backbone).to(self.device)
        self.tag_opt = 'rms'
        self.optimizer = TAG(model=self.model.feature,
                             args=Namespace(device=self.device),
                             num_tasks=self.num_task,
                             optim=self.tag_opt, lr=self.lr)

        print_num_params(self.model)
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any]) -> Tensor:
        reg = misc['reg']
        return self.criterion(output, target) + self.lamb * reg
    # enddef

    def on_after_train_epoch(self):
        pass
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        # metric
        patience = 0
        loss_val_best = np.inf
        acc_val_best = NotImplemented
        loss_train_best = NotImplemented
        acc_train_best = NotImplemented
        state_dict_best = NotImplemented
        epoch_best = NotImplemented

        # learn by epoch
        time_start = time.time()
        for epoch in range(self.epochs_max):
            results_train, alpha_mean \
                = self.train_epoch(epoch=epoch,
                                   optimizer=self.optimizer,
                                   idx_task=idx_task,
                                   dl_train=dl_train, dl_val=dl_val,
                                   args_on_forward=args_on_forward,
                                   args_on_after_backward=args_on_after_backward)
            loss_train, acc_train = results_train['loss_train'], results_train['acc_train']
            loss_val, acc_val = results_train['loss_val'], results_train['acc_val']

            # TAG
            alpha_val = [1.0]
            if self.tag_opt == 'rms':
                mat = np.array([alpha_mean[i] for i in alpha_mean])
                if idx_task != 1 and alpha_mean != {}:
                    alpha_val = np.round(mat.mean(axis=0), 3)
                # endif
            # endif

            # precise_bn.update_bn_stats(self.model, dl_tmp)

            self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=False,
                                           dl_train=dl_train, dl_val=dl_val)

            # check and save
            show_msg = True
            if epoch == 0 or loss_val < loss_val_best:
                loss_train_best = loss_train
                acc_train_best = acc_train
                loss_val_best = loss_val
                acc_val_best = acc_val
                epoch_best = epoch
                state_dict_best = self.copy_model()
                patience = 0
            else:
                show_msg = False
                patience += 1
                if patience >= self.patience_max:
                    if self.optimizer.lr <= self.lr_min:
                        pass
                        # patience += 1
                    else:
                        patience = 0
                        lr_prev = self.optimizer.lr
                        self.optimizer.lr /= self.lr_factor
                        lr_curr = self.optimizer.lr
                        print(f'lr: {lr_prev} -> {lr_curr}')
                    # endif
                # endif
            # endif

            if show_msg:
                msg = ' '.join([f'epoch: {epoch}/{self.epochs_max}, patience: {patience}/{self.patience_max}',
                                f'[train] loss: {loss_train_best:.4f}, acc: {acc_train_best:.4f}',
                                f'[val] loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f}',
                                ])
                print(f'{msg}')
            # endif

            # early stop
            if patience >= self.patience_max or epoch == (self.epochs_max - 1):
                print(f'Load back to epoch={epoch_best}(loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f})')
                self.load_model(state_dict_best)

                self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=True,
                                               dl_train=dl_train, dl_val=dl_val)
                break
            # endif
        # endfor
        time_end = time.time()
        time_consumed = time_end - time_start

        if isinstance(self, ParamConsumable):
            param_consumed = self.compute_param_consumed(idx_task)
        else:
            param_consumed = None
        # endif

        results = {
            'epoch': epoch,
            'time_consumed': time_consumed,
            'param_consumed': param_consumed,
            'loss_train': loss_train_best,
            'acc_train': acc_train_best,
            'loss_val': loss_val_best,
            'acc_val': acc_val_best,
            }

        return results
    # enddef

    def freeze_mask_on_each_epoch(self, idx_task: int, epoch: int, is_final: bool, **kwargs) -> None:
        pass
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        self.optimizer.update_all(idx_task)
    # enddef

    def before_blocking_params(self, optimizer: optim.Optimizer):
        pass
    # enddef

    def step_optimize(self, idx_task: int, optimizer: TAG, idx_batch: int):
        optimizer.step(model=self.model.feature, task_id=idx_task, step=idx_batch)
    # enddef

    def train_epoch(self, epoch: int, optimizer: TAG,
                    idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
                    args_on_forward: Dict[str, Any],
                    args_on_after_backward: Dict[str, Any]) -> Tuple[Dict[str, float], Dict]:
        # train
        self.model.train()
        num_batch_train = len(dl_train)
        list__target_train, list__output_train = [], []
        loss_train = 0  # type: Tensor

        if self.tag_opt == 'adam':
            optimizer_clf = optim.Adam(self.model.fc.parameters(), lr=self.lr)
        elif self.tag_opt == 'rms':
            optimizer_clf = optim.RMSprop(self.model.fc.parameters(), lr=self.lr)
        else:
            raise NotImplementedError
        # endif

        alpha_mean = {}
        for idx_batch, (x, y) in enumerate(dl_train):
            x = x.to(self.device)  # type: Tensor
            # x.requires_grad_(True)
            y = y.to(self.device)
            s = 1 / self.smax + (self.smax - 1 / self.smax) * idx_batch / num_batch_train

            args_fw = args_on_forward.copy()
            args_fw['epoch'] = epoch
            args_fw['idx_batch'] = idx_batch

            output, misc = self.model(idx_task, x, s=s, args_on_forward=args_fw)
            loss = self.compute_loss(output, y, misc)
            loss_train += loss
            list__target_train.append(y)
            list__output_train.append(output)

            # optim
            optimizer.zero_grad()
            optimizer_clf.zero_grad()
            loss.backward()
            args_bw = args_on_after_backward.copy()
            args_bw['epoch'] = epoch
            args_bw['idx_batch'] = idx_batch

            self.model.on_after_backward_emb(s=s)
            # self.before_blocking_params(optimizer)
            # blocking = self.model.on_after_backward_params(idx_task, s=s, args=args_bw)
            self.step_optimize(idx_task=idx_task, optimizer=optimizer, idx_batch=idx_batch)
            optimizer_clf.step()
            if idx_task >= 0:
                alpha_mean = store_alpha(tag_optimizer=optimizer, task_id=idx_task,
                                         iter=idx_batch, alpha_mean=alpha_mean)
            # endif
        # endfor | idx_batch

        acc_train = utils.my_accuracy(torch.cat(list__target_train, dim=0),
                                      torch.cat(list__output_train, dim=0)).item()

        # val
        results_val = self._eval_common(idx_task, dl_val, args_on_forward=args_on_forward)

        results = {
            'loss_train': loss_train.item(),
            'acc_train': acc_train,
            'loss_val': results_val['loss'],
            'acc_val': results_val['acc'],
            }

        return results, alpha_mean
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        results_test = self._eval_common(idx_task, dl_test, args_on_forward=args_on_forward)

        results = {
            'loss_test': results_test['loss'],
            'acc_test': results_test['acc'],
            }
        return results
    # enddef

    def _eval_common(self, idx_task: int, dl: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        self.model.eval()
        list__target, list__output = [], []
        loss = 0  # type: Tensor

        with torch.no_grad():
            for idx_batch, (x, y) in enumerate(dl):
                x = x.to(self.device)
                y = y.to(self.device)

                output, misc = self.model(idx_task, x, s=self.smax, args_on_forward=args_on_forward)
                loss += self.compute_loss(output, y, misc)
                list__target.append(y)
                list__output.append(output)
            # endfor | idx_batch
        # endwith
        acc = utils.my_accuracy(torch.cat(list__target, dim=0),
                                torch.cat(list__output, dim=0)).item()

        results = {
            'loss': loss.item(),
            'acc': acc,
            }
        return results
    # enddef

    def copy_model(self) -> Dict[str, Tensor]:
        return deepcopy(self.model.state_dict())
    # enddef

    def load_model(self, state_dict: Dict[str, Tensor]) -> None:
        self.model.load_state_dict(deepcopy(state_dict))
    # enddef

    def save_object_as_artifact(self, obj: Any, filename: str):
        with tempfile.TemporaryDirectory() as dir:
            filepath = os.path.join(dir, filename)
            with open(filepath, 'w') as fp:
                json.dump(obj, fp)
            # endwith
            mlflow.log_artifact(filepath)
        # endwith
    # enddef


# endclass


def store_alpha(tag_optimizer, task_id, iter, alpha_mean=None):
    """
    Collects alpha values for given task (t) and current step (n)
    :param tag_optimizer: Object of the class tag_opt()
    :param task_id: Current task identity
    :param iter: Current step in the epoch
    :return: alpha_mean: Dictionary with previous task ids as keys
    """
    for tau in tag_optimizer.alpha_add_[task_id]:
        alphas = tag_optimizer.alpha_add_[task_id][tau]
        if iter == 0:
            alpha_mean[tau] = alphas
        else:
            alpha_mean[tau] = (alpha_mean[tau] * iter + alphas) / (iter + 1)
        # endif
    # endfor

    return alpha_mean
