import os
import sys
import warnings

warnings.filterwarnings("ignore")
# from torch.distributed.pipeline.sync.checkpoint import checkpoint
from tqdm import tqdm
import math
from torch import optim, nn
import logging
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from utils import factory
from utils.data_manager import DataManager
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.toolkit import count_parameters
from pathlib import Path
import torch.nn.functional as F
def train(args):
    seed_list = copy.deepcopy(args['seed']) 
    device = copy.deepcopy(args['device'])
    
    # for seed in seed_list:
    #     args['seed'] = seed
    #     args['device'] = device
    #     _train(args)
    args['seed'] = seed_list[0]
    gpus = []
    gpus.append(torch.device('cuda:{}'.format(device[0])))
    args['device'] = gpus
    # for epochs in range(8,13):
    #     args["tap_epochs"]=epochs
    #     print(f"本轮epochs={args['epochs']},tap_epochs={args['tap_epochs']}")
    # for length in range(55,85,5):
    #     args["prompt_length"]=length
    #     print(f"本轮epochs={args['epochs']},tap_epochs={args['tap_epochs']},prompt_length={length}")
    for margin in [0.2]:
        args["gate1"]=margin
        args["gate2"]=margin
        print(f"本轮gate={args['gate1']}")
        _train(args)
def _train(args):
    logfilename = 'logs/{}_{}_{}_{}_{}'.format( args['seed'], args['model_name'],
                                                args['dataset'], args['init_cls'], args['increment'])
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
    )
    global cls_mean
    global cls_cov
    cls_mean = dict()
    cls_cov = dict()
    _set_random()
    # _set_device(args)
    print_args(args)

    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    model = factory.get_model(args['model_name'], args)
    cnn_curve, nme_curve, nme_curve_ori = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
    for task in range(data_manager.nb_tasks):
        #----------------- 1) 正常做增量训练 ---------------------
        model.incremental_train(data_manager)  # 主模型进行本任务学习
        # state_cpu = torch.get_rng_state()
        # state_cuda = torch.cuda.get_rng_state_all()
        # _set_random()
        # ----------------- 2) 保存主模型(不做TAP) ------------------
        # noTAP_dir = os.path.join(args['output_dir'], f'tap_cov5_naturelis_ep{args["tap_epochs"]}', 'checkpoint_noTAP_fromtrain')
        # Path(noTAP_dir).mkdir(parents=True, exist_ok=True)
        # noTAP_ckpt_path = os.path.join(noTAP_dir, f"task{task + 1}_checkpoint.pth")
        # torch.save(model, noTAP_ckpt_path)  # 保存的是“未做TAP”的主模型
        print("----TAP前的可训练参数----")
        for name, param in model._network.named_parameters():
            if param.requires_grad:
                print(f"Parameter name: {name}, Number of elements: {param.numel()}")
        _compute_mean(model=model, data=data_manager, device=args['device'][0], args=args,taskid=task)
        if task >0:
            # tap_model = copy.deepcopy(model)
            print("----TAP前调整可训练参数----")
            for name, param in model._network.named_parameters():
                param.requires_grad_(False)
            for name, param in model._network.named_parameters():
                if 'task_tokens' in name:
                    param.requires_grad_(True)
                for j in range(task+1):
                    if f'clas_w.{j}' in name:
                        param.requires_grad_(True)
            # 打印可训练参数（可选）
            for name, param in model._network.named_parameters():
                if param.requires_grad:
                    print(f"Parameter name: {name}, Number of elements: {param.numel()}")
            train_task_adaptive_prediction(model, args, args['device'][0],taskid=task,data_manager=data_manager)
            print("----TAP后恢复可训练参数----")
            for name, param in model._network.named_parameters():
                if 'task_tokens' in name or 'keys' in name or 'aux_cla' in name:
                    param.requires_grad_(True)
                else:
                    param.requires_grad_(False)
            # 打印可训练参数（可选）
            for name, param in model._network.named_parameters():
                if param.requires_grad:
                    print(f"Parameter name: {name}, Number of elements: {param.numel()}")
            if task==data_manager.nb_tasks-1:
                tap_dir = os.path.join(args['output_dir'], f'pl{args["prompt_length"]}_tapshan_margin{args["gate2"]}_ep{args["epochs"]}_tapep{args["tap_epochs"]}')
                Path(tap_dir).mkdir(parents=True, exist_ok=True)
                tap_ckpt_path = os.path.join(
                    tap_dir, f"task{task + 1}_checkpoint.pth"
                )
                torch.save(model, tap_ckpt_path)
            # torch.set_rng_state(state_cpu)
            # torch.cuda.set_rng_state_all(state_cuda)
            #评测结果
            # cnn_accy,nme_accy = model.eval_task()
            # logging.info('CNN: {}'.format(cnn_accy['grouped']))
            # cnn_curve['top1'].append(cnn_accy['top1'])
            #
            # logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
            # print('old:{}, new:{}'.format(cnn_accy['grouped']['old'],cnn_accy['grouped']['new']))
            # x=np.array(cnn_curve['top1'])
            # if len(x)>=1:
            #     print("TLO:{}".format(x[-1]))
            #     print("MEAN:{}".format(round(np.mean(x),2)))
            # tap_model 用完就可丢弃，不影响主模型
            # del tap_model
        # else:
            #评测结果
            # torch.set_rng_state(state_cpu)
            # torch.cuda.set_rng_state_all(state_cuda)
        cnn_accy,nme_accy = model.eval_task()
        logging.info('CNN: {}'.format(cnn_accy['grouped']))
        cnn_curve['top1'].append(cnn_accy['top1'])

        logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
        print('old:{}, new:{}'.format(cnn_accy['grouped']['old'],cnn_accy['grouped']['new']))
        x=np.array(cnn_curve['top1'])
        if len(x)>=1:
            print("TLO:{}".format(x[-1]))
            print("MEAN:{}".format(round(np.mean(x),2)))
        model.after_task()
    # path=os.path.join(args["output_dir"],'onemodel_cov5',f'seed_{args["seed"]}')
    # torch.save(model, path)

def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus

def _set_random():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("Already set random")

def print_args(args):
    for key, value in args.items():
        logging.info('{}: {}'.format(key, value))


def _compute_mean(model: torch.nn.Module, data, device: torch.device,args=None,taskid=0):
    model._network.eval()
    with torch.no_grad():
        for cls_id in tqdm(range(taskid*args['init_cls'],(taskid+1)*args['init_cls']), desc="Processing classes"):
            train_dataset = data.get_dataset(np.arange(cls_id,cls_id+1), source='train',
                                                 mode='train', appendent=None)
            data_loader_cls = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=8)
            features_per_cls = []
            for i, (_,inputs, targets) in enumerate(data_loader_cls):
                inputs = inputs.to(device)
                image_features = model._network.image_encoder(inputs, instance_tokens=model._network.ts_prompts_1[cls_id//args['init_cls']].weight,
                                                    second_pro=model._network.ts_prompts_2[cls_id//args['init_cls']].weight, returnbeforepool=True, )
                features = image_features[:, 0, :]
                features_per_cls.append(features)
            features_per_cls = torch.cat(features_per_cls, dim=0)

            if args["ca_storage_efficient_method"] == 'covariance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(device)

            if args["ca_storage_efficient_method"]== 'variance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.diag(
                    torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(device))
            if args["ca_storage_efficient_method"] == 'multi-centroid':
                from sklearn.cluster import KMeans
                n_clusters = 10
                features_np = features_per_cls.cpu().numpy()
                # features_per_cls = torch.cat(features_per_cls_list, dim=0).cpu().numpy()
                kmeans = KMeans(n_clusters=n_clusters)
                kmeans.fit(features_np)
                cluster_lables = kmeans.labels_
                cluster_means = []
                cluster_vars = []
                for i in range(n_clusters):
                    cluster_data = features_np[cluster_lables == i]
                    cluster_mean = torch.tensor(np.mean(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_var = torch.tensor(np.var(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_means.append(cluster_mean)
                    cluster_vars.append(cluster_var)

                cls_mean[cls_id] = cluster_means
                cls_cov[cls_id] = cluster_vars
    # cls_save = {
    #     'cls_mean': cls_mean,
    #     'cls_cov': cls_cov
    # }

    # torch.save(cls_save, 'cls_save_cifar.pth')

    # print("字典 cls_save 已成功保存到 cls_save_cifar.pth")

def train_task_adaptive_prediction(model, args, device,taskid=0,data_manager=None):
    model._network.train()
    run_epochs = args["tap_epochs"]
    crct_num = 0
    # param_list = [p for n, p in model._network.named_parameters() if p.requires_grad and 'prompt' not in n]
    # network_params = [{'params': param_list, 'lr': 0.01, 'weight_decay': args["weight_decay"]}]
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model._network.parameters()), momentum=0.9,
                          lr=args["lr"], weight_decay=args["weight_decay"])

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs)
    # criterion = torch.nn.CrossEntropyLoss().to(device)

    for i in range(taskid):
        crct_num += args['init_cls']
    print(f"TAP设置的crct_num={crct_num}")
    # TODO: efficiency may be improved by encapsulating sampled data into Datasets class and using distributed sampler.
    for epoch in tqdm(range(run_epochs), desc="Processing epochs"):
        model._network.train()
        sampled_data = []
        sampled_label = []
        num_sampled_pcls = args["batch_size"] * 5

        if args["ca_storage_efficient_method"] in ['covariance', 'variance']:
            for c_id in range((taskid+1)* args['init_cls']):
                mean = torch.tensor(cls_mean[c_id], dtype=torch.float64).to(device)
                cov = cls_cov[c_id].to(device)
                if args["ca_storage_efficient_method"] == 'variance':
                    cov = torch.diag(cov)
                m = MultivariateNormal(mean.float(), cov.float())
                sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                sampled_data.append(sampled_data_single)

                sampled_label.extend([c_id] * num_sampled_pcls)

        elif args["ca_storage_efficient_method"] == 'multi-centroid':
           for c_id in range((taskid+1)* args['init_cls']):
               for cluster in range(len(cls_mean[c_id])):
                   mean = cls_mean[c_id][cluster]
                   var = cls_cov[c_id][cluster]
                   if var.mean() == 0:
                       continue
                   m = MultivariateNormal(mean.float(), (torch.diag(var) + 1e-4 * torch.eye(mean.shape[0]).to(mean.device)).float())
                   sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                   sampled_data.append(sampled_data_single)
                   sampled_label.extend([c_id] * num_sampled_pcls)
        else:
            raise NotImplementedError


        sampled_data = torch.cat(sampled_data, dim=0).float().to(device)
        sampled_label = torch.tensor(sampled_label).long().to(device)
        print(sampled_data.shape)

        inputs = sampled_data
        targets = sampled_label

        sf_indexes = torch.randperm(inputs.size(0))
        inputs = inputs[sf_indexes]
        targets = targets[sf_indexes]
        # print(targets)
        running_loss = 0.0
        correct = 0
        total = 0
        for _iter in range(crct_num):
            inp = inputs[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            tgt = targets[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            for i in range(taskid+1):
                if i == 0:
                    logits = model._network.clas_w[i](inp)['logits']
                else:
                    logit = model._network.clas_w[i](inp)['logits']
                    logits = torch.cat((logits, logit), 1)

            # loss = criterion(logits, tgt)  # base criterion (CrossEntropyLoss)
            # acc1, acc5 = accuracy(logits, tgt, topk=(1, 5))
            probabilities = torch.softmax(logits, dim=1).detach()
            # 取出对应的类别概率
            p_t = probabilities[torch.arange(len(tgt)), tgt]

            def f(x):
                return (torch.sin(torch.pi * x / (2 * args["gate2"]) - torch.pi / 2) + 1)**args["gamma"]

            weights = f(p_t)
            mask1 = p_t >= args["gate2"]
            mask2 = (p_t < args["gate2"]) & (p_t >= args["gate1"])
            mask3 = p_t < args["gate1"]
            weights[mask1] = 1
            weights[mask3] = 0
            # 计算交叉熵损失
            ce_loss = F.cross_entropy(logits, tgt, reduction='none')
            # 将调整因子和损失加权
            weight_loss = weights * ce_loss
            # 计算加权后的总损失
            # tot = float(len(tgt))
            tot = float(mask1.sum().item()+mask2.sum().item())
            loss = weight_loss.sum() / tot

            if not math.isfinite(loss.item()):
                print("Loss is {}, stopping training".format(loss.item()))
                sys.exit(1)

            optimizer.zero_grad()
            loss.backward()
            # for name, p in model._network.named_parameters():
            #    if p.requires_grad:
            #        print(name, p.grad)
            optimizer.step()
            # torch.cuda.synchronize()
            running_loss += loss.item() * inp.size(0)
            _, predicted = logits.max(1)
            total += inp.size(0)
            correct += predicted.eq(tgt).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        print(f"Train Loss: {epoch_loss:.4f} - Train Acc: {epoch_acc:.2f}%")
        scheduler.step()
