import os
import sys
import warnings

warnings.filterwarnings("ignore")
# from torch.distributed.pipeline.sync.checkpoint import checkpoint
from tqdm import tqdm
import math
from torch import optim, nn
import logging
from utils.toolkit import target2onehot, tensor2numpy, accuracy
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from utils import factory
from utils.data_manager import DataManager
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.toolkit import count_parameters
from pathlib import Path
import torch.nn.functional as F


def joint_training(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        _train(args)


def _train(args):
    logfilename = 'logs/{}_{}_{}_{}_{}'.format(args['seed'], args['model_name'],
                                               args['dataset'], args['init_cls'], args['increment'])
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
    )
    _set_random()
    _set_device(args)
    print_args(args)

    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    model = factory.get_model(args['model_name'], args)
    cnn_curve, nme_curve, nme_curve_ori = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
    model.data_manager = data_manager
    model._cur_task += 1
    cur_task_nbclasses = data_manager.get_task_size(model._cur_task)
    model._total_classes = model._known_classes + cur_task_nbclasses
    model._network.update_fc(model._total_classes, cur_task_nbclasses)
    for name, param in model._network.named_parameters():
        if param.requires_grad:
            print(f"Parameter name: {name}, Number of elements: {param.numel()}")
    logging.info('Learning on {}-{}'.format(model._known_classes, model._total_classes))

    logging.info('All params: {}'.format(count_parameters(model._network)))
    logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))

    train_dataset = data_manager.get_dataset(np.arange(model._known_classes, model._total_classes), source='train',
                                             mode='train', appendent=None)
    test_dataset = data_manager.get_dataset(np.arange(0, model._total_classes), source='test', mode='test')

    model.train_loader = DataLoader(train_dataset, batch_size=model.args["batch_size"], shuffle=True, num_workers=8)
    model.test_loader = DataLoader(test_dataset, batch_size=model.args["batch_size"], shuffle=False, num_workers=8)
    model._network.to(model._device)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model._network.parameters()), momentum=0.9,
                          lr=model.args["lr"], weight_decay=model.args["weight_decay"])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=model.args["epochs"])
    prog_bar = tqdm(range(model.args["epochs"]))
    for _, epoch in enumerate(prog_bar):
        model._network.train()
        losses = 0.
        correct, total = 0, 0
        for k, (_, inputs, targets) in enumerate(model.train_loader):
            inputs, targets = inputs.to(model._device), targets.to(model._device)
            i = 0
            image_features = model._network.image_encoder(inputs, instance_tokens=model._network.ts_prompts_1[i].weight,
                                                second_pro=model._network.ts_prompts_2[i].weight, returnbeforepool=True, )
            feature = image_features[:, 0, :]
            i=len(model._network.clas_w)-1
            logits=model._network.clas_w[i](feature)['logits']
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses += loss.item()

            _, preds = torch.max(logits, dim=1)
            correct += preds.eq(targets.expand_as(preds)).cpu().sum()
            total += len(targets)

        scheduler.step()
        train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

        info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
            model._cur_task, epoch + 1, model.args["epochs"], losses / len(model.train_loader), train_acc)

        prog_bar.set_description(info)
        print(info)

        faa_y_true = []
        total = 0
        cor = 0
        faa_pred = []
        model._network.eval()
        y_pred, y_true = [], []
        val_bar = tqdm(model.test_loader, desc=f"测试 [Val]", leave=False)
        for _, inputs, targets in val_bar:
            inputs, targets = inputs.to(model._device), targets.to(model._device)

            gen_p = []
            m= (targets // model.args["init_cls"]).unsqueeze(1)
            ts_prompts_1 = model._network.ts_prompts_1
            P1 = torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in m], dim=0)
            gen_p.append(P1)
            ts_prompts_2 = model._network.ts_prompts_2
            P2 = torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in m], dim=0)
            gen_p.append(P2)
            with torch.no_grad():
                out_logits = model._network(inputs, gen_p, train=False)

            preds = torch.max(out_logits, dim=1)[1]
            cor += preds.eq(targets.expand_as(preds)).cpu().sum().numpy()
            predicts = torch.topk(out_logits, k=model.topk, dim=1, largest=True, sorted=True)[1]  # [bs, topk]
            faa_pred.append(preds.cpu().numpy())
            faa_y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())
            total += len(targets)
            val_bar.set_postfix(acc=100. * cor / total if total > 0 else 0.0)
        val_bar.close()

        faa_pred = np.concatenate(faa_pred)
        faa_y_true = np.concatenate(faa_y_true)
        faa_tempacc = []
        for class_id in range(0, np.max(faa_y_true), 20):
            idxes = np.where(np.logical_and(faa_y_true >= class_id, faa_y_true < class_id + model.args["increment"]))[0]
            faa_tempacc.append(np.around((faa_pred[idxes] == faa_y_true[idxes]).sum() * 100 / len(idxes), decimals=3))
        print(f"总的准确率为{100. * cor / total :.2f}%")
        print(faa_tempacc)
    model.after_task()
    dir=os.path.join(args['output_dir'],'joint_training',args['dataset'])
    Path(dir).mkdir(parents=True, exist_ok=True)
    ckpt_path = os.path.join(dir, "checkpoint.pth")
    torch.save(model,ckpt_path)


def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus


def _set_random():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("Already set random")


def print_args(args):
    for key, value in args.items():
        logging.info('{}: {}'.format(key, value))


def _compute_mean(model: torch.nn.Module, data, device: torch.device, args=None, taskid=0):
    model._network.eval()
    with torch.no_grad():
        for cls_id in tqdm(range(taskid * args['init_cls'], (taskid + 1) * args['init_cls']),
                           desc="Processing classes"):
            train_dataset = data.get_dataset(np.arange(cls_id, cls_id + 1), source='train',
                                             mode='train', appendent=None)
            data_loader_cls = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=8)
            features_per_cls = []
            for i, (_, inputs, targets) in enumerate(data_loader_cls):
                inputs = inputs.to(device)
                image_features = model._network.image_encoder(inputs, instance_tokens=model._network.ts_prompts_1[
                    cls_id // args['init_cls']].weight,
                                                              second_pro=model._network.ts_prompts_2[
                                                                  cls_id // args['init_cls']].weight,
                                                              returnbeforepool=True, )
                features = image_features[:, 0, :]
                features_per_cls.append(features)
            features_per_cls = torch.cat(features_per_cls, dim=0)

            if args["ca_storage_efficient_method"] == 'covariance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(
                    device)

            if args["ca_storage_efficient_method"] == 'variance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.diag(
                    torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(device))
            if args["ca_storage_efficient_method"] == 'multi-centroid':
                from sklearn.cluster import KMeans
                n_clusters = 10
                features_np = features_per_cls.cpu().numpy()
                # features_per_cls = torch.cat(features_per_cls_list, dim=0).cpu().numpy()
                kmeans = KMeans(n_clusters=n_clusters)
                kmeans.fit(features_np)
                cluster_lables = kmeans.labels_
                cluster_means = []
                cluster_vars = []
                for i in range(n_clusters):
                    cluster_data = features_np[cluster_lables == i]
                    cluster_mean = torch.tensor(np.mean(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_var = torch.tensor(np.var(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_means.append(cluster_mean)
                    cluster_vars.append(cluster_var)

                cls_mean[cls_id] = cluster_means
                cls_cov[cls_id] = cluster_vars
    # cls_save = {
    #     'cls_mean': cls_mean,
    #     'cls_cov': cls_cov
    # }

    # torch.save(cls_save, 'cls_save_cifar.pth')

    # print("字典 cls_save 已成功保存到 cls_save_cifar.pth")


def train_task_adaptive_prediction(model, args, device, taskid=0, data_manager=None):
    model._network.train()
    run_epochs = args["tap_epochs"]
    crct_num = 0
    # param_list = [p for n, p in model._network.named_parameters() if p.requires_grad and 'prompt' not in n]
    # network_params = [{'params': param_list, 'lr': 0.01, 'weight_decay': args["weight_decay"]}]
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model._network.parameters()), momentum=0.9,
                          lr=args["lr"], weight_decay=args["weight_decay"])

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs)
    # criterion = torch.nn.CrossEntropyLoss().to(device)

    for i in range(taskid):
        crct_num += args['init_cls']
    print(f"TAP设置的crct_num={crct_num}")
    # TODO: efficiency may be improved by encapsulating sampled data into Datasets class and using distributed sampler.
    for epoch in tqdm(range(run_epochs), desc="Processing epochs"):
        model._network.train()
        sampled_data = []
        sampled_label = []
        num_sampled_pcls = args["batch_size"] * 5

        if args["ca_storage_efficient_method"] in ['covariance', 'variance']:
            for c_id in range((taskid + 1) * args['init_cls']):
                mean = torch.tensor(cls_mean[c_id], dtype=torch.float64).to(device)
                cov = cls_cov[c_id].to(device)
                if args["ca_storage_efficient_method"] == 'variance':
                    cov = torch.diag(cov)
                m = MultivariateNormal(mean.float(), cov.float())
                sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                sampled_data.append(sampled_data_single)

                sampled_label.extend([c_id] * num_sampled_pcls)

        elif args["ca_storage_efficient_method"] == 'multi-centroid':
            for c_id in range((taskid + 1) * args['init_cls']):
                for cluster in range(len(cls_mean[c_id])):
                    mean = cls_mean[c_id][cluster]
                    var = cls_cov[c_id][cluster]
                    if var.mean() == 0:
                        continue
                    m = MultivariateNormal(mean.float(),
                                           (torch.diag(var) + 1e-4 * torch.eye(mean.shape[0]).to(mean.device)).float())
                    sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                    sampled_data.append(sampled_data_single)
                    sampled_label.extend([c_id] * num_sampled_pcls)
        else:
            raise NotImplementedError

        sampled_data = torch.cat(sampled_data, dim=0).float().to(device)
        sampled_label = torch.tensor(sampled_label).long().to(device)
        print(sampled_data.shape)

        inputs = sampled_data
        targets = sampled_label

        sf_indexes = torch.randperm(inputs.size(0))
        inputs = inputs[sf_indexes]
        targets = targets[sf_indexes]
        # print(targets)
        running_loss = 0.0
        correct = 0
        total = 0
        for _iter in range(crct_num):
            inp = inputs[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            tgt = targets[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            for i in range(taskid + 1):
                if i == 0:
                    logits = model._network.clas_w[i](inp)['logits']
                else:
                    logit = model._network.clas_w[i](inp)['logits']
                    logits = torch.cat((logits, logit), 1)

            # loss = criterion(logits, tgt)  # base criterion (CrossEntropyLoss)
            # acc1, acc5 = accuracy(logits, tgt, topk=(1, 5))
            probabilities = torch.softmax(logits, dim=1).detach()
            # 取出对应的类别概率
            p_t = probabilities[torch.arange(len(tgt)), tgt]

            def f(x):
                return torch.sin(torch.pi * x / (2 * args["margin_tap"]) - torch.pi / 2) + 1

            weights = f(p_t)
            mask = p_t >= args["margin_tap"]
            weights[mask] = 1
            # weights[~mask] = 0
            # 计算交叉熵损失
            ce_loss = F.cross_entropy(logits, tgt, reduction='none')
            # 将调整因子和损失加权
            weight_loss = weights * ce_loss
            # 计算加权后的总损失
            tot = float(len(tgt))
            # tot = float(mask.sum().item())
            loss = weight_loss.sum() / tot

            if not math.isfinite(loss.item()):
                print("Loss is {}, stopping training".format(loss.item()))
                sys.exit(1)

            optimizer.zero_grad()
            loss.backward()
            # for name, p in model._network.named_parameters():
            #    if p.requires_grad:
            #        print(name, p.grad)
            optimizer.step()
            # torch.cuda.synchronize()
            running_loss += loss.item() * inp.size(0)
            _, predicted = logits.max(1)
            total += inp.size(0)
            correct += predicted.eq(tgt).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        print(f"Train Loss: {epoch_loss:.4f} - Train Acc: {epoch_acc:.2f}%")
        scheduler.step()
