import collections
import copy
import logging
import os
import pickle

import numpy as np
from sklearn.cluster import KMeans
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.lib.network.linear import SplitLSCLinear
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class AFC(Finetune):
    def __init__(self, args):
        super().__init__(args)

        self._nca_config = args.get("nca", {})
        self._importance_loss = args.get("importance_loss", {'lambda': 1.0})
        
        self._old_network = None
        
        self._classifier_config = args.get("classifier_config", {})
        self._network.classifier = SplitLSCLinear(self._network.features_dim, self._classifier_config.get("proxy_per_class", 10), self._device)
                
    def _before_task(self):
        self._network.classifier.add_class(self._seen_classes-self._old_seen_classes)
        if self._task > 0:
            self._imprint_weights(self.inc_dataset.get_cur_train_loader(shuffle=False, num_workers=8, drop_last=False))
        self._old_network = self._network.copy().freeze().to(self._device)

        if self._groupwise_factors:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                factor = self._groupwise_factors.get(group_name, 1.0)
                if isinstance(factor, list):
                    factor = factor[0] if self._task == 0 else factor[1]
                if factor == 0.:
                    continue
                                        
                if group_name == 'classifier' and self._task > 0:
                    params.append({"params": [self._network.classifier.fc2.weight, self._network.classifier.factor], "lr": self._lr * factor})
                    logger.info(f"Group: {group_name} (new class), lr: {self._lr * factor}.")
                    logger.info(f"Group: {group_name} (old class), lr: 0.")
                else:    
                    params.append({"params": group_params, "lr": self._lr * factor})
                    logger.info(f"Group: {group_name}, lr: {self._lr * factor}.")
        else:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                if group_name == 'classifier' and self._task > 0:
                    params.append({"params": [self._network.classifier.fc2.weight, self._network.classifier.factor], "lr": self._lr})
                    logger.info(f"Group: {group_name} (new class), lr: {self._lr}.")
                    logger.info(f"Group: {group_name} (old class), lr: 0.")
                else:    
                    params.append({"params": group_params, "lr": self._lr})
                    logger.info(f"Group: {group_name}, lr: {self._lr}.")

        self._optimizer = factory.get_optimizer(
            params, self._opt_name, self._lr, self._weight_decay
        )

        self._scheduler = factory.get_lr_scheduler(
            self._scheduling,
            self._optimizer,
            nb_epochs=self._n_epochs,
            lr_decay=self._lr_decay,
            task=self._task
        )
        
    def _imprint_weights(self, data_loader):
        self._network.eval()

        all_features = []
        all_targets = []
        with torch.no_grad():
            for input_dict in data_loader:
                all_targets.append(input_dict['target'])
                inputs = {key: item.to(self._device) for key, item in input_dict.items() if (key!="target" and key!="task_id")} 
                all_features.append(self._network(inputs)['features_fused'].detach().cpu())            
        all_features = torch.cat(all_features)
        all_features = F.normalize(all_features, p=2, dim=1)
        all_targets = torch.cat(all_targets)

        weights_norm = self._network.classifier.fc1.weight.data.norm(dim=1, keepdim=True).cpu()
        avg_weights_norm = torch.mean(weights_norm, dim=0)

        new_weights = []
        for cls in range(self._old_seen_classes, self._seen_classes):
            class_features = all_features[all_targets==cls]
            clusterizer = KMeans(n_clusters=self._classifier_config.get("proxy_per_class", 10), n_init='auto')
            clusterizer.fit(class_features.numpy())

            for center in clusterizer.cluster_centers_:
                new_weights.append(torch.tensor(center) * avg_weights_norm)

        new_weights = torch.stack(new_weights)
        self._network.classifier.fc2.weight.data = new_weights.to(self._device)


    def _compute_loss(self, inputs, outputs, targets, task_id): 
        if self._task == 0:
            loss = losses.nca(outputs["logits"], targets, scale=self._network.classifier.factor, **self._nca_config)
        
        else:
            old_outputs = self._old_network(inputs)
            
            for modality in self._network.modalities:
                outputs[f'features_{modality}'].retain_grad()
                old_outputs[f'features_{modality}'].retain_grad()
            outputs['features_fused'].retain_grad()
            old_outputs['features_fused'].retain_grad()
            
            prev_loss = losses.nca(old_outputs["logits"], targets, scale=self._old_network.classifier.factor)
            prev_loss.backward(retain_graph=True)
            
            all_imp = {}
            for modality in self._network.modalities:
                all_imp[modality] = torch.norm(old_outputs[f'features_{modality}'].grad, p='fro', dim=-1)
            all_imp['fused'] = torch.norm(old_outputs[f'features_fused'].grad, p='fro', dim=-1)
            
            ttl_imp = 0
            for key, imp in all_imp.items():
                ttl_imp += imp
            for key, imp in all_imp.items():
                all_imp[key] /= ttl_imp
            
            loss_imp = 0
            for key in all_imp.keys():
                loss_imp += torch.mean(all_imp[key] * torch.norm(F.normalize(outputs[f"features_{key}"]) - F.normalize(old_outputs[f"features_{key}"].detach()), p='fro', dim=-1))
            loss_imp *= np.sqrt(self._seen_classes/(self._seen_classes-self._old_seen_classes)) * self._importance_loss['lambda']
            
            loss_clf = losses.nca(outputs["logits"], targets, scale=self._network.classifier.factor, **self._nca_config)
            
            self._metrics["clf"] += loss_clf.item()   
            self._metrics["imp"] += loss_imp.item()   
            
            loss = loss_imp + loss_clf
                        
        return loss
    
    def _after_task_intensive(self, inc_dataset):
        targets = []
        features = []        
        with torch.no_grad():
            for input_dict in inc_dataset.get_cur_train_loader(shuffle=False, num_workers=8, drop_last=False):
                targets.append(input_dict.pop("target").numpy())
                input_dict.pop("task_id").numpy()

                inputs = {key: item.to(self._device) for key, item in input_dict.items()}
                features.append(self._network.extract(inputs)['features_fused'].detach().cpu().numpy())
        inc_dataset.update_exemplar_by_herding(np.concatenate(features), np.concatenate(targets))
                
    def _eval_task(self, data_loader):
        ypred = []
        ytrue = []
        zid = []

        for input_dict in data_loader:
            targets = input_dict.pop("target").numpy()
            task_id = input_dict.pop("task_id").numpy()
            
            ytrue.append(targets)
            zid.append(task_id)

            inputs = {key: item.to(self._device) for key, item in input_dict.items()}
            logits = (self._network.classifier.factor * self._network(inputs)["logits"]).detach()

            preds = F.softmax(logits, dim=-1)
            ypred.append(preds.cpu().numpy())

        ypred = np.concatenate(ypred)
        ytrue = np.concatenate(ytrue)
        zid = np.concatenate(zid)

        return ypred, ytrue, zid   




