import json
import os
import tempfile
import time
from copy import deepcopy
from typing import *

import mlflow
import numpy as np
import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader

import utils
from approaches.param_consumable import ParamConsumable
from utils import myprint as print


class AbstractAppr:
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float,
                 epochs_max: int, patience_max: int,
                 smax: float, lamb: float):
        self.device = device

        # dataloader
        self.list__ncls = list__ncls
        self.inputsize = inputsize

        # variables
        self.lr = lr
        self.lr_factor = lr_factor
        self.lr_min = lr_min
        self.epochs_max = epochs_max
        self.patience_max = patience_max
        self.smax = smax
        self.lamb = lamb

        # misc
        self.criterion = nn.CrossEntropyLoss()
        self.model = NotImplemented  # type: nn.Module
    # enddef

    def compute_loss(self, output: Tensor, target: Tensor, misc: Dict[str, Any]) -> Tensor:
        reg = misc['reg']
        return self.criterion(output, target) + self.lamb * reg
    # enddef

    def on_after_train_epoch(self):
        pass
    # enddef

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              list__dl_test: List[DataLoader] = None,
              ) -> Dict[str, float]:
        # optimizer
        optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
        # optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=1.0 / self.lr_factor,
                                                         patience=max(self.patience_max - 1, 0),
                                                         min_lr=self.lr_min,
                                                         verbose=True,
                                                         )
        # metric
        patience = 0
        loss_val_best = np.inf
        acc_val_best = -np.inf
        loss_train_best = np.inf
        acc_train_best = -np.inf
        state_dict_best = self.copy_model()
        epoch_best = 0

        # learn by epoch
        time_start = time.time()
        for epoch in range(self.epochs_max):
            results_train = self.train_epoch(epoch=epoch,
                                             optimizer=optimizer,
                                             idx_task=idx_task,
                                             dl_train=dl_train, dl_val=dl_val,
                                             args_on_forward=args_on_forward,
                                             args_on_after_backward=args_on_after_backward)
            loss_train, acc_train = results_train['loss_train'], results_train['acc_train']
            loss_val, acc_val = results_train['loss_val'], results_train['acc_val']

            # precise_bn.update_bn_stats(self.model, dl_tmp)

            self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=False,
                                           dl_train=dl_train, dl_val=dl_val)

            lr_curr = utils.get_current_lr(optimizer)
            # check and save
            show_msg = True
            if loss_val < loss_val_best:
                loss_train_best = loss_train
                acc_train_best = acc_train
                loss_val_best = loss_val
                acc_val_best = acc_val
                epoch_best = epoch
                state_dict_best = self.copy_model()
                patience = 0
            else:
                show_msg = False
                if lr_curr <= self.lr_min:
                    patience += 1
                else:
                    patience = 0
                # endif
            # endif

            if show_msg:
                msg = ' '.join([f'epoch: {epoch}/{self.epochs_max}, patience: {patience}/{self.patience_max}',
                                f'[train] loss: {loss_train_best:.4f}, acc: {acc_train_best:.4f}',
                                f'[val] loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f}',
                                ])
                print(f'{msg}')
            # endif

            # early stop
            if patience >= self.patience_max or epoch == (self.epochs_max - 1):
                print(f'Load back to epoch={epoch_best}(loss: {loss_val_best:.4f}, acc: {acc_val_best:.4f})')
                self.load_model(state_dict_best)

                self.freeze_mask_on_each_epoch(idx_task, epoch, is_final=True,
                                               dl_train=dl_train, dl_val=dl_val)
                break
            # endif

            scheduler.step(loss_val)

            if lr_curr != utils.get_current_lr(optimizer):
                pass
            # endif
            if np.isnan(loss_val) or np.isnan(loss_train):
                print(f'Loaded model at epoch={epoch_best}')
                self.load_model(state_dict_best)
            # endif
        # endfor
        time_end = time.time()
        time_consumed = time_end - time_start

        '''
        if isinstance(self, ParamConsumable):
            param_consumed = self.compute_param_consumed(idx_task)
        else:
            param_consumed = None
        # endif
        '''

        results = {
            'epoch': epoch,
            'time_consumed': time_consumed,
            # 'param_consumed': param_consumed,
            'loss_train': loss_train_best,
            'acc_train': acc_train_best,
            'loss_val': loss_val_best,
            'acc_val': acc_val_best,
            }

        return results
    # enddef

    def freeze_mask_on_each_epoch(self, idx_task: int, epoch: int, is_final: bool, **kwargs) -> None:
        pass
    # enddef

    def complete_learning(self, idx_task: int, **kwargs) -> None:
        raise NotImplementedError
    # enddef

    def before_blocking_params(self, optimizer: optim.Optimizer):
        pass
    # enddef

    def step_optimize(self, idx_task: int, optimizer: optim.Optimizer, blocking: Dict[str, Tensor], x: Tensor, s: float):
        optimizer.step()
    # enddef

    def train_epoch(self, epoch: int, optimizer: optim.Optimizer,
                    idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
                    args_on_forward: Dict[str, Any],
                    args_on_after_backward: Dict[str, Any]) -> Dict[str, float]:
        # train
        self.model.train()
        num_batch_train = len(dl_train)
        list__target_train, list__output_train = [], []
        loss_train = 0  # type: Tensor

        for idx_batch, (x, y) in enumerate(dl_train):
            x = x.to(self.device)  # type: Tensor
            # x.requires_grad_(True)
            y = y.to(self.device)
            s = 1 / self.smax + (self.smax - 1 / self.smax) * idx_batch / num_batch_train

            args_fw = args_on_forward.copy()
            args_fw['epoch'] = epoch
            args_fw['idx_batch'] = idx_batch

            output, misc = self.model(idx_task, x, s=s, args_on_forward=args_fw)
            loss = self.compute_loss(output, y, misc)
            loss_train += loss
            list__target_train.append(y)
            list__output_train.append(output)

            # optim
            optimizer.zero_grad()
            loss.backward()
            args_bw = args_on_after_backward.copy()
            args_bw['epoch'] = epoch
            args_bw['idx_batch'] = idx_batch

            self.model.on_after_backward_emb(s=s)
            self.before_blocking_params(optimizer)
            blocking = self.model.on_after_backward_params(idx_task, s=s, args=args_bw)
            self.step_optimize(idx_task, optimizer, blocking, x, s)
        # endfor | idx_batch

        acc_train = utils.my_accuracy(torch.cat(list__target_train, dim=0),
                                      torch.cat(list__output_train, dim=0)).item()

        # val
        results_val = self._eval_common(idx_task, dl_val, args_on_forward=args_on_forward)

        results = {
            'loss_train': loss_train.item(),
            'acc_train': acc_train,
            'loss_val': results_val['loss'],
            'acc_val': results_val['acc'],
            }

        return results
    # enddef

    def test(self, idx_task: int, dl_test: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        results_test = self._eval_common(idx_task, dl_test, args_on_forward=args_on_forward)

        results = {
            'loss_test': results_test['loss'],
            'acc_test': results_test['acc'],
            }
        return results
    # enddef

    def _eval_common(self, idx_task: int, dl: DataLoader, args_on_forward: Dict[str, Any]) -> Dict[str, float]:
        self.model.eval()
        list__target, list__output = [], []
        loss = 0  # type: Tensor

        with torch.no_grad():
            for idx_batch, (x, y) in enumerate(dl):
                x = x.to(self.device)
                y = y.to(self.device)

                output, misc = self.model(idx_task, x, s=self.smax, args_on_forward=args_on_forward)
                loss += self.compute_loss(output, y, misc)
                list__target.append(y)
                list__output.append(output)
            # endfor | idx_batch
        # endwith
        acc = utils.my_accuracy(torch.cat(list__target, dim=0),
                                torch.cat(list__output, dim=0)).item()

        results = {
            'loss': loss.item(),
            'acc': acc,
            }
        return results
    # enddef

    def copy_model(self) -> Dict[str, Tensor]:
        return deepcopy(self.model.state_dict())
    # enddef

    def load_model(self, state_dict: Dict[str, Tensor]) -> None:
        self.model.load_state_dict(deepcopy(state_dict))
    # enddef

    def save_object_as_artifact(self, obj: Any, filename: str):
        with tempfile.TemporaryDirectory() as dir:
            filepath = os.path.join(dir, filename)
            with open(filepath, 'w') as fp:
                json.dump(obj, fp)
            # endwith
            mlflow.log_artifact(filepath)
        # endwith
    # enddef
# endclass
