from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from dataloader.data_utils import get_dataloader
from utils import *
# from backbone.svgg9 import svgg9
# from backbone.svgg16 import svgg16
# from backbone.svgg11 import svgg11
# from backbone.zoosvgg5 import svgg5
# from backbone.zoosvgg9 import svgg9
# from backbone.zoo_vgg5_model import *
from backbone.sparsity_svgg5 import *
# from backbone.zoo_vgg9_model import *
from backbone.sparsity_svgg9 import *
# from backbone.zoo_vgg_model import *
from backbone.sparsity_svgg11 import *
# from backbone.zoo_sresnet20_model import *
from backbone.sparsity_sresnet20 import *
from backbone.sresnet18 import *
from backbone.zoosvgg11 import svgg11
from backbone.zoosresnet19 import *
from backbone.zoospikingformer import *
from backbone.zoosvgg16 import *
from backbone.zoosresnet34 import *
from spikingjelly.activation_based.functional import reset_net
#这是一个teen+剪枝的代码文件
#剪枝不行
#更新：TETloss + teen
#更新：LOCALZO + TEEN
EPS=1E-7
class SPARSITY(nn.Module):
    def __init__(self, args):
        super(SPARSITY, self).__init__()
        self.args = args
        set_save_path(self)
        logger = log_to_file(os.path.join(args.save_path, 'log.txt'))
        self.logger = logger
        for arg, value in vars(args).items():
            self.logger.info(f"{arg}: {value}")
        self.model = MYNET(self.args, mode=self.args.base_mode)
        # self.model = nn.DataParallel(self.model, list(range(self.args.num_gpu)))
        self.model = nn.DataParallel(self.model, self.args.gpu)
        self.model = self.model.cuda()
        self.infer_time_sessions=[0.0]*args.sessions
        if self.args.model_dir is not None:
            self.logger.info('Loading init parameters from: %s' % self.args.model_dir)
            self.best_model_dict = torch.load(self.args.model_dir)['params']
        else:
            self.logger.info('random init params')
            if self.args.start_session > 0:
                self.logger.info('WARING: Random init weights for new sessions!')
            self.best_model_dict = deepcopy(self.model.state_dict())

    def train(self, ):
        args = self.args
        logger = self.logger
        max_acc_sessions = [0.0] * args.sessions
        unseen_acc_sessions = [0.0] * args.sessions
        seen_acc_sessions = [0.0] * args.sessions
        # init train statistics
        for session in range(args.start_session, args.sessions):
            train_set, trainloader, testloader = get_dataloader(args, session)
            self.model.load_state_dict(self.best_model_dict)
            if session == 0:  # load base class train img label
                if not args.only_do_incre:
                    logger.info(f'new classes for this session:{np.unique(train_set.targets)}')
                    optimizer, scheduler = get_optimizer_scheduler(self, self.model.parameters())
                    torch.cuda.synchronize()
                    start_time = time.time()
                    for epoch in range(args.epochs_base):
                        epoch_start = time.time()
                        tl, ta = 0.0, 0.0
                        self.model = self.model.train()
                        # standard classification for pretrain
                        tqdm_gen = tqdm(trainloader)
                        for i, batch in enumerate(tqdm_gen, 1):
                            data, train_label = [_.cuda() for _ in batch]
                            logits = self.model(data, session=session)
                            # print("logits.shape:", logits.shape)
                            logits = logits[:, :, :args.base_class]
                            means=5
                            lamb=0.05
                            # print("logits.shape=", logits.shape)
                            # print("train_labels.shape=", train_label.shape)
                            # print(1/0)
                            criterion = nn.CrossEntropyLoss().cuda()
                            if args.network == 'spikingformer':
                                loss = TET_loss_former(logits, train_label, criterion, means, lamb)
                                logits = logits.mean(0)
                            else:
                                loss = TET_loss(logits, train_label, criterion, means, lamb)
                                logits = logits.mean(1)

                            # loss = F.cross_entropy(logits, train_label)
                            acc = count_acc(logits, train_label)
                            total_loss = loss
                            lrc = scheduler.get_last_lr()[0]
                            tqdm_gen.set_description('Session 0, epo {}/{}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch+1, args.epochs_base, lrc, total_loss.item(), acc))
                            tl += total_loss.item()
                            ta += acc
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                        torch.cuda.synchronize()
                        epoch_end = time.time()
                        logger.info(f"Epoch {epoch + 1}/{args.epochs_base} Cost Time: {epoch_end - epoch_start:.2f} s")
                        tl = tl / len(trainloader)
                        ta = ta / len(trainloader)
                        tsl, tsa = self.test(self.model, testloader, epoch, args, session)
                        # save better model
                        if (tsa * 100) >= max_acc_sessions[session]:
                            max_acc_sessions[session] = float('%.3f' % (tsa * 100))
                            max_acc_epoch = epoch
                            save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                            torch.save(dict(params=self.model.state_dict()), save_model_dir)
                            torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                            self.best_model_dict = deepcopy(self.model.state_dict())
                            logger.info('********A better model is found!!**********')
                            logger.info('Saving model to :%s' % save_model_dir)
                        logger.info('best epoch {}, best test acc={:.3f}'.format(max_acc_epoch+1, max_acc_sessions[session]))
                        # self.trlog['train_loss'].append(tl)
                        # self.trlog['train_acc'].append(ta)
                        # self.trlog['test_loss'].append(tsl)
                        # self.trlog['test_acc'].append(tsa)
                        lrc = scheduler.get_last_lr()[0]
                        logger.info('epoch:%d/%d, lr:%.4f, training_loss:%.5f, training_acc:%.5f, test_loss:%.5f, test_acc:%.5f' % (epoch+1,args.epochs_base, lrc, tl, ta, tsl, tsa))
                        print_config(args)
                        scheduler.step()
                    # Finish base train
                    torch.cuda.synchronize()
                    end_time = time.time()
                    logger.info("Total training time: {:.2f} seconds".format(end_time - start_time))
                    logger.info('>>> Finish Base Train <<<')
                    logger.info('Session {}, Test Best Epoch {}, best test Acc {:.4f}'.format(session, max_acc_epoch+1, max_acc_sessions[session]))
                else:
                    logger.info('>>> Load Model &&& Finish base train...')
                    assert args.model_dir is not None

                if not args.not_data_init:
                    self.model.load_state_dict(self.best_model_dict)
                    self.model = self.replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
                    best_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                    logger.info('Replace the fc with average embedding, and save it to :%s' % best_model_dir)
                    self.best_model_dict = deepcopy(self.model.state_dict())
                    torch.save(dict(params=self.model.state_dict()), best_model_dir)
                    self.model.module.mode = 'avg_cos'
                    tsl, tsa = self.test(self.model, testloader, 0, args, session)
                    if (tsa * 100) >= max_acc_sessions[session]:
                        max_acc_sessions[session] = float('%.3f' % (tsa * 100))
                        logger.info('The new best test acc of base session={:.3f}'.format(max_acc_sessions[session]))
            # incremental learning sessions
            else:
                logger.info("training session: [%d]" % session)
                self.model.module.mode = self.args.new_mode
                self.model.eval()
                trainloader.dataset.transform = testloader.dataset.transform
                if args.soft_mode == 'soft_proto':
                    self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)
                    self.model.module.soft_calibration(args, session)
                # 注释了这两行
                # else:
                #     raise NotImplementedError
                tsl, (seenac, unseenac, avgac) = self.test(self.model, testloader, 0, args, session)
                # update results and save model
                logger.info(f"Seen Accuracy: {seenac * 100:.3f}%")
                logger.info(f"Unseen Accuracy: {unseenac * 100:.3f}%")
                # self.trlog['seen_acc'].append(float('%.3f' % (seenac * 100)))
                # self.trlog['unseen_acc'].append(float('%.3f' % (unseenac * 100)))
                max_acc_sessions[session] = float('%.3f' % (avgac * 100))
                unseen_acc_sessions[session] = float('%.3f' % (unseenac * 100))
                seen_acc_sessions[session] = float('%.3f' % (seenac * 100))
                self.best_model_dict = deepcopy(self.model.state_dict())
                logger.info(f"Session {session} ==> Seen Acc:{seenac * 100:.3f}%"
                             f"Unseen Acc:{unseenac * 100:.3f}% Avg Acc:{max_acc_sessions[session]}")
                logger.info('Session {}, test Acc {:.3f}'.format(session, max_acc_sessions[session]))
        # Finish all incremental sessions, save results.
        # result_list, hmeans = postprocess_results(result_list, self.trlog)
        # def postprocess_results(result_list, trlog):
        save_all_sessions_model_dir = os.path.join(args.save_path, args.dataset + '_' + args.project + '_' + args.network + '_' + str(args.epochs_base) + '.pth')
        torch.save(dict(params=self.model.module.encoder.state_dict()), save_all_sessions_model_dir)

        logger.info('Base Session Best Epoch {}'.format(max_acc_epoch+1))
        logger.info("max_acc:")
        logger.info(max_acc_sessions)
        logger.info("Seen acc:")
        logger.info(seen_acc_sessions)
        logger.info('Unseen acc:')
        logger.info(unseen_acc_sessions)
        hmeans = harm_mean(seen_acc_sessions, unseen_acc_sessions)
        logger.info('Harmonic mean:')
        logger.info(hmeans)
        logger.info("Total training time: {:.2f} seconds, average training time:{:2f}".format(end_time - start_time, (end_time-start_time)/args.epochs_base))
        logger.info(f'Infer time: {self.infer_time_sessions}')
        logger.info(f' average time: {sum(self.infer_time_sessions)/len(self.infer_time_sessions)}')
        logger.info(f'save path is {args.save_path}')
        print_config(args,logger,is_end=True)
        logger.info(f"ZO的样本数=5")
        logger.info(f"delta=0.5")
        logger.info(f"改成sparsity，阈值、脉冲发放率根据任务调整的，soft_calibration改成投影的")
        logger.info(f"beta={args.beta},theta={args.theta}")
        # logger.info('learning rate is 0.01')
        # logging.info(f"max_acc: {max_acc_sessions}")
        # logging.info(f"Unseen acc: {unseen_acc_sessions}")
        # logging.info(f"Seen acc: {seen_acc_sessions}")
        # logging.info(f"Harmonic mean: {hmeans}")
        # return result_list, hmeans
        # if not self.args.debug:
        #     save_result(args, self.trlog, hmeans)
        # logger.info(f"Base Session Best epoch:{max_acc_epoch}")

    def replace_base_fc(self, trainset, transform, model, args):
        # replace fc.weight with the embedding average of train data
        model = model.eval()
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, num_workers=8, pin_memory=True, shuffle=False)
        trainloader.dataset.transform = transform
        embedding_list = []
        label_list = []
        with torch.no_grad():
            for i, batch in enumerate(trainloader):
                data, label = [_.cuda() for _ in batch]
                model.module.mode = 'encoder'
                embedding = model(data)
                if self.args.network == 'spikingformer':
                    embedding = embedding.mean(0)
                else:
                    embedding = embedding.mean(1)
                embedding_list.append(embedding.cpu())
                label_list.append(label.cpu())
        embedding_list = torch.cat(embedding_list, dim=0)
        label_list = torch.cat(label_list, dim=0)
        proto_list = []
        for class_index in range(args.base_class):
            data_index = (label_list == class_index).nonzero()
            embedding_this = embedding_list[data_index.squeeze(-1)]
            embedding_this = embedding_this.mean(0)
            proto_list.append(embedding_this)
        proto_list = torch.stack(proto_list, dim=0)
        # print("proto_list.shape=", proto_list.shape)
        model.module.fc.weight.data[:args.base_class] = proto_list
        return model

    def test(self, model, testloader, epoch, args, session):
        logger = self.logger
        test_class = args.base_class + session * args.way
        model = model.eval()
        vl, va, va5 = 0.0, 0.0, 0.0
        lgt = torch.tensor([])
        lbs = torch.tensor([])
        with torch.no_grad():
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()  # 开始记录
            for i, batch in enumerate(testloader, 1):
                data, test_label = [_.cuda() for _ in batch]
                logits = model(data)
                if args.network == 'spikingformer':
                    logits = logits.mean(0)
                else:
                    logits = logits.mean(1)
                logits = logits[:, :test_class]
                loss = F.cross_entropy(logits, test_label)
                acc = count_acc(logits, test_label)
                top5acc = count_acc_topk(logits, test_label)
                vl += loss.item()
                va += acc
                va5 += top5acc
                lgt = torch.cat([lgt, logits.cpu()])
                lbs = torch.cat([lbs, test_label.cpu()])
            end_event.record()  # 结束记录
            torch.cuda.synchronize()  # 等待所有 GPU 操作完成
            total_time_ms = start_event.elapsed_time(end_event)
            average_time_ms = total_time_ms / len(testloader)
            self.infer_time_sessions[session] = average_time_ms
            self.logger.info("session {}: total inference time: {:.4f} ms, each batch inference time: {:.4f} ms".format(session, total_time_ms, session, average_time_ms))
            vl = vl / len(testloader)
            va = va / len(testloader)
            va5 = va5 / len(testloader)
            self.logger.info('epo {}/{}, test, loss={:.4f} acc={:.4f}, acc@5={:.4f}'.format(epoch+1, args.epochs_base, vl, va, va5))
            print_config(args)
            lgt = lgt.view(-1, test_class)
            lbs = lbs.view(-1)

            if session > 0:
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + 'confusion_matrix')
                cm = confmatrix(lgt, lbs, save_model_dir)
                perclassacc = cm.diagonal()
                seenac = np.mean(perclassacc[:args.base_class])
                unseenac = np.mean(perclassacc[args.base_class:])
                logger.info(f"Seen Acc:{seenac}  Unseen Acc:{unseenac}")
                return vl, (seenac, unseenac, va)
            else:
                return vl, va

class MYNET(nn.Module):
    def __init__(self, args, mode=None):
        super().__init__()
        self.mode = mode
        self.args = args
        if self.args.dataset == 'cifar100':
            C, H, W = 3, 32, 32
        elif self.args.dataset == 'mini_imagenet':
            C, H, W = 3, 84, 84
        elif self.args.dataset == 'cub200':
            C, H, W = 3, 224, 224
        elif self.args.dataset == 'n_caltech101':
            C, H, W = 3, 300, 200
        if self.args.network == 'svgg9':
            # self.encoder = svgg9(C=C,H=H,W=W)
            self.encoder = VGG9SNN(time_step=args.time_step, args=args)
            self.num_features = 1024
        elif self.args.network == 'svgg5':
            # print(1/0)
            self.encoder = VGG5SNN(time_step=args.time_step, args=args)
            # self.encoder = svgg5(C=C,H=H,W=W)
            self.num_features = 1024
        elif self.args.network == 'svgg16' :
            self.encoder = svgg16(C=C,H=H,W=W)
            self.num_features = 4096
        elif self.args.network == 'svgg11':
            self.encoder = VGG11SNN(time_step=args.time_step, args=args)
            self.num_features = 4096
            # self.encoder = svgg11(C=C,H=H,W=W)
            # self.num_features = 4096
        elif self.args.network == 'spikingformer':
            self.encoder = Spikingformer(H=H, W=W, C=C)
            self.num_features = 128
        elif self.args.network == 'sresnet19':
                self.encoder = resnet19(num_classes=100, time_step=args.time_step)
                self.num_features = 256
        elif self.args.network == 'sresnet18':
            self.encoder = resnet18(args)
            self.num_features = 512
        elif self.args.network == 'sresnet20':
            self.encoder = resnet20(num_classes=100,time_step=args.time_step)
            self.num_features = 256
        elif self.args.network == 'sresnet34':
            self.encoder = resnet34(num_classes=100,time_step=args.time_step)
            self.num_features = 256
        # if self.args.dataset in ['cifar100', 'manyshotcifar']:
        #     if self.args.network == 'svgg9':
        #         self.num_features = 1024
        #         self.encoder = svgg9()
        #     elif self.args.network == 'svgg11':
        #         self.num_features = 4096
        #         self.encoder = svgg11()
        #     elif self.args.network == 'svgg16':
        #         self.num_features = 4096
        #         self.encoder = svgg16()
        #     elif self.args.network == 'sresnet19':
        #         self.num_features = 256
        #         self.encoder = resnet19(num_classes=100)
        #     elif self.args.network == 'spikingformer':
        #         self.num_features = 1024
        #         self.encoder = spikingformer()
        #     else:
        #         self.encoder = resnet20()
        #         self.num_features = 64
        # if self.args.dataset in ['mini_imagenet', 'manyshotmini', 'imagenet100', 'imagenet1000', 'mini_imagenet_withpath']:
        #     if self.args.network == 'svgg9':
        #         self.encoder = svgg9(C=3, H=84, W=84)
        #         self.num_features = 1024
        #     else:
        #         self.encoder = resnet18(False, args)  # pretrained=False
        #         self.num_features = 512
        # if self.args.dataset in ['cub200', 'manyshotcub']:
        #     if self.args.network == 'svgg9':
        #         self.encoder = svgg9(C=3, H=224, W=224)
        #         self.num_features = 1024
        #     else:
        #         self.encoder = resnet18(True, args)  # pretrained=True follow TOPIC, models for cub is imagenet pre-trained. https://github.com/xyutao/fscil/issues/11#issuecomment-687548790
        #         self.num_features = 512
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.num_features, self.args.num_classes, bias=False)

    def forward_metric(self, x, session=0):
        x = self.encode(x, session=session)
        # print("x.shape", x.shape)
        # print(1/0)
        if 'cos' in self.mode:
            x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))
            x = self.args.temperature * x
        elif 'dot' in self.mode:
            x = self.fc(x)
            x = self.args.temperature * x
        # print("x.shape=", x.shape)
        # print(1/0)
        return x

    def encode(self, x, session=0):
        # print("x.shape=", x.shape)
        x = self.encoder(x, session=session, args=self.args)
        # print("x.shape=", x.shape)
        # print(1/0)
        if self.args.network[:4] != 'svgg' and self.args.network!= 'sresnet19' and self.args.network != 'spikingformer' and self.args.network != 'sresnet20' and self.args.network != 'sresnet34':
            x = F.adaptive_avg_pool2d(x, 1)
            x = x.squeeze(-1).squeeze(-1)
        return x

    def forward(self, input, session=0):
        if self.mode != 'encoder':
            input = self.forward_metric(input, session=session)
            # print("input.shape=", input.shape)
            # print(1/0)
            return input
        elif self.mode == 'encoder':
            input = self.encode(input, session=session)
            return input
        else:
            raise ValueError('Unknown mode')

    def update_fc(self, dataloader, class_list, session):
        for batch in dataloader:
            data, label = [_.cuda() for _ in batch]
            data = self.encode(data, session=session).detach()
            if self.args.network == 'spikingformer':
                data = data.mean(0)
            else:
                data = data.mean(1)
        if self.args.not_data_init:
            new_fc = nn.Parameter(
                torch.rand(len(class_list), self.num_features, device="cuda"),
                requires_grad=True)
            nn.init.kaiming_uniform_(new_fc, a=math.sqrt(5))
        else:
            new_fc = self.update_fc_avg(data, label, class_list)

    def update_fc_avg(self, data, label, class_list):
        new_fc = []
        for class_index in class_list:
            data_index = (label == class_index).nonzero().squeeze(-1)
            embedding = data[data_index]
            proto = embedding.mean(0)
            new_fc.append(proto)
            self.fc.weight.data[class_index] = proto
        new_fc = torch.stack(new_fc, dim=0)
        return new_fc

    def get_logits(self, x, fc):
        if 'dot' in self.args.new_mode:
            return F.linear(x, fc)
        elif 'cos' in self.args.new_mode:
            return self.args.temperature * F.linear(F.normalize(x, p=2, dim=-1), F.normalize(fc, p=2, dim=-1))

    # def soft_calibration(self, args, session):
    #     base_protos = self.fc.weight.data[:args.base_class].detach().cpu().data
    #     # base_protos = F.normalize(base_protos, p=2, dim=-1)
    #     cur_protos = self.fc.weight.data[args.base_class + (session - 1) * args.way: args.base_class + session * args.way].detach().cpu().data
    #     # cur_protos = F.normalize(cur_protos, p=2, dim=-1)
    #     weights = torch.mm(cur_protos, base_protos.T) * args.softmax_t
    #     norm_weights = torch.softmax(weights, dim=1)
    #     delta_protos = torch.matmul(norm_weights, base_protos)
    #     delta_protos = F.normalize(delta_protos, p=2, dim=-1)
    #     updated_protos = (1 - args.shift_weight) * cur_protos + args.shift_weight * delta_protos
    #     self.fc.weight.data[args.base_class + (session - 1) * args.way: args.base_class + session * args.way] = updated_protos

    def soft_calibration(self, args, session):
        # 获取 base 类原型和当前 session 的新类原型
        base_start = 0
        base_end = args.base_class
        cur_start = args.base_class + (session - 1) * args.way
        cur_end = args.base_class + session * args.way

        base_protos = self.fc.weight.data[base_start:base_end].detach().cpu()  # [B, D]
        cur_protos = self.fc.weight.data[cur_start:cur_end].detach().cpu()  # [C, D]

        # 归一化
        base_protos = F.normalize(base_protos, p=2, dim=-1)  # [B, D]
        cur_protos = F.normalize(cur_protos, p=2, dim=-1)  # [C, D]

        # 计算投影
        Bt = base_protos.T  # [D, B]
        BBt_inv = torch.linalg.pinv(base_protos @ base_protos.T)  # [B, B]
        weights = cur_protos @ base_protos.T  # [C, B]
        proj = weights @ BBt_inv @ base_protos  # [C, D]

        # 更新原型
        updated = (1 - args.shift_weight) * cur_protos + args.shift_weight * proj
        updated = F.normalize(updated, p=2, dim=-1)

        self.fc.weight.data[cur_start:cur_end] = updated.to(self.fc.weight.device)

def TET_loss(outputs, labels, criterion, means, lamb):

    T = outputs.size(1)
    Loss_es = 0

    for t in range(T):
        # print(f"outputs[{t}].shape = {outputs[:, t, :].shape}")
        # print(f"labels.shape = {labels.shape}")
        # print(1/0)
        Loss_es += criterion(outputs[:,t, :], labels)
    Loss_es = Loss_es / T # L_TET
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(outputs).fill_(means)
        Loss_mmd = MMDLoss(outputs, y) # L_mse
        # Loss_mmd = criterion(outputs, y)
    else:
        Loss_mmd = 0
    return (1 - lamb) * Loss_es + lamb * Loss_mmd # L_Total

def TET_loss_former(outputs, labels, criterion, means, lamb):
    print("outputs.shape=", outputs.shape)
    print("label.shape=", labels.shape)
    # print(1/0)
    T = outputs.size(0)
    Loss_es = 0

    for t in range(T):
        # print(f"outputs[{t}].shape = {outputs[:, t, :].shape}")
        # print(f"labels.shape = {labels.shape}")
        # print(1/0)
        Loss_es += criterion(outputs[t,:, :], labels)
    Loss_es = Loss_es / T # L_TET
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(outputs).fill_(means)
        Loss_mmd = MMDLoss(outputs, y) # L_mse
        # Loss_mmd = criterion(outputs, y)
    else:
        Loss_mmd = 0
    return (1 - lamb) * Loss_es + lamb * Loss_mmd # L_Total