import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class ESMER(Finetune):
    def __init__(self, args):
        super().__init__(args)
        
        self._ema_network = copy.deepcopy(self._network).to(self._device)
        
        # regularization weight
        self._reg_weight = args.get("regularization", 0.1)
        
        self._loss_margin = args.get("loss_margin", 1.0)
        self._loss_alpha = args.get("loss_alpha", 0.99)
        self._std_margin = args.get("std_margin", 1.0)
        
        # parameters for ema model
        self._ema_update = args["ema_model_update"]        
        self._warmup_phase = args.get("warmup_phase", 1)
        
        # Running estimates
        self._loss_running_mean = 0
        self._loss_running_std = 0 
        self._global_step = 0
        
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network

        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                inputs = input_dict

                loss = self._forward_loss(
                    training_network,
                    inputs,
                    targets,
                    task_id,
                    warmup_phase=epoch<self._warmup_phase
                )

                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()

    def _forward_loss(
        self,
        training_network,
        inputs,
        targets,
        task_id,
        warmup_phase,
        **kwargs
    ):
        self._optimizer.zero_grad()
        
        inputs = {key: item.to(self._device) for key, item in inputs.items()} 
        targets = targets.to(self._device)
        
        self._ema_network.train()    
        
        ema_outputs = self._ema_network(inputs)
        outputs = self._network(inputs)
        acc = (targets == outputs['logits'].argmax(axis=1)).float().mean().item()

        if self._task > 0:
            old_batch = (task_id != self._task).sum()
            batch_size = [len(task_id)-old_batch, old_batch]
                        
            logits, mem_logits = torch.split(outputs["logits"], batch_size)
            ema_logits, mem_ema_logits = torch.split(ema_outputs["logits"], batch_size)
            targets, mem_targets = torch.split(targets, batch_size)
        else:
            ema_logits = ema_outputs["logits"]
            logits = outputs["logits"]
        
        ema_loss = F.cross_entropy(ema_logits, targets, reduction='none')        
        cur_loss = F.cross_entropy(logits, targets, reduction='none')  
        
        if self._loss_running_mean > 0:
            sample_weight = torch.where(
                ema_loss >= self._loss_margin * self._loss_running_mean,
                self._loss_running_mean / ema_loss,
                torch.ones_like(ema_loss)
            )
            loss_ce = (sample_weight * cur_loss).mean()
        else:
            loss_ce = cur_loss.mean()
            
        self._metrics["ce"] += loss_ce.item()    
        loss = loss_ce    
            
        if self._task > 0:
            loss_con = self._reg_weight * F.mse_loss(mem_logits, mem_ema_logits.detach())
            loss_mem = F.cross_entropy(mem_logits, mem_targets)
            
            self._metrics["con"] += loss_con.item()
            self._metrics["old"] += loss_mem.item()
            
            loss += loss_con + loss_mem

        loss.backward()
        self._optimizer.step()
            
        self._global_step += 1
        if torch.rand(1) < self._ema_update['freq']:
            self._update_ema_model_variables()
            
        loss_mean, loss_std = ema_loss.mean(), ema_loss.std()
        ignore_mask = ema_loss > (loss_mean + (self._std_margin * loss_std))
        ema_loss = ema_loss[~ignore_mask]

        if not warmup_phase:
            self._update_running_loss_ema(ema_loss.detach())
            
        if bool(torch.isnan(loss).item()): #not utils.check_loss(loss):
            raise ValueError("A loss is NaN: {}".format(self._metrics))

        self._metrics["loss"] += loss.item()        
        self._metrics["acc"] += acc
        
        return loss     
    
    def _update_ema_model_variables(self):
        alpha = min(1 - 1 / (self._global_step + 1), self._ema_update["alpha"])
        ema_parameters = self._ema_network.get_group_parameters()
        parameters = self._network.get_group_parameters()
        
        for modality in ema_parameters.keys():
            for ema_param, param in zip(ema_parameters[modality], parameters[modality]):
                ema_param.data.mul_(alpha).add_(param.data, alpha=1-alpha)
        
    def _update_running_loss_ema(self, batch_loss):
        alpha = min(1 - 1 / (self._global_step + 1), self._loss_alpha)
        self._loss_running_mean = alpha * self._loss_running_mean + (1 - alpha) * batch_loss.mean()
        self._loss_running_std = alpha * self._loss_running_std + (1 - alpha) * batch_loss.std()
        
    
    def _after_task_intensive(self, inc_dataset):
        self._ema_network.eval()
        save_mask = []
        with torch.no_grad():
            for input_dict in inc_dataset.get_cur_train_loader(shuffle=False, num_workers=8):
                targets = input_dict['target'].to(self._device)
                inputs = {key: item.to(self._device) for key, item in input_dict.items() if (key!="target" and key!="task_id")} 
                ema_logits = self._ema_network(inputs)['logits']
                ema_loss = F.cross_entropy(ema_logits, targets, reduction='none')
                ignore_mask = ema_loss > self._loss_margin * self._loss_running_mean
                save_mask.append((~ignore_mask).cpu().numpy())
            
        save_mask = np.concatenate(save_mask)
        
        inc_dataset.update_exemplar(np.where(save_mask)[0])
        
    def save_parameters(self, directory, run_id):
        super().save_parameters(directory, run_id)
        
        path = os.path.join(directory, f"ema_net_{run_id}_task_{self._task}.pth")
        self._ema_network.save(path)
        
    def load_parameters(self, directory, run_id):
        super().load_parameters(directory, run_id)
        
        path = os.path.join(directory, f"ema_net_{run_id}_task_{self._task}.pth")
        if not os.path.exists(path):
            return
        
        try:
            self._ema_network.load(path)
        except Exception:
            raise ValueError("Cannot load weights")

    def save_metadata(self, directory, run_id):
        path = os.path.join(directory, f"meta_{run_id}_task_{self._task}.pkl")

        logger.info("Saving metadata at {}.".format(path))
        with open(path, "wb+") as f:
            pickle.dump(
                [self._loss_running_mean, self._loss_running_std, self._global_step],
                f
            )

    def load_metadata(self, directory, run_id):
        path = os.path.join(directory, f"meta_{run_id}_task_{self._task}.pkl")
        if not os.path.exists(path):
            return

        logger.info("Loading metadata at {}.".format(path))
        with open(path, "rb") as f:
            self._loss_running_mean, self._loss_running_std, self._global_step = pickle.load(
                f
            )
            
    def _eval_task(self, data_loader):
        self._ema_network.eval()
        ypred = []
        ytrue = []
        zid = []

        for input_dict in data_loader:
            targets = input_dict.pop("target").numpy()
            task_id = input_dict.pop("task_id").numpy()
            
            ytrue.append(targets)
            zid.append(task_id)

            inputs = {key: item.to(self._device) for key, item in input_dict.items()}
            logits = self._ema_network(inputs)["logits"].detach()

            preds = F.softmax(logits, dim=-1)
            ypred.append(preds.cpu().numpy())

        ypred = np.concatenate(ypred)
        ytrue = np.concatenate(ytrue)
        zid = np.concatenate(zid)

        return ypred, ytrue, zid




