from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from dataloader.data_utils import get_dataloader, set_up_datasets
from utils import *
from spiking_vgg9 import *
import json
import argparse
import importlib

# from backbone.spiking_vgg11 import *
# from backbone.spiking_resnet20 import *
# from backbone.spiking_resnet18 import *
# from backbone.spiking_vgg11 import svgg11
# from backbone.spiking_resnet19 import *
# from backbone.spikingformer import *
# from backbone.spiking_vgg16 import *
# from backbone.spiking_resnet34 import *

class SAFA(nn.Module):
    def __init__(self, args):
        super(SAFA, self).__init__()
        self.args = args
        set_save_path(self)
        logger = log_to_file(os.path.join(args.save_path, 'log.txt'))
        self.logger = logger
        for arg, value in vars(args).items():
            self.logger.info(f"{arg}: {value}")
        self.model = MYNET(self.args, mode=self.args.base_mode)
        # self.model = nn.DataParallel(self.model, list(range(self.args.num_gpu)))
        self.model = nn.DataParallel(self.model, self.args.gpu)
        self.model = self.model.cuda()
        self.infer_time_sessions=[0.0]*args.sessions
        if self.args.model_dir is not None:
            self.logger.info('Loading init parameters from: %s' % self.args.model_dir)
            self.best_model_dict = torch.load(self.args.model_dir)['params']
        else:
            self.logger.info('random init params')
            if self.args.start_session > 0:
                self.logger.info('WARING: Random init weights for new sessions!')
            self.best_model_dict = deepcopy(self.model.state_dict())

    def train(self, ):
        args = self.args
        logger = self.logger
        max_acc_sessions = [0.0] * args.sessions
        unseen_acc_sessions = [0.0] * args.sessions
        seen_acc_sessions = [0.0] * args.sessions
        # init train statistics
        for session in range(args.start_session, args.sessions):
            train_set, trainloader, testloader = get_dataloader(args, session)
            self.model.load_state_dict(self.best_model_dict)
            if session == 0:  # load base class train img label
                if not args.only_do_incre:
                    logger.info(f'new classes for this session:{np.unique(train_set.targets)}')
                    optimizer, scheduler = get_optimizer_scheduler(self, self.model.parameters())
                    torch.cuda.synchronize()
                    start_time = time.time()
                    for epoch in range(args.epochs_base):
                        epoch_start = time.time()
                        tl, ta = 0.0, 0.0
                        self.model = self.model.train()
                        # standard classification for pretrain
                        tqdm_gen = tqdm(trainloader)
                        for i, batch in enumerate(tqdm_gen, 1):
                            data, train_label = [_.cuda() for _ in batch]
                            logits = self.model(data, session=session)
                            # print("logits.shape:", logits.shape)
                            logits = logits[:, :, :args.base_class]
                            criterion = nn.CrossEntropyLoss().cuda()
                            loss = LOSS(logits, train_label, criterion, args.means, args.lamb)
                            logits = logits.mean(1)
                            acc = count_acc(logits, train_label)
                            total_loss = loss
                            lrc = scheduler.get_last_lr()[0]
                            tqdm_gen.set_description('Session 0, epo {}/{}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch+1, args.epochs_base, lrc, total_loss.item(), acc))
                            tl += total_loss.item()
                            ta += acc
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                        torch.cuda.synchronize()
                        epoch_end = time.time()
                        logger.info(f"Epoch {epoch + 1}/{args.epochs_base} Cost Time: {epoch_end - epoch_start:.2f} s")
                        tl = tl / len(trainloader)
                        ta = ta / len(trainloader)
                        tsl, tsa = self.test(self.model, testloader, epoch, args, session)
                        # save better model
                        if (tsa * 100) >= max_acc_sessions[session]:
                            max_acc_sessions[session] = float('%.3f' % (tsa * 100))
                            max_acc_epoch = epoch
                            save_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                            torch.save(dict(params=self.model.state_dict()), save_model_dir)
                            torch.save(optimizer.state_dict(), os.path.join(args.save_path, 'optimizer_best.pth'))
                            self.best_model_dict = deepcopy(self.model.state_dict())
                            logger.info('********A better model is found!!**********')
                            logger.info('Saving model to :%s' % save_model_dir)
                        logger.info('best epoch {}, best test acc={:.3f}'.format(max_acc_epoch+1, max_acc_sessions[session]))

                        lrc = scheduler.get_last_lr()[0]
                        logger.info('epoch:%d/%d, lr:%.4f, training_loss:%.5f, training_acc:%.5f, test_loss:%.5f, test_acc:%.5f' % (epoch+1,args.epochs_base, lrc, tl, ta, tsl, tsa))
                        print_config(args)
                        scheduler.step()
                    # Finish base train
                    torch.cuda.synchronize()
                    end_time = time.time()
                    logger.info("Total training time: {:.2f} seconds".format(end_time - start_time))
                    logger.info('>>> Finish Base Train <<<')
                    logger.info('Session {}, Test Best Epoch {}, best test Acc {:.4f}'.format(session, max_acc_epoch+1, max_acc_sessions[session]))
                else:
                    logger.info('>>> Load Model &&& Finish base train...')

                    assert args.model_dir is not None

                if not args.not_data_init:
                    self.model.load_state_dict(self.best_model_dict)
                    self.model = self.replace_base_fc(train_set, testloader.dataset.transform, self.model, args)
                    best_model_dir = os.path.join(args.save_path, 'session' + str(session) + '_max_acc.pth')
                    logger.info('Replace the fc with average embedding, and save it to :%s' % best_model_dir)
                    self.best_model_dict = deepcopy(self.model.state_dict())
                    torch.save(dict(params=self.model.state_dict()), best_model_dir)
                    self.model.module.mode = 'avg_cos'
                    tsl, tsa = self.test(self.model, testloader, 0, args, session)
                    if (tsa * 100) >= max_acc_sessions[session]:
                        max_acc_sessions[session] = float('%.3f' % (tsa * 100))
                        logger.info('The new best test acc of base session={:.3f}'.format(max_acc_sessions[session]))
            # incremental learning sessions
            else:
                logger.info("training session: [%d]" % session)
                self.model.module.mode = self.args.new_mode
                self.model.eval()
                trainloader.dataset.transform = testloader.dataset.transform
                self.model.module.update_fc(trainloader, np.unique(train_set.targets), session)
                self.model.module.subspace_projection(args, session)
                tsl, (seenac, unseenac, avgac) = self.test(self.model, testloader, 0, args, session)
                # update results and save model
                logger.info(f"Seen Accuracy: {seenac * 100:.3f}%")
                logger.info(f"Unseen Accuracy: {unseenac * 100:.3f}%")

                max_acc_sessions[session] = float('%.3f' % (avgac * 100))
                unseen_acc_sessions[session] = float('%.3f' % (unseenac * 100))
                seen_acc_sessions[session] = float('%.3f' % (seenac * 100))
                self.best_model_dict = deepcopy(self.model.state_dict())
                logger.info(f"Session {session} ==> Seen Acc:{seenac * 100:.3f}%"
                             f"Unseen Acc:{unseenac * 100:.3f}% Avg Acc:{max_acc_sessions[session]}")
                logger.info('Session {}, test Acc {:.3f}'.format(session, max_acc_sessions[session]))

        save_all_sessions_model_dir = os.path.join(args.save_path, args.dataset + '_' + args.project + '_' + args.network + '_' + str(args.epochs_base) + '.pth')
        torch.save(dict(params=self.model.module.encoder.state_dict()), save_all_sessions_model_dir)

        logger.info('Base Session Best Epoch {}'.format(max_acc_epoch+1))
        logger.info("max_acc:")
        logger.info(max_acc_sessions)
        logger.info("Seen acc:")
        logger.info(seen_acc_sessions)
        logger.info('Unseen acc:')
        logger.info(unseen_acc_sessions)
        hmeans = harm_mean(seen_acc_sessions, unseen_acc_sessions)
        logger.info('Harmonic mean:')
        logger.info(hmeans)
        logger.info("Total training time: {:.2f} seconds, average training time:{:2f}".format(end_time - start_time, (end_time-start_time)/args.epochs_base))
        logger.info(f'Infer time: {self.infer_time_sessions}')
        logger.info(f' average time: {sum(self.infer_time_sessions)/len(self.infer_time_sessions)}')
        logger.info(f'save path is {args.save_path}')
        print_config(args,logger,is_end=True)
        logger.info(f"beta={args.beta},theta={args.theta}")

    def replace_base_fc(self, trainset, transform, model, args):
        # replace fc.weight with the embedding average of train data
        model = model.eval()
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, num_workers=8, pin_memory=True, shuffle=False)
        trainloader.dataset.transform = transform
        embedding_list = []
        label_list = []
        with torch.no_grad():
            for i, batch in enumerate(trainloader):
                data, label = [_.cuda() for _ in batch]
                model.module.mode = 'encoder'
                embedding = model(data)
                embedding = embedding.mean(1)
                embedding_list.append(embedding.cpu())
                label_list.append(label.cpu())
        embedding_list = torch.cat(embedding_list, dim=0)
        label_list = torch.cat(label_list, dim=0)
        proto_list = []
        for class_index in range(args.base_class):
            data_index = (label_list == class_index).nonzero()
            embedding_this = embedding_list[data_index.squeeze(-1)]
            embedding_this = embedding_this.mean(0)
            proto_list.append(embedding_this)
        proto_list = torch.stack(proto_list, dim=0)
        # print("proto_list.shape=", proto_list.shape)
        model.module.fc.weight.data[:args.base_class] = proto_list
        return model

    def test(self, model, testloader, epoch, args, session):
        logger = self.logger
        test_class = args.base_class + session * args.way
        model = model.eval()
        vl, va, va5 = 0.0, 0.0, 0.0
        lgt = torch.tensor([])
        lbs = torch.tensor([])
        with torch.no_grad():
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()  # 开始记录
            for i, batch in enumerate(testloader, 1):
                data, test_label = [_.cuda() for _ in batch]
                logits = model(data)
                logits = logits.mean(1)
                logits = logits[:, :test_class]
                loss = F.cross_entropy(logits, test_label)
                acc = count_acc(logits, test_label)
                top5acc = count_acc_topk(logits, test_label)
                vl += loss.item()
                va += acc
                va5 += top5acc
                lgt = torch.cat([lgt, logits.cpu()])
                lbs = torch.cat([lbs, test_label.cpu()])
            end_event.record()  # 结束记录
            torch.cuda.synchronize()  # 等待所有 GPU 操作完成
            total_time_ms = start_event.elapsed_time(end_event)
            average_time_ms = total_time_ms / len(testloader)
            self.infer_time_sessions[session] = average_time_ms
            self.logger.info("session {}: total inference time: {:.4f} ms, each batch inference time: {:.4f} ms".format(session, total_time_ms, session, average_time_ms))
            vl = vl / len(testloader)
            va = va / len(testloader)
            va5 = va5 / len(testloader)
            self.logger.info('epo {}/{}, test, loss={:.4f} acc={:.4f}, acc@5={:.4f}'.format(epoch+1, args.epochs_base, vl, va, va5))
            print_config(args)
            lgt = lgt.view(-1, test_class)
            lbs = lbs.view(-1)

            if session > 0:
                save_model_dir = os.path.join(args.save_path, 'session' + str(session) + 'confusion_matrix')
                cm = confmatrix(lgt, lbs, save_model_dir)
                perclassacc = cm.diagonal()
                seenac = np.mean(perclassacc[:args.base_class])
                unseenac = np.mean(perclassacc[args.base_class:])
                logger.info(f"Seen Acc:{seenac}  Unseen Acc:{unseenac}")
                return vl, (seenac, unseenac, va)
            else:
                return vl, va

class MYNET(nn.Module):
    def __init__(self, args, mode=None):
        super().__init__()
        self.mode = mode
        self.args = args
        if self.args.dataset == 'cifar100':
            C, H, W = 3, 32, 32
        elif self.args.dataset == 'mini_imagenet':
            C, H, W = 3, 84, 84
        elif self.args.dataset == 'cub200':
            C, H, W = 3, 224, 224
        elif self.args.dataset == 'n_caltech101':
            C, H, W = 3, 300, 200
        if self.args.network == 'svgg9':
            self.encoder = VGG9SNN(time_step=args.time_step, args=args)
            self.num_features = 1024
        elif self.args.network == 'svgg5':
            # print(1/0)
            self.encoder = VGG5SNN(time_step=args.time_step, args=args)
            self.num_features = 1024
        elif self.args.network == 'svgg16' :
            self.encoder = svgg16(C=C,H=H,W=W)
            self.num_features = 4096
        elif self.args.network == 'svgg11':
            self.encoder = VGG11SNN(time_step=args.time_step, args=args)
        elif self.args.network == 'spikingformer':
            self.encoder = Spikingformer(H=H, W=W, C=C)
        elif self.args.network == 'sresnet19':
                self.encoder = resnet19(num_classes=100, time_step=args.time_step)
        elif self.args.network == 'sresnet18':
            self.encoder = resnet18(args)
        elif self.args.network == 'sresnet20':
            self.encoder = resnet20(num_classes=100,time_step=args.time_step)
        elif self.args.network == 'sresnet34':
            self.encoder = resnet34(num_classes=100,time_step=args.time_step)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.num_features, self.args.num_classes, bias=False)

    def forward_metric(self, x, session=0):
        x = self.encode(x, session=session)
        if 'cos' in self.mode:
            x = F.linear(F.normalize(x, p=2, dim=-1), F.normalize(self.fc.weight, p=2, dim=-1))
            x = self.args.temperature * x
        elif 'dot' in self.mode:
            x = self.fc(x)
            x = self.args.temperature * x
        return x

    def encode(self, x, session=0):
        x = self.encoder(x, session=session, args=self.args)
        return x

    def forward(self, input, session=0):
        if self.mode != 'encoder':
            input = self.forward_metric(input, session=session)
            return input
        elif self.mode == 'encoder':
            input = self.encode(input, session=session)
            return input
        else:
            raise ValueError('Unknown mode')

    def update_fc(self, dataloader, class_list, session):
        for batch in dataloader:
            data, label = [_.cuda() for _ in batch]
            data = self.encode(data, session=session).detach()
            data = data.mean(1)
        if self.args.not_data_init:
            new_fc = nn.Parameter(
                torch.rand(len(class_list), self.num_features, device="cuda"),
                requires_grad=True)
            nn.init.kaiming_uniform_(new_fc, a=math.sqrt(5))
        else:
            new_fc = self.update_fc_avg(data, label, class_list)

    def update_fc_avg(self, data, label, class_list):
        new_fc = []
        for class_index in class_list:
            data_index = (label == class_index).nonzero().squeeze(-1)
            embedding = data[data_index]
            proto = embedding.mean(0)
            new_fc.append(proto)
            self.fc.weight.data[class_index] = proto
        new_fc = torch.stack(new_fc, dim=0)
        return new_fc

    def get_logits(self, x, fc):
        if 'dot' in self.args.new_mode:
            return F.linear(x, fc)
        elif 'cos' in self.args.new_mode:
            return self.args.temperature * F.linear(F.normalize(x, p=2, dim=-1), F.normalize(fc, p=2, dim=-1))

    def subspace_projection(self, args, session):
        # 获取 base 类原型和当前 session 的新类原型
        base_start = 0
        base_end = args.base_class
        cur_start = args.base_class + (session - 1) * args.way
        cur_end = args.base_class + session * args.way

        base_protos = self.fc.weight.data[base_start:base_end].detach().cpu()  # [B, D]
        cur_protos = self.fc.weight.data[cur_start:cur_end].detach().cpu()  # [C, D]

        # 归一化
        base_protos = F.normalize(base_protos, p=2, dim=-1)  # [B, D]
        cur_protos = F.normalize(cur_protos, p=2, dim=-1)  # [C, D]

        # 计算投影
        Bt = base_protos.T  # [D, B]
        BBt_inv = torch.linalg.pinv(base_protos @ base_protos.T)  # [B, B]
        weights = cur_protos @ base_protos.T  # [C, B]
        proj = weights @ BBt_inv @ base_protos  # [C, D]

        # 更新原型
        updated = (1 - args.shift_weight) * cur_protos + args.shift_weight * proj
        updated = F.normalize(updated, p=2, dim=-1)

        self.fc.weight.data[cur_start:cur_end] = updated.to(self.fc.weight.device)

def LOSS(outputs, labels, criterion, means, lamb):
    T = outputs.size(1)
    loss_es = sum(criterion(outputs[:, t, :], labels) for t in range(T)) / T
    target = torch.full_like(outputs, means)
    loss_mmd = torch.nn.functional.mse_loss(outputs, target)
    return (1 - lamb) * loss_es + lamb * loss_mmd


def set_seed(seed):
    if seed == 0:
        print(' random seed')
        torch.backends.cudnn.benchmark = True
    else:
        print('manual seed:', seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


# def set_gpu(args):
#     gpu_list = [int(x) for x in args.gpu.split(',')]
#     print('use gpu:', gpu_list)
#     os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
#     os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#     return gpu_list.__len__()

def get_optimizer_scheduler(self, optimize_parameters=None):
    optimizer = torch.optim.Adam(optimize_parameters, self.args.lr_base, weight_decay=self.args.decay)
    if self.args.schedule == 'Step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step, gamma=self.args.gamma)
    elif self.args.schedule == 'Milestone':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.milestones,
                                                         gamma=self.args.gamma)
    elif self.args.schedule == 'Cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs_base)
    return optimizer, scheduler


def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    if torch.cuda.is_available():
        return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    else:
        return (pred == label).type(torch.FloatTensor).mean().item()

def setup_parser():
    parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorithms.')
    parser.add_argument('--config', type=str, default='./safa.json',
                        help='Json file of settings.')

    return parser

class Args:
    def __init__(self, args_dict):
        self.__dict__.update(args_dict)

if __name__ == '__main__':
    args = setup_parser().parse_args()
    with open(args.config) as data_file:
        param = json.load(data_file)
    args = vars(args)
    args.update(param)
    args.update(args["config"][args["dataset"]])

    args = Args(args)
    print(args)

    args = set_up_datasets(args)
    set_seed(args.seed)

    args.class_name = args.project.upper()
    trainer = getattr(importlib.import_module('%s' % args.project), args.class_name)(args)
    trainer.train()
