import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
import scipy
import scipy.linalg
from scipy.spatial.distance import cdist
from torch import nn
from torch.nn import functional as F
from torch.utils.data import WeightedRandomSampler
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.lib.network.mlp import MLP
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class Co2L(Finetune):
    def __init__(self, args):
        super().__init__(args)
        self._args = args
        
        self._proj = MLP(input_dim=self._network.features_dim, hidden_dims=[self._network.features_dim,128], use_bn=False, normalize=True).to(self._device)
        self._temp = args.get("temp", 0.5)
        self._current_temp = args.get("current_temp", 0.2)
        self._past_temp = args.get("past_temp", 0.01)
        self._distill_lambda = args.get("distill_lambda", 1.0)
        
        self._linear_config = args.get("linear_config")
        
    def _before_task(self):   
        if self._groupwise_factors:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                factor = self._groupwise_factors.get(group_name, 1.0)
                if isinstance(factor, list):
                    factor = factor[0] if self._task == 0 else factor[1]
                if factor == 0.:
                    continue
                if group_name == 'classifier':
                    continue
                params.append({"params": group_params, "lr": self._lr * factor})
                logger.info(f"Group: {group_name}, lr: {self._lr * factor}.")
        else:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                if group_name == 'classifier':
                    continue
                params.append({"params": group_params, "lr": self._lr})
                logger.info(f"Group: {group_name}, lr: {self._lr}.")
        params.append({"params": self._proj.parameters(), "lr": self._lr})
                
        self._optimizer = factory.get_optimizer(
            params, self._opt_name, self._lr, self._weight_decay
        )

        self._scheduler = factory.get_lr_scheduler(
            self._scheduling,
            self._optimizer,
            nb_epochs=self._n_epochs,
            lr_decay=self._lr_decay,
            task=self._task
        )
                
    def _train_task(self, train_loader):
        logger.debug("nb {}.".format(len(train_loader.dataset)))
        self._training_step(train_loader, 0, self._n_epochs)
        
        cur_dataset = train_loader.dataset        
        ut, uc = np.unique(cur_dataset.targets, return_counts=True)
        weights = np.array([0.] * len(cur_dataset))
        for t, c in zip(ut, uc):
            weights[cur_dataset.targets == t] = 1./c
        train_sampler = WeightedRandomSampler(torch.Tensor(weights), len(weights))
        linear_train_loader = torch.utils.data.DataLoader(
            cur_dataset, batch_size=self._linear_config['batch_size'], shuffle=(train_sampler is None),
            num_workers=self._args['workers'], pin_memory=True, sampler=train_sampler)
        self._linear_finetune(linear_train_loader)
            
    def _forward_loss(
        self,
        training_network,
        inputs,
        targets,
        task_id,
        **kwargs
    ):
        
        inputs = {key: torch.cat(item, dim=0).to(self._device) for key, item in inputs.items()} 
        targets = targets.to(self._device)
        
        outputs = training_network(inputs)

        loss = self._compute_loss(inputs, outputs, targets, task_id, **kwargs)            
        if bool(torch.isnan(loss).item()): #not utils.check_loss(loss):
            raise ValueError("A loss is NaN: {}".format(self._metrics))

        self._metrics["loss"] += loss.item()        

        return loss                    

    
    def _compute_loss(self, inputs, outputs, targets, task_id): 
        bsz = len(task_id)
        old_task_mask = task_id != self._task
        old_task_mask = old_task_mask.repeat(2)
        
        encoded = outputs['features_fused']
        features = self._proj(encoded)
        
        loss_distill = 0
        if self._task > 0:
            features1_prev_task = features
            features1_sim = torch.div(torch.matmul(features1_prev_task, features1_prev_task.T), self._current_temp)
            logits_mask = torch.scatter(
                torch.ones_like(features1_sim),
                1,
                torch.arange(features1_sim.size(0)).view(-1, 1).cuda(non_blocking=True),
                0
            )
            logits_max1, _ = torch.max(features1_sim * logits_mask, dim=1, keepdim=True)
            features1_sim = features1_sim - logits_max1.detach()
            row_size = features1_sim.size(0)
            logits1 = torch.exp(features1_sim[logits_mask.bool()].view(row_size, -1)) / torch.exp(features1_sim[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)

            with torch.no_grad():
                features2_prev_task = self._old_proj(self._old_network(inputs)['features_fused'])

                features2_sim = torch.div(torch.matmul(features2_prev_task, features2_prev_task.T), self._past_temp)
                logits_max2, _ = torch.max(features2_sim*logits_mask, dim=1, keepdim=True)
                features2_sim = features2_sim - logits_max2.detach()
                logits2 = torch.exp(features2_sim[logits_mask.bool()].view(row_size, -1)) /  torch.exp(features2_sim[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)
            
            loss_distill = self._distill_lambda * (-logits2 * torch.log(logits1)).sum(1).mean()
            self._metrics['dis'] += loss_distill.item()
            
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)                
        loss = losses.sup_con_current(features, task_id==self._task, labels=targets, temperature=self._temp)
        self._metrics['con'] += loss.item()
                              
        loss += loss_distill
        
        return loss
    
    def _after_task(self):
        self._old_network = self._network.copy().freeze().to(self._device)
        self._old_proj = copy.deepcopy(self._proj).to(self._device)
        self._old_proj.eval()
        self._network.on_task_end()
    
    def _after_task_intensive(self, inc_dataset):
        inc_dataset.update_exemplar() 
        
    def _linear_finetune(self, train_loader):
        logger.info("Train linear classifier")
        
        self._network.fusion.eval()
        for _, e in self._network.encoders.items():
            e.eval()
            
        nb_epochs = self._linear_config["epochs"]    
        
        self._network.classifier.train()
        optimizer = factory.get_optimizer(
            self._network.classifier.parameters(), self._opt_name, self._linear_config["lr"], self._weight_decay
        )
        scheduler = factory.get_lr_scheduler(
            self._linear_config.get('scheduling'), optimizer, 
            nb_epochs, lr_decay=self._linear_config.get('scheduling')
        )
                        
        for epoch in range(nb_epochs):
            self._metrics = collections.defaultdict(float)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                inputs = {key: item[0].to(self._device) for key, item in input_dict.items()} 
                targets = targets.to(self._device)

                with torch.no_grad():
                    features = self._network.extract(inputs)['features_fused'].detach()
                    
                optimizer.zero_grad()
                logits = self._network.classifier(features)
                loss = F.cross_entropy(logits, targets)      
                loss.backward()
                optimizer.step()
                
                self._metrics["clf"] += loss.item()   
                self._metrics["acc"] += (targets == logits.argmax(axis=1)).float().mean().item()

                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)
                
            if scheduler:
                scheduler.step()
        self._network.classifier.eval()




