import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

import logging
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans

from methods.base import BaseLearner
from utils.toolkit import tensor2numpy, accuracy_domain

from models.clip_prefix_one_prompt_tuning.net import PrefixOnePromptNet

class PrefixPromptTuning(BaseLearner):

    def __init__(self, args):
        super().__init__(args)
        self._network = PrefixOnePromptNet(args)
        
        self.args = args
        self.query_type = args["query_type"]
        self.EPSILON = args["EPSILON"]# 浮点数下限
        self.init_epoch = args["init_epoch"]#第一轮训练的轮数
        self.init_lr = args["init_lr"]#第一轮训练的学习率
        self.init_lr_decay = args["init_lr_decay"]#
        self.init_weight_decay = args["init_weight_decay"]#第一轮训练的权重衰减（正则化 防止过拟合）
        self.epochs = args["epochs"] # 训练总次数
        self.lrate = args["lrate"] # 学习率
        self.lrate_decay = args["lrate_decay"]#
        #self.batch_size = args["batch_size"] # 数据集批次大小
        self.weight_decay = args["weight_decay"] # 权重衰减
        self.num_workers = args["num_workers"] # 线程数量
        self.knn_k=args["knn_k"]
        self.topk = 1  # origin is 5 预测排序在前topk
        self.class_num = self._network.class_num # 每个域的类数量=increment
        self.all_keys = []

    def after_task(self):
        #self._old_network = self._network.copy().freeze()
        self._known_classes = self._total_classes#更新当前已训练的类数量
        
        # logging.info('Exemplar size: {}'.format(self.exemplar_size))
    def begin_incremental(self, data_manager):
        # 初始_cur_task = -1, _known_classes=0, _total_classes=0
        self._cur_task += 1 
        # 不同域的相同类也先按不同类``
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        # self._network.update_fc(self._total_classes)
        logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes))
    def incremental_train(self, data_manager):
        
        # 得到当前任务所需全部训练数据集，第一个任务只有一个域，2个分类各3000张图片，训练集有6000，第二个任务为2400，为第二个域的real和fake
        # get_dataset函数包含了很多图片预处理的transform
        train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source='train', mode='train')
        # self.train_loader仍然还没有将数据加载，还是str表示图片存储地址
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        # 测试数据集，从0到_total_classes
        
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test')
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        # 多卡执行，全部损失和除以数据量，执行在divice0上，所以压力大一点https://zhuanlan.zhihu.com/p/102697821
        if len(self._multiple_gpus) > 1:
            self._network = nn.DataParallel(self._network, self._multiple_gpus) 
        self._train(self.train_loader, self.test_loader)
        try:
            if self._network.module.prefix_prompt is not None:
                self._network.module.prefix_prompt.process_task_count()
        except:
            if self._network.prefix_prompt is not None:
                self._network.prefix_prompt.process_task_count()
        self._network.update_fc()
        if self.query_type=='share_p_query':
            self.shareP_clustering(data_manager)
        elif self.query_type=='vit_query':
            self.vit_clustering(self.train_loader)
        else:
            return
        # 训练后的模型参数更新到最新，便于下一轮
        if len(self._multiple_gpus) > 1:
            self._network = self._network.module


    def _train(self, train_loader, test_loader):
        self._network.to(self._device)
        # if self._old_network is not None:
        #     self._old_network.to(self._device)
        paramGrad=0
        # 定向更新分类器和提示器相关参数,设置只有当前任务,设定的分类器和prompt层会更新,其他不更新。str(self._network.numtask - 1)绑定了任务
        for name, param in self._network.named_parameters():
            param.requires_grad_(False)
            #clip_model模型的
            if len(self._multiple_gpus) > 1:
                numtask=self._network.module.numtask
            else:
                numtask=self._network.numtask
            # param.requires_grad_(True)让 backward 可以追踪这个参数并且计算它的梯度，否则模型参数不更新
            # s-prompts是只更新 当前任务的classifier和prompt
            # classifier池, 如果是第一个域，那么更新这个池子中第一个分类器的参数
            #'classifier_pool.0.ctx' 待优化的上下文在PromptLearner里有self.ctx=nn.Parameter
            if "classifier_pool" + "." + str(numtask) in name:
                param.requires_grad_(True)
                paramGrad+=param.numel() #8192=16*512
            # prompt池，如果是第一个域，那么更新这个池子中第一个prompts的参数
            #'prompt_pool.0.weight'、'prompt_pool.1.weight'
            if "share_prompt" in name or "prefix_prompt" in name:
                param.requires_grad_(True)
                paramGrad+=param.numel()#7680=768*10


        # Double check 二次检查谁被更新，easy
        enabled = set()
        for name, param in self._network.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled},count:{paramGrad}")
        #{'prompt_pool.0.weight', 'classifier_pool.0.ctx'}

        if self._cur_task==0: # 第一轮训练
            optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.init_lr, weight_decay=self.init_weight_decay)
            # 学习率衰减控制 余弦退火 调整学习率方法 因为懒… 这样就不用像使用其他类似于StepLR策略 进行调参了，而且总会取得不错的结果
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=self.init_epoch)
            self.run_epoch = self.init_epoch
            self.train_function(train_loader,test_loader,optimizer,scheduler)
        else:
            optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.lrate, weight_decay=self.weight_decay)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=self.epochs)
            self.run_epoch = self.epochs
            self.train_function(train_loader, test_loader, optimizer, scheduler)


    def train_function(self, train_loader, test_loader, optimizer, scheduler):
        prog_bar = tqdm(range(self.run_epoch)) # tqdm是进度条，可以监测每一轮的进度
        for _, epoch in enumerate(prog_bar):
            self._network.eval() #　不启用 Batch Normalization 和 Dropout。保证BN层能够用全部训练数据的均值和方差，即测试过程中要保证BN层的均值和方差不变。对于Dropout，model.eval()是利用到了所有网络连接，即不进行随机舍弃神经元
            losses = 0.
            correct, total = 0, 0
            # 第一个参数 (_,)是用来表示输入变量是一个丢弃的变量。每个batchsize=128，取出的inputs就是128，target也是128。最后一个size不足怎么办：单独打包最后一个batch依然送入训练（Pytorch默认方式）
            for i, (_, inputs, targets) in enumerate(train_loader): 
                # 输入特征和标签加载到GPU中 inputs.shape[128,3,224,224]，真实图片经过处理后变成CXHXW,值在（0,1）。targer.shape[128]
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                # 选出标签大于已完成的，nonzero是判定非0
                mask = (targets >= self._known_classes).nonzero().view(-1)
                # torch.index_select 按照给定索引筛选向量，inputs为输入向量，0为在第0维度进行筛选，mask是选取的索引ID列表
                inputs = torch.index_select(inputs, 0, mask) 
                # 这里减去是因为真实标签是0-50
                targets = torch.index_select(targets, 0, mask) - self._known_classes
                # 打印网络信息
                # from torchsummary import summary
                # summary(self._network, inputs)
                # 神经网络前向传播输出的概率值
                logits = self._network(inputs)['logits']
                # 计算交叉熵损失
                loss = F.cross_entropy(logits, targets)#计算交叉熵损失
                # 优化参数 3件套 梯度清零，计算梯度值，权重更新计算的梯度
                optimizer.zero_grad()
                loss.backward() 
                optimizer.step()
                losses += loss.item() # 累计损失
                # 计算准确率
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step() # 调用scheduler.step(),则会改变optimizer中的学习率
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            # 每一个epoch训练结束后都计算准确率
            test_acc = self._compute_accuracy_domain(self._network, test_loader)
            # 进度条内容显示格式
            info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}'.format(
                self._cur_task, epoch + 1, self.run_epoch, losses / len(train_loader), train_acc, test_acc)
            prog_bar.set_description(info)

        logging.info(info)

    
    def shareP_clustering(self,data_manager):
        self.all_keys = []
        for task in range(self._cur_task+1):
            features = []
            train_dataset = data_manager.get_dataset(np.arange(self.class_num*task, self.class_num*(task+1)), source='train', mode='train')
            # self.train_loader仍然还没有将数据加载，还是str表示图片存储地址
            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        
            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self._device), targets.to(self._device)
                # mask = (targets >= self._known_classes).nonzero().view(-1)
                # inputs = torch.index_select(inputs, 0, mask)
                with torch.no_grad():
                    if isinstance(self._network, nn.DataParallel):
                        feature = self._network.module.extract_share_prompt_vector(inputs)
                    else:
                        feature = self._network.extract_share_prompt_vector(inputs)
            # 每个输入图片转为归一化的向量
            
                features.append(feature)
            features = torch.cat(features, 0).cpu().detach().numpy()
            # 聚类得到五个中心点向量，是每个域五个中心点，这里的dataloader是对一个任务的数据集。
            clustering = KMeans(n_clusters=self.knn_k, random_state=0).fit(features)
            self.all_keys.append(torch.tensor(clustering.cluster_centers_).to(feature.device))

    def vit_clustering(self, dataloader):
        features = []
        for i, (_, inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(self._device), targets.to(self._device)
            mask = (targets >= self._known_classes).nonzero().view(-1)
            inputs = torch.index_select(inputs, 0, mask)
            with torch.no_grad():
                
                if isinstance(self._network, nn.DataParallel):
                    feature = self._network.module.extract_vector(inputs)
                else:
                    feature = self._network.extract_vector(inputs)
            # 每个输入图片转为归一化的向量
            feature = feature / feature.norm(dim=-1, keepdim=True)
            features.append(feature)
        features = torch.cat(features, 0).cpu().detach().numpy()
        # 聚类得到五个中心点向量，是每个域五个中心点，这里的dataloader是对一个任务的数据集。
        clustering = KMeans(n_clusters=self.knn_k, random_state=0).fit(features)
        self.all_keys.append(torch.tensor(clustering.cluster_centers_).to(feature.device))

    def _evaluate(self, y_pred, y_true):
        ret = {}
        grouped = accuracy_domain(y_pred.T[0], y_true, self._known_classes, class_num=self.class_num)
        ret['grouped'] = grouped
        ret['top1'] = grouped['total']
        #ret['top{}'.format(self.topk)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum()*100/len(y_true), decimals=2)
        return ret

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            targets = targets.to(self._device)
            with torch.no_grad():
                if self.query_type=='share_p_query':
                    if isinstance(self._network, nn.DataParallel):
                        feature = self._network.module.extract_share_prompt_vector(inputs)
                    else:
                        feature = self._network.extract_share_prompt_vector(inputs)
                elif self.query_type=='vit_query':
                    if isinstance(self._network, nn.DataParallel):
                        feature = self._network.module.extract_vector(inputs)
                    else:
                        feature = self._network.extract_vector(inputs)
                else:
                    return
                taskselection = []
                # 和各任务的中心对比，找到最近距离的任务去预测
                # self.all_keys存储
                for task_centers in self.all_keys:
                    tmpcentersbatch = []
                    # 计算每一个任务中5个聚类的特征，哪一个最小，添加到taskselection[]
                    for center in task_centers: # 一个任务有五个中心点,计算特征与这5个中心点的差值
                        tmpcentersbatch.append((((feature - center) ** 2) ** 0.5).sum(1))
                    # torch.vstack所有行拼成一个列
                    # min(0)返回该矩阵中每一列的最小值：找出来最近的这个向量
                    taskselection.append(torch.vstack(tmpcentersbatch).min(0)[0])
                # 从每一个任务最小距离中再选出最小
                selection = torch.vstack(taskselection).min(0)[1]

                if isinstance(self._network, nn.DataParallel):
                    outputs = self._network.module.interface(inputs, selection)
                else:
                    outputs = self._network.interface(inputs, selection)
            # torch.topk() 求这个tensor前topk
            # input：一个tensor数据
            # k：指明是得到前k个数据以及其index
            # dim： 指定在哪个维度上排序， 默认是最后一个维度
            # largest：如果为True，按照大到小排序； 如果为False，按照小到大排序
            # sorted：返回的结果按照顺序返回
            # out：可缺省，不要
            predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy_domain(self, model, loader):
        
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)['logits']

            predicts = torch.max(outputs, dim=1)[1]
            #不同域相同类算预测成功，这里默认选好了域
            correct += ((predicts % self.class_num).cpu() == (targets % self.class_num)).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)
