import math
import time
from argparse import Namespace
from copy import deepcopy
from typing import *

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader

import utils
from approaches.ucl import ucl_utils
from approaches.ucl.bayes_layer import BayesianConv2D, BayesianLinear, _calculate_fan_in_and_fan_out
from utils import print_num_params


class Appr(object):
    def __init__(self, model: nn.Module, inputsize: Tuple[int, ...], device: str, list__dl_val: List[DataLoader],
                 epochs_max: int, sbatch: int, patience_max: int,
                 lr: float, lr_min: float, lr_factor: float,
                 lamb: float, log_name: str, args: Namespace,
                 log_dir: str,
                 clipgrad=100, split=True):
        self.device = device
        self.inputsize = inputsize

        self.list__dl_val = list__dl_val

        self.model = model
        self.model_old = deepcopy(self.model).to(self.device)

        file_name = log_name
        self.logger = ucl_utils.logger(file_name=file_name, resume=False, path=log_dir, data_format='csv')

        self.nepochs = epochs_max
        self.sbatch = sbatch
        self.lr = lr
        # self.lr_rho = args.lr_rho
        # self.lr_min = lr / (lr_factor ** 5)
        self.lr_min = lr_min
        self.lr_factor = lr_factor
        # self.lr_patience = 5
        self.lr_patience = patience_max
        self.clipgrad = clipgrad
        self.args = args
        self.iteration = 0
        self.epoch = 0
        self.saved = 0
        self.split = split
        self.beta = args.beta

        # self.drop = [20, 40, 60, 75, 90]

        self.param_name = []

        for (name, p) in self.model.named_parameters():
            self.param_name.append(name)
        # endfor

        # self.optimizer = self._get_optimizer(self.lr)
        self.lamb = lamb

        print_num_params(self.model)
        return
    # enddef

    def _get_optimizer(self, lr: float):
        return torch.optim.SGD(self.model.parameters(), lr=lr)
        # if lr is None:
        #     lr = self.lr
        # if lr_rho is None:
        #     lr_rho = self.lr_rho
        # if args.optimizer == 'Adam':
        #     return Adam(self.model.parameters(), lr=lr, lr_rho=lr_rho, param_name=self.param_name)
        # if args.optimizer == 'SGD':
        #     return torch.optim.SGD(self.model.parameters(), lr=lr)
    # enddef

    def train(self, t: int, dl_train: DataLoader, dl_val: DataLoader):
        self.model = self.model.to(self.device)

        best_loss = np.inf
        best_model = ucl_utils.get_model(self.model)
        # lr = self.lr
        # lr_rho = self.lr_rho
        # patience = self.lr_patience
        patience = 0
        self.optimizer = self._get_optimizer(self.lr)

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                         mode='min',
                                                         factor=1.0 / self.lr_factor,
                                                         patience=max(self.lr_patience - 1, 0),
                                                         min_lr=self.lr_min,
                                                         verbose=True,
                                                         )

        # Loop epochs
        for e in range(self.nepochs):
            self.epoch = self.epoch + 1
            # Train
            clock0 = time.time()

            # num_batch = xtrain.size(0)
            num_batch = len(dl_train)

            self.train_epoch(t, dl_train)

            clock1 = time.time()
            train_loss, train_acc = self.eval(t, dl_train)

            clock2 = time.time()
            print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format(
                e + 1, 1000 * self.sbatch * (clock1 - clock0) / num_batch,
                1000 * self.sbatch * (clock2 - clock1) / num_batch, train_loss, 100 * train_acc), end='')
            # Valid

            valid_loss, valid_acc = self.eval(t, dl_val)
            print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, 100 * valid_acc), end='')

            # save log for current task & old tasks at every epoch
            self.logger.add(epoch=(t * self.nepochs) + e, task_num=t + 1, valid_loss=valid_loss, valid_acc=valid_acc)
            for task in range(t):
                dl_valid_t = self.list__dl_val[task]

                valid_loss_t, valid_acc_t = self.eval(task, dl_valid_t)
                self.logger.add(epoch=(t * self.nepochs) + e, task_num=task + 1, valid_loss=valid_loss_t,
                                valid_acc=valid_acc_t)
            # endfor

            if self.epoch == 1 or valid_loss < best_loss:
                best_loss = valid_loss
                best_model = ucl_utils.get_model(self.model)
                # patience = self.lr_patience
                patience = 0
                print(' *', end='')
            else:
                if utils.get_current_lr(self.optimizer) <= self.lr_min:
                    patience += 1
                else:
                    patience = 0
                # endif

                # patience -= 1
                # if patience <= 0:
                #     lr /= self.lr_factor
                #     # lr_rho /= self.lr_factor
                #     print(' lr={:.1e}'.format(lr), end='')
                #     if lr < self.lr_min:
                #         print()
                #     # endif
                #     patience = self.lr_patience
                #     self.optimizer = self._get_optimizer(lr)
                # # endif
            # endif

            if patience >= self.lr_patience or self.epoch == self.nepochs:
                break
            # endif

            scheduler.step(valid_loss)
            print()

            ucl_utils.freeze_model(self.model_old)  # Freeze the weights
        # endfor

        # Restore best
        ucl_utils.set_model_(self.model, best_model)
        self.model_old = deepcopy(self.model)
        self.saved = 1

        self.logger.save()

        return
    # enddef

    def train_epoch(self, t: int, dl_train: DataLoader):
        self.model.train()

        # Loop batches
        for x, y in dl_train:
            images = x.to(self.device)
            targets = y.to(self.device)

            if self.split:
                output = F.log_softmax(self.model(images, sample=True)[t], dim=1)
            else:
                output = self.model(images, sample=True)
            # endif

            loss = F.nll_loss(output, targets, reduction='sum')
            loss = self.custom_regularization(self.model_old, self.model, self.sbatch, loss)
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            # if args.optimizer == 'SGD' or args.optimizer == 'SGD_momentum_decay':
            torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clipgrad)
            self.optimizer.step()
        # endfor

        return
    # enddef

    def eval(self, t: int, dl: DataLoader):
        total_loss = 0
        total_acc = 0
        total_num = 0
        self.model.eval()

        # Loop batches
        for x, y in dl:
            images = x.to(self.device)
            targets = y.to(self.device)

            if self.split:
                output = F.log_softmax(self.model(images, sample=False)[t], dim=1)
            else:
                output = self.model(images, sample=False)
            # endif

            loss = F.nll_loss(output, targets, reduction='sum')

            _, pred = output.max(1)
            hits = (pred == targets).float()

            total_loss += loss.data.cpu().numpy()
            total_acc += hits.sum().data.cpu().numpy()
            total_num += targets.shape[0]
        # endfor

        return total_loss / total_num, total_acc / total_num
    # enddef

    def criterion(self, t, output, targets):
        # Regularization for all previous tasks
        loss_reg = 0
        if t > 0:
            for (name, param), (_, param_old) in zip(self.model.named_parameters(), self.model_old.named_parameters()):
                loss_reg += torch.sum(self.fisher[name] * (param_old - param).pow(2)) / 2
            # endfor
        # endif

        return self.ce(output, targets) + self.lamb * loss_reg
    # enddef

    # custom regularization
    def custom_regularization(self, saver_net, trainer_net, mini_batch_size, loss=None):
        sigma_weight_reg_sum = 0
        sigma_bias_reg_sum = 0
        sigma_weight_normal_reg_sum = 0
        sigma_bias_normal_reg_sum = 0
        mu_weight_reg_sum = 0
        mu_bias_reg_sum = 0
        L1_mu_weight_reg_sum = 0
        L1_mu_bias_reg_sum = 0

        out_features_max = 512
        alpha = self.args.alpha
        if self.saved:
            alpha = 1
        # endif

        if self.args.conv_net:
            # prev_weight_strength = nn.Parameter(torch.Tensor(3, 1, 1, 1).uniform_(0, 0)).to(self.device)
            prev_weight_strength = nn.Parameter(torch.Tensor(self.inputsize[0], 1, 1, 1).uniform_(0, 0)).to(self.device)

            # if self.args.experiment == 'omniglot':
            #     prev_weight_strength = nn.Parameter(torch.Tensor(1, 1, 1, 1).uniform_(0, 0))
            # else:
            #     prev_weight_strength = nn.Parameter(torch.Tensor(3, 1, 1, 1).uniform_(0, 0))
            # # endif
        else:
            # prev_weight_strength = nn.Parameter(torch.Tensor(28 * 28, 1).uniform_(0, 0)).to(self.device)
            dimtotal = self.inputsize[0] * self.inputsize[1] * self.inputsize[2]
            prev_weight_strength = nn.Parameter(torch.Tensor(dimtotal, 1).uniform_(0, 0)).to(self.device)
        # endif

        for (_, saver_layer), (_, trainer_layer) in zip(saver_net.named_children(), trainer_net.named_children()):
            if isinstance(trainer_layer, BayesianLinear) == False and isinstance(trainer_layer, BayesianConv2D) == False:
                continue
            # endif

            # calculate mu regularization
            trainer_weight_mu = trainer_layer.weight_mu
            saver_weight_mu = saver_layer.weight_mu
            trainer_bias = trainer_layer.bias
            saver_bias = saver_layer.bias

            fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_weight_mu)

            trainer_weight_sigma = torch.log1p(torch.exp(trainer_layer.weight_rho))
            saver_weight_sigma = torch.log1p(torch.exp(saver_layer.weight_rho))

            if isinstance(trainer_layer, BayesianLinear):
                std_init = math.sqrt((2 / fan_in) * self.args.ratio)
            # endif
            if isinstance(trainer_layer, BayesianConv2D):
                std_init = math.sqrt((2 / fan_out) * self.args.ratio)
            # endif

            saver_weight_strength = (std_init / saver_weight_sigma)

            if len(saver_weight_mu.shape) == 4:
                out_features, in_features, _, _ = saver_weight_mu.shape
                curr_strength = saver_weight_strength.expand(out_features, in_features, 1, 1)
                prev_strength = prev_weight_strength.permute(1, 0, 2, 3).expand(out_features, in_features, 1, 1)
            else:
                out_features, in_features = saver_weight_mu.shape
                curr_strength = saver_weight_strength.expand(out_features, in_features)
                if len(prev_weight_strength.shape) == 4:
                    feature_size = in_features // (prev_weight_strength.shape[0])
                    prev_weight_strength = prev_weight_strength.reshape(prev_weight_strength.shape[0], -1)
                    prev_weight_strength = prev_weight_strength.expand(prev_weight_strength.shape[0], feature_size)
                    prev_weight_strength = prev_weight_strength.reshape(-1, 1)
                # endif

                prev_strength = prev_weight_strength.permute(1, 0).expand(out_features, in_features)
            # endif

            L2_strength = torch.max(curr_strength, prev_strength)
            bias_strength = torch.squeeze(saver_weight_strength)

            L1_sigma = saver_weight_sigma
            bias_sigma = torch.squeeze(saver_weight_sigma)

            prev_weight_strength = saver_weight_strength

            mu_weight_reg = (L2_strength * (trainer_weight_mu - saver_weight_mu)).norm(2) ** 2
            mu_bias_reg = (bias_strength * (trainer_bias - saver_bias)).norm(2) ** 2

            L1_mu_weight_reg = (torch.div(saver_weight_mu ** 2, L1_sigma ** 2) * (trainer_weight_mu - saver_weight_mu)).norm(1)
            L1_mu_bias_reg = (torch.div(saver_bias ** 2, bias_sigma ** 2) * (trainer_bias - saver_bias)).norm(1)

            L1_mu_weight_reg = L1_mu_weight_reg * (std_init ** 2)
            L1_mu_bias_reg = L1_mu_bias_reg * (std_init ** 2)

            weight_sigma = (trainer_weight_sigma ** 2 / saver_weight_sigma ** 2)

            normal_weight_sigma = trainer_weight_sigma ** 2

            sigma_weight_reg_sum = sigma_weight_reg_sum + (weight_sigma - torch.log(weight_sigma)).sum()
            sigma_weight_normal_reg_sum = sigma_weight_normal_reg_sum + (normal_weight_sigma - torch.log(normal_weight_sigma)).sum()

            mu_weight_reg_sum = mu_weight_reg_sum + mu_weight_reg
            mu_bias_reg_sum = mu_bias_reg_sum + mu_bias_reg
            L1_mu_weight_reg_sum = L1_mu_weight_reg_sum + L1_mu_weight_reg
            L1_mu_bias_reg_sum = L1_mu_bias_reg_sum + L1_mu_bias_reg
        # endfor

        # elbo loss
        loss = loss / mini_batch_size
        # L2 loss
        loss = loss + alpha * (mu_weight_reg_sum + mu_bias_reg_sum) / (2 * mini_batch_size)
        # L1 loss
        loss = loss + self.saved * (L1_mu_weight_reg_sum + L1_mu_bias_reg_sum) / (mini_batch_size)
        # sigma regularization
        loss = loss + self.beta * (sigma_weight_reg_sum + sigma_weight_normal_reg_sum) / (2 * mini_batch_size)

        return loss
    # enddef
