import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import logging
import numpy as np
from tqdm import tqdm

from methods.base import BaseLearner
from utils.toolkit import tensor2numpy, accuracy
from models.sinet_lora import SiNet_GR_LoRA
from models.vit_lora import Attention_GR_LoRA
from copy import deepcopy
from utils.schedulers import CosineSchedule
import ipdb
import optimgrad
import re
from collections import defaultdict
from utils.losses import AugmentedTripletLoss
from scipy.spatial.distance import cdist

from torch import optim

import sklearn
import os
import random
import matplotlib.pyplot as plt

from torch.distributions.multivariate_normal import MultivariateNormal
from utils.losses import CrossEntropyLoss


class GR_Lora(BaseLearner):

    def __init__(self, args):
        super().__init__(args)

        if args["net_type"] == "sip":
            self._network = SiNet_GR_LoRA(args)
        else:
            raise ValueError('Unknown net: {}.'.format(args["net_type"]))

        self.args = args
        self.optim = args["optim"]
        self.EPSILON = args["EPSILON"]
        self.init_epoch = args["init_epoch"]
        self.init_lr = args["init_lr"]
        self.init_lr_decay = args["init_lr_decay"]
        self.init_weight_decay = args["init_weight_decay"]
        self.epochs = args["epochs"]
        self.lrate = args["lrate"]
        self.lrate_decay = args["lrate_decay"]
        self.batch_size = args["batch_size"]
        self.weight_decay = args["weight_decay"]
        self.num_workers = args["num_workers"]
        self.total_sessions = args["total_sessions"]
        self.dataset = args["dataset"]
        self.fc_lrate = args["fc_lrate"]
        self.eval = args['eval']
        self._protos = []

        self.topk = 1  # origin is 5
        self.class_num = self._network.class_num
        self.debug = False
        self.fea_in = defaultdict(dict)

        self.cls_mean = []
        self.cls_cov = []

        self.cls_mean_shift = []
        self.cls_mean_share = []
        
        self.loss_fn = CrossEntropyLoss()

        for module in self._network.modules():
            if isinstance(module, Attention_GR_LoRA):
                module.init_param()


    def after_task(self):
        self._known_classes = self._total_classes
        # logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def incremental_train(self, data_manager):
        self.data_manager = data_manager
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        self._network.update_fc(self._total_classes)

        logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes))

        train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train',
                                                 mode='train')
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True,
                                       num_workers=self.num_workers)
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test')
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False,
                                      num_workers=self.num_workers)

        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus)
        if not self.eval:

            self._train(self.train_loader, self.test_loader)
            

        if len(self._multiple_gpus) > 1:
            self._network = self._network.module
        
        self._compute_mean()
        if self._cur_task > 0:
            self.classifer_align()
            
        # self.save_checkpoint()

    @torch.no_grad()
    def _compute_mean(self):
        self._network.eval()
        for class_idx in range(self._known_classes, self._total_classes):
            data, targets, idx_dataset = self.data_manager.get_dataset(
                np.arange(class_idx, class_idx + 1),
                source="train",
                mode="test",
                ret_data=True,
            )
            idx_loader = DataLoader(
                idx_dataset, batch_size=self.batch_size * 3, shuffle=False, num_workers=4
            )

            # specific space prototype
            vectors = []
            for _, _inputs, _targets in idx_loader:
                inputs, targets = _inputs.to(self._device), _targets.to(self._device)
                _vectors = self._network.extract_vector(inputs)
                vectors.append(_vectors)
            vectors = torch.cat(vectors, dim=0)
            features_per_cls = vectors
            self.cls_mean.append(features_per_cls.mean(dim=0).to(self._device))
            self.cls_cov.append(torch.cov(features_per_cls.T) + (torch.eye(self.cls_mean[class_idx].shape[-1]) * 1e-2).to(self._device))
            
            # shared space prototype
            vectors = []
            for _, _inputs, _targets in idx_loader:
                inputs, targets = _inputs.to(self._device), _targets.to(self._device)
                _vectors = self._network.extract_vector_shared(inputs)
                vectors.append(_vectors)
            vectors = torch.cat(vectors, dim=0)
            features_per_cls = vectors
            self.cls_mean_share.append(features_per_cls.mean(dim=0).to(self._device))
    
    
    def prototype_shift_add(self, old_class_idx, temperature=1.0,task_id=0):

        shared_old_p = self.cls_mean_share[old_class_idx]
        increment = self._total_classes - self._known_classes
        shard_cur_p_list = torch.stack([self.cls_mean_share[i] for i in range(increment * task_id, increment * (task_id+1))])
        private_cur_p_list = torch.stack([self.cls_mean[i] for i in range(increment * task_id, increment * (task_id+1))])

        # similarity
        individual_shifts = private_cur_p_list - shard_cur_p_list
        similarities = F.cosine_similarity(shared_old_p.unsqueeze(0) , shard_cur_p_list, dim=1)
        weights = F.softmax(similarities / temperature, dim=0)
        proto_shift = torch.sum(weights.unsqueeze(1) * individual_shifts, dim=0)

        return shared_old_p + proto_shift

    def classifer_align(self):

        from torch.distributions.multivariate_normal import MultivariateNormal
        for p in self._network.classifier_pool.parameters():
            p.requires_grad = True

        run_epochs = 20
        network_params = [
            {'params': self._network.classifier_pool.parameters(), 'lr': self.fc_lrate, 'weight_decay': self.weight_decay}]
        
        optimizer = optim.SGD(network_params, lr=self.fc_lrate, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=5)

        prog_bar = tqdm(range(run_epochs))
        increment = self._total_classes - self._known_classes
        self._network.eval()

        past_aux_dict = defaultdict(lambda: defaultdict(int))

        for epoch in prog_bar:

            sampled_data = []
            sampled_label = []
            num_sampled_pcls = self.batch_size * 5

            sampled_data_wrong = []
            sampled_label_wrong = []

            if num_sampled_pcls % self._cur_task == 0:
                num_sampled_pcls_wrong = num_sampled_pcls // self._cur_task
            else:
                num_sampled_pcls_wrong = (num_sampled_pcls // self._cur_task) + 1
            
            for class_idx in range(self._total_classes):
                
                mean = self.cls_mean[class_idx].to(self._device)
                cov = self.cls_cov[class_idx].to(self._device)
                m = MultivariateNormal(mean.float(), cov.float())
                sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                sampled_data.append(sampled_data_single)

                sampled_label.extend([class_idx] * num_sampled_pcls)
                
            for class_idx in range(self._total_classes):
                task_id = class_idx//increment
                for task_wrong in range(self._cur_task+1):
                    if task_id == task_wrong:
                        continue

                    past_aux_dict[task_id][task_wrong] += 1

                    mean_shift_sup = self.prototype_shift_add(class_idx, task_id=task_wrong)
                    cov = self.cls_cov[class_idx].to(self._device)
                    m = MultivariateNormal(mean_shift_sup.float(), cov.float())

                    sampled_data_single = m.sample(sample_shape=(num_sampled_pcls_wrong,))
                    sampled_data_wrong.append(sampled_data_single)
                    sampled_label_wrong.extend([class_idx] * num_sampled_pcls_wrong)
            
            sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device)
            sampled_label = torch.tensor(sampled_label).long().to(self._device)
            inputs = sampled_data
            targets = sampled_label
            sf_indexes = torch.randperm(inputs.size(0))
            inputs = inputs[sf_indexes]
            targets = targets[sf_indexes]

            sampled_data_wrong = torch.cat(sampled_data_wrong, dim=0).float().to(self._device)
            sampled_label_wrong = torch.tensor(sampled_label_wrong).long().to(self._device)
            inputs_wrong = sampled_data_wrong
            targets_wrong = sampled_label_wrong
            sf_indexes_wrong = torch.randperm(inputs_wrong.size(0))
            inputs_wrong = inputs_wrong[sf_indexes_wrong]
            targets_wrong = targets_wrong[sf_indexes_wrong]

            losses = 0.0
            correct, total = 0, 0
            for _iter in range(self._total_classes):
                inp = inputs[_iter * (num_sampled_pcls):(_iter + 1) * (num_sampled_pcls)]
                tgt = targets[_iter * (num_sampled_pcls):(_iter + 1) * (num_sampled_pcls)]

                inp_wrong = inputs_wrong[_iter * (num_sampled_pcls):(_iter + 1) * (num_sampled_pcls)]
                tgt_wrong = targets_wrong[_iter * (num_sampled_pcls):(_iter + 1) * (num_sampled_pcls)]
                    
                logits = self._network.forward_only_fc(inp)

                probs_ent = F.softmax(logits, dim=1)
                log_probs_ent = F.log_softmax(logits, dim=1)
                entropy = -(probs_ent * log_probs_ent).sum(dim=1)
                loss = entropy.mean()

                if inp_wrong.size(0) > 0:
                    logits_wrong = self._network.forward_only_fc(inp_wrong)
                    loss += self.loss_fn(logits, tgt) + self.loss_fn(logits_wrong, tgt_wrong)
                else:
                    loss += self.loss_fn(logits, tgt)
                
                _, preds = torch.max(logits, dim=1)

                correct += preds.eq(tgt.expand_as(preds)).cpu().sum()
                total += len(tgt)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses += loss

            scheduler.step()
            ca_acc = np.round(tensor2numpy(correct) * 100 / total, decimals=2)
            info = "Task {}, Epoch {}/{} => Loss {:.3f}, CA_accy {:.2f}".format(
                self._cur_task,
                epoch + 1,
                run_epochs,
                losses / self._total_classes,
                ca_acc,
            )
            prog_bar.set_description(info)
            

        logging.info(info)
    

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)

        for name, param in self._network.named_parameters():
            param.requires_grad_(False)
            
            if "normal_fc" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            # if self._cur_task == 0:
            if "lora_A_k" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            if "lora_A_v" + "." + str(self._network.numtask - 1) + "." in name:
                    param.requires_grad_(True)
            if "lora_B_k" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            if "lora_B_v" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)

            # if self._cur_task > 0:
            if "lora_B_k_fit" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            if "lora_B_v_fit" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            if "lora_A_k_fit" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)
            if "lora_A_v_fit" + "." + str(self._network.numtask - 1) + "." in name:
                param.requires_grad_(True)

                
        # Double check
        enabled = set()
        for name, param in self._network.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        with torch.no_grad():
            if self._cur_task > 0:
                for i, (_, inputs, targets) in enumerate(train_loader):
                    inputs, targets = inputs.to(self._device), targets.to(self._device)
                    self._network.forward_shared(inputs, get_cur_x=True)

                for module in self._network.modules():
                    if isinstance(module, Attention_GR_LoRA):
                        self.fea_in[module.lora_A_k[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_A_v[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_B_k[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_B_v[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_A_k_fit[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_A_v_fit[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_B_k_fit[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)
                        self.fea_in[module.lora_B_v_fit[self._cur_task].weight] = deepcopy(module.cur_matrix).to(self._device)

                        module.cur_matrix.zero_()
                        module.n_cur_matrix = 0 

            self.init_model_optimizer()
            if self._cur_task == 0:
                self.run_epoch = self.init_epoch
            else:
                self.update_optim_transforms()
                self.run_epoch = self.epochs
        
        self.train_function(train_loader, test_loader)

        return
    
    def train_function(self, train_loader, test_loader):
        prog_bar = tqdm(range(self.run_epoch))
        past_aux_dict = defaultdict(int)
        
        for _, epoch in enumerate(prog_bar):
            self._network.eval()
            losses = 0.
            correct, total = 0, 0

            for i, (_, inputs, targets) in enumerate(train_loader):
                
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                mask = (targets >= self._known_classes).nonzero().view(-1)
                inputs = torch.index_select(inputs, 0, mask)
                labels = torch.index_select(targets, 0, mask)
                targets = torch.index_select(targets, 0, mask) - self._known_classes

                ret = self._network.forward_specific(inputs)
                
                logits = ret['logits'][:,self._known_classes:self._total_classes]
                
                features = ret['features']
                features_pos = features / features.norm(dim=-1, keepdim=True)

                if self._cur_task > 0:
                    aux_id = random.randint(0, self._cur_task-1)
                    past_aux_dict[aux_id] += 1
                    features_neg = self._network.extract_vector_by_auxid(inputs, self._cur_task, aux_id)
                    logits_neg = self._network.forward_only_fc(features_neg)[:,self._known_classes:self._total_classes]
                    loss = self.loss_fn(logits, targets) + self.loss_fn(logits_neg, targets)
                else:
                    loss = self.loss_fn(logits, targets)

                self.model_optimizer.zero_grad()
                loss.backward()

                self.model_optimizer.step()
                losses += loss.item()

                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)
                # if self.debug and i > 10: break

            self.model_scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
                self._cur_task, epoch + 1, self.run_epoch, losses / len(train_loader), train_acc)
            prog_bar.set_description(info)

        logging.info(info)
        
    def _evaluate(self, y_pred, y_true):
        ret = {}
        print(len(y_pred), len(y_true))
        grouped = accuracy(y_pred, y_true, self._known_classes, self.class_num)
        ret['grouped'] = grouped
        ret['top1'] = grouped['total']
        return ret

    def _eval_cnn(self, loader):
        self._network.to(self._device)
        self._network.eval()
        y_pred, y_true = [], []
        y_pred_with_task = []
        y_pred_task, y_true_task = [], []


        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            targets = targets.to(self._device)

            with torch.no_grad():

                y_true_task.append((targets//self.class_num).cpu())

                if isinstance(self._network, nn.DataParallel):
                    outputs = self._network.module.interface(inputs, self._total_classes-self._known_classes)
                else:
                    outputs = self._network.interface(inputs, self._total_classes-self._known_classes)

            predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1].view(-1)  # [bs, topk]
            y_pred_task.append((predicts//self.class_num).cpu())

            outputs_with_task = torch.zeros_like(outputs)[:,:self.class_num]
            for idx, i in enumerate(targets//self.class_num):
                en, be = self.class_num*i, self.class_num*(i+1)
                outputs_with_task[idx] = outputs[idx, en:be]
            predicts_with_task = outputs_with_task.argmax(dim=1)
            predicts_with_task = predicts_with_task + (targets//self.class_num)*self.class_num

            # print(predicts.shape)
            y_pred.append(predicts.cpu().numpy())
            y_pred_with_task.append(predicts_with_task.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_pred_with_task), np.concatenate(y_true), torch.cat(y_pred_task), torch.cat(y_true_task)  # [N, topk]

    def init_model_optimizer(self):
        if self._cur_task == 0:
            lr = self.init_lr
        else:
            lr = self.lrate

        fea_params = [p for n, p in self._network.named_parameters() if bool(re.search('lora', n)) and p.requires_grad == True]

        param_name_map = {p: n for n, p in self._network.named_parameters()}
        name_param_map = {n: p for n, p in self._network.named_parameters()}

        cls_params_normal = [p for n, p in self._network.named_parameters() if (bool(re.search('_pool', n))) and p.requires_grad == True]

        model_optimizer_arg = {'params': [{'params': fea_params, 'svd': True, 'lr': lr,'thres': 0.99},
                                          {'params': cls_params_normal, 'weight_decay': self.weight_decay,'lr': self.fc_lrate}],
                               'weight_decay': self.weight_decay,
                               'param_name_map':param_name_map,
                               'name_param_map':name_param_map
                               }
        if self.args['optim'] == 'Adam':
            model_optimizer_arg.update({'betas': (0.9, 0.999)})
        elif self.args['optim'] == 'Sgd':
            model_optimizer_arg.update({'momentum': 0.9})

        # self.args['model_optimizer'] = 'Adam'
        self.model_optimizer = getattr(optimgrad, self.args['optim'])(**model_optimizer_arg)
        self.model_scheduler = CosineSchedule(self.model_optimizer, K=self.epochs)

    def update_optim_transforms(self):
        self.model_optimizer.get_eigens(self.fea_in)
        self.model_optimizer.get_transforms()
        self.fea_in = defaultdict(dict)
        




