# ConvPrompt implementation
import logging
import numpy as np
import torch
import wandb
import copy
import math
import matplotlib.pyplot as plt
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import ConvPromptVitNetOnline
from models_online.base import BaseLearner
from utils.toolkit import tensor2numpy
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler
from sklearn.metrics import confusion_matrix
from torch.distributions import Categorical
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir

import json
import os
from sentence_transformers import SentenceTransformer, util
import pickle as pkl
import wandb
import sys


def get_dataset_name(args=None):
    if args["dataset"] == "cub":
        return "cub"
    elif args["dataset"] == "cifar224":
        return "cifar100"
    elif args["dataset"] == "imagenetr":
        return "imr"
    elif args["dataset"] == "imagenet1000":
        return "imagenet1000"
    else:
        return "others"

def getUnion(list1, list2):
    return list(set(list1) | set(list2))

def getSimilarity(desc_list, desc_list2, getEmbeddings=False):
    
    model = SentenceTransformer('whaleloops/phrase-bert')
    encodings1 = model.encode(desc_list, convert_to_tensor=False)
    encodings2 = model.encode(desc_list2, convert_to_tensor=False)

    similarity = util.cos_sim(encodings1, encodings2)
    similarity, _ = torch.max(similarity, dim=1)

    return torch.mean(similarity)

def image_embedding_similarity(class_mask, task_id, args=None):
    k = args.num_prompts_per_task
    if task_id>0 and args.variable_num_prompts:
        curr_mask = class_mask[task_id]
        print(curr_mask)

        dataset_name = args.dataset
        filename = os.getcwd()+"/"+dataset_name+"_class_prototypes.pkl"
        with open(filename, 'rb') as f:
            x = pkl.load(f)
        prev_task_embeddings = []
        for i in range(task_id):
            for class_num in class_mask[i]:
                prev_task_embeddings.append(torch.tensor(x[class_num]))
        prev_task_embeddings = torch.stack(prev_task_embeddings)

        curr_task_embeddings = []
        for class_num in class_mask[task_id]:
            curr_task_embeddings.append(torch.tensor(x[class_num]))
        curr_task_embeddings = torch.stack(curr_task_embeddings)
        
        similarity = util.cos_sim(curr_task_embeddings, prev_task_embeddings)
        similarity, _ = torch.max(similarity, dim=1)
        similarity = np.mean(np.array(similarity), axis=0)

        print('Image-based Similarity: ', similarity, ' Task: ', task_id)
        k = math.ceil((1-similarity)*k)

    return k

def class_label_similarity(class_mask, task_id, args=None):
    k = args.num_prompts_per_task
    if task_id>0 and args.variable_num_prompts:
        curr_mask = class_mask[task_id]
        print(curr_mask)

        dataset_name = get_dataset_name(args)
        path = os.getcwd()+"/../ConvPrompt/descriptors/descriptors_"+dataset_name+".json"

        if not os.path.exists(path):
            print('Class List not found for ', args.dataset)
            raise NotImplementedError
        
        f = open(path)
        desc = json.load(f)

        class_labels = list(desc.keys())
        names_list1 = []

        for id in range(task_id):
            for item in class_mask[id]:
                names_list1.append(class_labels[item])

        names_list2 = []
        for item in class_mask[task_id]:
            names_list2.append(class_labels[item])

        similarity = getSimilarity(desc_list=names_list1, desc_list2=names_list2)
        print('Similarity: ', similarity, ' Task: ', task_id)

        k = math.ceil((1-similarity)*k)
    
    return k

def num_new_prompts(class_mask, task_id, args=None):
    k = args["num_prompts_per_task"]
    if task_id>0:
        curr_mask = class_mask[task_id]

        dataset_name = get_dataset_name(args)
        if dataset_name == "others":
            return k
        path = os.getcwd()+"/../ConvPrompt/descriptors/descriptors_"+dataset_name+".json"

        if not os.path.exists(path):
            print("path: ", path)
            print('Descriptor List not found for ', args["dataset"])
            raise NotImplementedError
        
        f = open(path)
        desc = json.load(f)

        desc_list = []

        for id in range(task_id):
            for item in class_mask[id]:
                attributes = list(desc.items())[item][1]
                # attr_list.append(attributes)
                desc_list = getUnion(desc_list, attributes)


        desc_list2 = []
        for item in class_mask[task_id]:
            attribs = list(desc.items())[item][1]
            desc_list2 = getUnion(desc_list2, attribs)


        similarity = getSimilarity(desc_list=desc_list, desc_list2=desc_list2)
        print('Similarity: ', similarity, ' Task: ', task_id)

        k = math.ceil((1-similarity)*k)
    
    return k

# tune the model at first session with vpt, and then conduct simple shot.
num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
    
        self._network = ConvPromptVitNetOnline(args, True)

        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.min_lr = args["min_lr"] if args["min_lr"] is not None else 1e-8
        self.recab_coef = args["recab_coef"]
        self.args = args
        self.old_num_k = 0
        
        self.optim = None
        self.scheduler = None
        # FGH stuff
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.step = 0
        
        self.grad_weight = {}
        self.old_grad = {}
        self.m = {}
        self.v = {}
                
        self.buffer = ProtoBuffer(
            max_size=args['nb_classes'],
            shape=(768,),
            device=self._device
        )

        # Freeze the parameters for ViT.
        if self.args["freeze"]:
            # freeze args.freeze[blocks, patch_embed, cls_token] parameters
            ### note: if use simclr, rep_token is NOT frozen
            for n, p in self._network.backbone.named_parameters():
                if n.startswith(tuple(self.args["freeze"])):
                    p.requires_grad = False
        
        total_params = sum(p.numel() for p in self._network.backbone.parameters())
        logging.info(f'{total_params:,} model total parameters.')
        total_trainable_params = sum(p.numel() for p in self._network.backbone.parameters() if p.requires_grad)
        logging.info(f'{total_trainable_params:,} model training parameters.')

        # if some parameters are trainable, print the key name and corresponding parameter number
        if total_params != total_trainable_params:
            for name, param in self._network.backbone.named_parameters():
                if param.requires_grad:
                    logging.info("{}: {}".format(name, param.numel()))


    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager, **kwargs):
        n_tasks = kwargs.get('n_tasks', 10)
        self._cur_task += 1
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager
        
        if len(self._multiple_gpus) > 1:
            print('Multiple GPUs')
            self._network = nn.DataParallel(self._network, self._multiple_gpus)

        # self.old_prompt = copy.deepcopy(self._network.backbone.e_prompt.prompt.clone().detach())
        # self.old_prompt_matcher = copy.deepcopy(self._network.backbone.e_prompt.prompt_embed_matcher)
        
        # curr_num_k = math.ceil((self._cur_task + 2.5) * 2) - self.old_num_k  # Returns number of prompts to be added for this task
        # total_task = data_manager.nb_tasks
        # total_classes = data_manager.nb_classes
        # class_order = data_manager._class_order
        # class_mask = torch.tensor(class_order).view(total_task, -1)
        # task_id = self._cur_task
        # curr_num_k = num_new_prompts(class_mask, task_id, self.args)  # Returns number of prompts to be added for this task
        
        # Use the same number of prompt at all times; no freezing
        if self._cur_task == 0:
            old_num_k = 0
            new_num_k = 5
            self._network.backbone.e_prompt.process_new_task(old_num_k, new_num_k)

        self._train(self.train_loader, self.test_loader)

        # self.old_num_k += curr_num_k

        if len(self._multiple_gpus) > 1:
            self._network = self._network.module

    def _train(self, train_loader, test_loader):
        self._network.to(self._device)

        if self._cur_task == 0:
            self.optim = self.get_optimizer()
            self.scheduler = None
            
        # if self._cur_task > 0:
            # self._init_prompt(optimizer)

        # if self._cur_task > 0 and self.args["reinit_optimizer"]:
        #     self._network.backbone.head.update(self.args["increment"])
        #     optimizer = self.get_optimizer()
            
        self._init_train(train_loader, test_loader)

    def get_optimizer(self):
        not_n_params = []
        n_params = []
        param_list = list(self._network.parameters())
        if self._cur_task > 0:
            for n, p in self._network.named_parameters():                
                if n.find('norm1')>=0 or n.find('norm2') >= 0 or n.startswith('norm') or n.find('fc_norm') >= 0:
                    # print(f'Param: {n} Param.requires_grad: {p.requires_grad}')
                    n_params.append(p)
                else:
                    not_n_params.append(p)
            
            network_params = [{'params': not_n_params, 'lr': self.init_lr},
                                {'params': n_params, 'lr': 0.005*self.init_lr}] 
        else:
            network_params = [{'params': param_list, 'lr': self.init_lr}]


        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, self._network.parameters()), 
                momentum=0.9, 
                lr=self.init_lr,
                weight_decay=self.weight_decay
            )
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(
                network_params,
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )
            
        elif self.args['optimizer'] == 'adamw':
            optimizer = optim.AdamW(
                filter(lambda p: p.requires_grad, self._network.parameters()),
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )

        return optimizer
    
    def get_scheduler(self, optimizer):
        if self.args["scheduler"] == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=self.args['tuned_epoch'], eta_min=self.min_lr)
        elif self.args["scheduler"] == 'steplr':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=self.args["init_milestones"], gamma=self.args["init_lr_decay"])
        elif self.args["scheduler"] == 'constant':
            scheduler = None

        return scheduler

    def _init_prompt(self, optimizer):
        args = self.args
        model = self._network.backbone
        task_id = self._cur_task

    def _init_train(self, train_loader, test_loader):

        # s = self.old_num_k

        # Freezing previous tasks' filters
        # for name, param in self._network.named_parameters():
        #     if name.find('e_prompt.v_conv_vals') >=0  or name.find('e_prompt.k_conv_vals') >=0:
        #         for i in range(s):
        #             if name.find('.{}.weight'.format(i)) >=0 or name.find('.{}.bias'.format(i)) >=0:
        #                 param.requires_grad = False

        prog_bar = tqdm(range(self.args['tuned_epoch']))
        for _, epoch in enumerate(prog_bar):
            self._network.backbone.train()

            losses = 0.0
            correct, total = 0, 0
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)

                output = self._network(inputs, task_id=-1, train=True)
                logits = output["logits"][:, :self._total_classes]
                f = output['pre_logits']

                present = targets.unique().to(self._device)
                logits[:, [i for i in range(self._total_classes) if i not in present]] = float('-inf')
                
                mem_y, mem_f= self.buffer.random_retrieve_f(n_imgs=self.args['nb_classes'])
                
                loss_recab = torch.tensor([0]).to(self._device)
                
                # fc recalibration
                if len(mem_y) > 0:
                    logits_ori = self._network.backbone.forward_head2(mem_f)
                    logits_ori[:, [i for i in range(self.args['nb_classes']) if i not in mem_y.unique().long()]] = float('-inf')
                    logits_ori = logits_ori[:, :self._total_classes]
                    loss_recab = F.cross_entropy(logits_ori, mem_y.long()).mean()
                
                loss = \
                    (F.cross_entropy(logits, targets.long())).mean() + self.recab_coef * loss_recab

                self.optim.zero_grad()
                loss.backward()
                
                # gradient reweighting
                for i, (name, param) in enumerate(self._network.named_parameters()):
                    curr_grad = param.grad
                    if curr_grad is not None:
                        if str(i) in self.grad_weight.keys():
                            self.m[str(i)] = self.beta1 * self.m[str(i)] + (1 - self.beta1) * curr_grad
                            self.v[str(i)] = self.beta2 * self.v[str(i)] + (1 - self.beta2) * curr_grad ** 2
                            m_hat = self.m[str(i)] / (1 - self.beta1 ** self.step)
                            v_hat = self.v[str(i)] / (1 - self.beta2 ** self.step)
                            curr_grad = m_hat / (torch.sqrt(v_hat) + 1e-8)
                            
                            self.grad_weight[str(i)] = self.grad_weight[str(i)] + self.args['gamma'] * curr_grad * self.old_grad[str(i)]
                            self.grad_weight[str(i)] = torch.clamp(self.grad_weight[str(i)], 0, self.args['clamp'])

                            param.grad = self.grad_weight[str(i)] * param.grad
                            wandb.log({"grad_w" + str(i): self.grad_weight[str(i)].mean().item(), "step": self.step})
                        else:
                            self.grad_weight[str(i)] = 1.0
                            self.m[str(i)] = 0.0
                            self.v[str(i)] = 0.0
                    self.old_grad[str(i)] = curr_grad
                self.step += 1
                
                if self.args["use_clip_grad"]:
                    torch.nn.utils.clip_grad_norm_(self._network.parameters(), 1.0)

                self.optim.step()

                losses += loss.item()

                if str(loss.item()) == 'nan':
                    sys.exit(0)
                
                wandb.log({
                    'loss': loss.item(),
                    'loss_recab': loss_recab.item(),
                    'step': self.step
                    })

                # buffer update
                self.buffer.update(queries=None, keys=None, values=None, labels=targets.detach(), features=f.detach())
                
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)
                
            if self.scheduler:
                self.scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            
            if self.args['blurry']:
                _, _, _, mem_y, _, _ = self.buffer.random_retrieve(n_imgs=self.args['nb_classes'])
                test_sampler = OnlineTestSampler(self.test_dataset, mem_y.unique())
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader

            if (epoch + 1) == self.args['tuned_epoch']:
                test_acc = self._compute_accuracy(self._network, test_loader)

                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
                # if self.args["sweep"]:
                wandb.log({"task": self._cur_task, "epoch": epoch, "train_acc": train_acc, "test_acc": test_acc})
            
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = self._network(inputs, task_id=self._cur_task)["logits"][:, :self._total_classes]
            predicts = torch.topk(
                outputs, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        cnn_entropy = []
        pred_all, true_all = torch.tensor([]), torch.tensor([])
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs, task_id=self._cur_task)["logits"][:, :self._total_classes]
            prob = torch.softmax(outputs, dim=1)
            cnn_entropy += Categorical(probs = prob).entropy().tolist()
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)
            pred_all = torch.cat((pred_all, predicts.cpu()), dim=0)
            true_all = torch.cat((true_all, targets.cpu()), dim=0)

        cm = confusion_matrix(true_all, pred_all)
        fig = plt.matshow(cm)
        # wandb.log({"cnn_cm": fig, "task_id": self._cur_task})

        fig = plt.figure()
        plt.hist(cnn_entropy, bins=100, alpha=0.5)
        plt.legend()
        # wandb.log({"Entropy_dist_cnn": wandb.Image(fig)})
        
        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)