import collections
from collections import OrderedDict
import copy
import logging
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.lib.network.mlp import MLP
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class Ours(Finetune):
    def __init__(self, args):
        super().__init__(args)

        self._global_step = 0
        self._ema_update = args.get("ema_model_update", None)
        if self._ema_update:
            self._ema_network = copy.deepcopy(self._network).to(self._device)
        else:
            self._ema_network = None
               
        self._contra_loss = args.get("contrastive_loss", {'lambda': 1.0})
        self._distil_loss = args.get("distil_loss", {'lambda': 1.0})

        if self._contra_loss['lambda'] > 0:
            self._forward_projection = {}
        for modality in self._network.modalities:
            dim = self._network.encoders[modality].out_dim
            if self._contra_loss['lambda'] > 0:
                self._forward_projection[modality] = MLP(input_dim=dim, hidden_dims=[2048,self._network.features_dim], use_bn=False, normalize=True).to(self._device)
            
    def _before_task(self):            
        if self._groupwise_factors:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                factor = self._groupwise_factors.get(group_name, 1.0)
                if isinstance(factor, list):
                    factor = factor[0] if self._task == 0 else factor[1]
                if factor == 0.:
                    continue
                params.append({"params": group_params, "lr": self._lr * factor})
                logger.info(f"Group: {group_name}, lr: {self._lr * factor}.")
        else:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                params.append({"params": group_params, "lr": self._lr})
                logger.info(f"Group: {group_name}, lr: {self._lr}.")

        for modality in self._network.modalities:
            if self._contra_loss['lambda'] > 0:
                params.append({'params': self._forward_projection[modality].parameters(), "lr": self._lr * self._groupwise_factors.get(f"{modality}_forward_project", 1.0)})

        self._optimizer = factory.get_optimizer(
            params, self._opt_name, self._lr, self._weight_decay
        )

        self._scheduler = factory.get_lr_scheduler(
            self._scheduling,
            self._optimizer,
            nb_epochs=self._n_epochs,
            lr_decay=self._lr_decay,
            task=self._task
        )
        
        
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network

        if self._ema_network:
            self._ema_network.train()   
            
        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                inputs = input_dict

                self._optimizer.zero_grad()
                loss = self._forward_loss(
                    training_network,
                    inputs,
                    targets,
                    task_id
                )
                loss.backward()
                self._optimizer.step()

                self._global_step += 1
                if self._ema_network:
                    self._update_ema_model_variables()
                
                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()
                
    def _compute_loss(self, inputs, outputs, targets, task_id):   
        if self._ema_network:
            old_outputs_mem = self._ema_network(inputs)
            
        loss = F.cross_entropy(outputs["logits"], targets)      
        self._metrics["clf"] += loss.item()   
        
        fused = F.normalize(outputs["features_fused"], dim=1)
        
        if self._contra_loss['lambda'] > 0:
            loss_contrastive = 0
            for modality in self._network.modalities:
                forward_proj = self._forward_projection[modality](outputs[f"features_{modality}"])

                fused_proj = torch.stack([forward_proj, fused], dim=1)
                
                loss_contrastive += losses.sup_con_current(fused_proj, task_id==self._task, labels=targets, contrast_mode='one', 
                                                             temperature=self._contra_loss['temperature'], base_temperature=self._contra_loss['temperature'])
                                    
            loss_contrastive /= len(self._network.modalities)
    
            if self._task > 0 and self._contra_loss['past_lambda'] > 0:
                old_task = task_id<self._task
                mask = torch.scatter(
                    torch.ones(old_task.sum(), old_task.sum()),
                    1,
                    torch.arange(old_task.sum()).view(-1, 1),
                    0
                ).to(self._device)
                
                ema_fused = F.normalize(old_outputs_mem['features_fused'][old_task], dim=1)
                ema_fused_sim = torch.div(torch.matmul(ema_fused, ema_fused.T), self._contra_loss['temperature'])[mask.bool()].view(old_task.sum(), -1)

                cur_fused = fused[old_task]
                cur_fused_sim = torch.div(torch.matmul(cur_fused, cur_fused.T), self._contra_loss['temperature'])[mask.bool()].view(old_task.sum(), -1)
                
                loss_contrastive_old = F.kl_div(F.log_softmax(cur_fused_sim, dim=1), F.softmax(ema_fused_sim, dim=1).detach(), reduction='batchmean') #* (0.07**2)
                loss_contrastive_old *= self._contra_loss['past_lambda'] * (self._old_seen_classes/self._seen_classes)
                loss_contrastive += loss_contrastive_old
    
                self._metrics["con_old"] += loss_contrastive_old.item()
        
            loss_contrastive *= self._contra_loss['lambda']
            self._metrics["con"] += loss_contrastive.item()   
            loss += loss_contrastive
            
        if self._task > 0 and self._distil_loss['lambda'] > 0 and self._ema_network:
            old_batch = (task_id != self._task).sum()
            batch_size = [len(task_id)-old_batch, old_batch]
                
            loss_dis = self._distil_loss['lambda'] * np.sqrt(self._task) * F.mse_loss(outputs["logits"][-batch_size[1]:], old_outputs_mem["logits"][-batch_size[1]:].detach())
            self._metrics["dis"] += loss_dis.item()   

            loss += loss_dis
                        
        return loss

    def _update_ema_model_variables(self):
        alpha = min(1 - 1 / (self._global_step + 1), self._ema_update["alpha"])
        ema_parameters = self._ema_network.get_group_parameters()
        parameters = self._network.get_group_parameters()
        
        for modality in ema_parameters.keys():
            for ema_param, param in zip(ema_parameters[modality], parameters[modality]):
                ema_param.data.mul_(alpha).add_(param.data, alpha=1-alpha)
            
    def _after_task_intensive(self, inc_dataset):
        inc_dataset.update_exemplar()
        
    def save_parameters(self, directory, run_id):
        super().save_parameters(directory, run_id)
        
        if self._ema_network:
            path = os.path.join(directory, f"ema_net_{run_id}_task_{self._task}.pth")
            self._ema_network.save(path)
        
        if self._contra_loss['lambda'] > 0:
            path = os.path.join(directory, f"net_{run_id}_task_{self._task}.pth")
            save_states = torch.load(path)

            for modality in self._network.modalities:
                if self._contra_loss['lambda'] > 0:
                    save_states[f"forward_proj_{modality}"] = self._forward_projection[modality].state_dict()
            
            torch.save(save_states, path)

    def load_parameters(self, directory, run_id):
        path = os.path.join(directory, f"net_{run_id}_task_{self._task}.pth")
        if not os.path.exists(path):
            return
        
        logger.info(f"Loading model at {path}.")
        try:
            self.network.load(path)

            if self._contra_loss['lambda'] > 0:
                save_states = torch.load(path)
                
                for modality in self._network.modalities:
                    if self._contra_loss['lambda'] > 0:
                        self._forward_projection[modality].load_state_dict(save_states[f"forward_proj_{modality}"])

        except Exception:
            raise ValueError("Cannot load weights")
        
        if self._ema_network:
            path = os.path.join(directory, f"ema_net_{run_id}_task_{self._task}.pth")
            if not os.path.exists(path):
                return

            try:
                self._ema_network.load(path)
            except Exception:
                raise ValueError("Cannot load weights")
            
    def save_metadata(self, directory, run_id):
        path = os.path.join(directory, f"meta_{run_id}_task_{self._task}.pkl")

        logger.info("Saving metadata at {}.".format(path))
        with open(path, "wb+") as f:
            pickle.dump(
                [self._global_step],
                f
            )

    def load_metadata(self, directory, run_id):
        path = os.path.join(directory, f"meta_{run_id}_task_{self._task}.pkl")
        if not os.path.exists(path):
            return

        logger.info("Loading metadata at {}.".format(path))
        with open(path, "rb") as f:
            self._global_step = pickle.load(
                f
            )[0]
            
    def _eval_task(self, data_loader):
        if self._ema_network:
            self._ema_network.eval()
            
        ypred = []
        ytrue = []
        zid = []

        for input_dict in data_loader:
            targets = input_dict.pop("target").numpy()
            task_id = input_dict.pop("task_id").numpy()
            
            ytrue.append(targets)
            zid.append(task_id)

            inputs = {key: item.to(self._device) for key, item in input_dict.items()}
            if self._ema_network:
                logits = self._ema_network(inputs)["logits"].detach()
            else:
                logits = self._network(inputs)["logits"].detach()

            preds = F.softmax(logits, dim=-1)
            ypred.append(preds.cpu().numpy())

        ypred = np.concatenate(ypred)
        ytrue = np.concatenate(ytrue)
        zid = np.concatenate(zid)

        return ypred, ytrue, zid            
