import os
import sys
import warnings
import torch.nn.functional as F
warnings.filterwarnings("ignore")
# from torch.distributed.pipeline.sync.checkpoint import checkpoint
from tqdm import tqdm
import math
from torch import optim, nn
import logging
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from utils import factory
from utils.data_manager import DataManager
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.toolkit import count_parameters
from pathlib import Path
from ESN.networks import IncrementalViTOOD, _create_vision_transformer


def train_TAP(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        _train(args)


def _train(args):
    logfilename = 'logs/{}_{}_{}_{}_{}'.format(args['seed'], args['model_name'],
                                               args['dataset'], args['init_cls'], args['increment'])
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
    )
    global cls_mean
    global cls_cov
    cls_mean = dict()
    cls_cov = dict()
    _set_random()
    _set_device(args)
    print_args(args)

    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    cnn_curve, nme_curve, nme_curve_ori = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
    model = torch.load(checkpoint)
    model._network.eval()
    device=args['device'][0]
    model._network.to(device)
    model._device = device
    file_path = './cls_save_imr_multi.pth'
    if os.path.exists(file_path):
        cls_save = torch.load(file_path)
        cls_mean = cls_save['cls_mean']
        cls_cov = cls_save['cls_cov']
        print("字典 cls_save 已成功加载")
    else:
        _compute_mean(model=model, data=data_manager, device=args['device'][0], args=args, taskid=9)
    for name, param in model._network.named_parameters():
        if 'clas_w' in name or 'task_tokens' in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

    train_task_adaptive_prediction(model, args, args['device'][0], taskid=9,data_manager=data_manager)
    torch.save(model, path)
    print(f"模型保存在{path}")
    return

def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus


def _set_random():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print("Already set random")


def print_args(args):
    for key, value in args.items():
        logging.info('{}: {}'.format(key, value))


def _compute_mean(model: torch.nn.Module, data, device: torch.device, args=None, taskid=0):
    model._network.eval()
    with torch.no_grad():
        for cls_id in tqdm(range((taskid + 1) * args['init_cls']),
                           desc="Processing classes"):
            train_dataset = data.get_dataset(np.arange(cls_id, cls_id + 1), source='train',
                                             mode='train', appendent=None)
            data_loader_cls = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=8)
            features_per_cls = []
            for i, (_, inputs, targets) in enumerate(data_loader_cls):
                inputs = inputs.to(device)
                image_features = model._network.image_encoder(inputs, instance_tokens=model._network.ts_prompts_1[
                    cls_id // args['init_cls']].weight,
                                                              second_pro=model._network.ts_prompts_2[
                                                                  cls_id // args['init_cls']].weight,
                                                              returnbeforepool=True, )
                features = image_features[:, 0, :]
                features_per_cls.append(features)
            features_per_cls = torch.cat(features_per_cls, dim=0)

            if args["ca_storage_efficient_method"] == 'covariance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(
                    device)

            if args["ca_storage_efficient_method"] == 'variance':
                # features_per_cls = torch.cat(features_per_cls_list, dim=0)
                # print(features_per_cls.shape)
                cls_mean[cls_id] = features_per_cls.mean(dim=0)
                cls_cov[cls_id] = torch.diag(
                    torch.cov(features_per_cls.T) + (torch.eye(cls_mean[cls_id].shape[-1]) * 1e-4).to(device))
            if args["ca_storage_efficient_method"] == 'multi-centroid':
                from sklearn.cluster import KMeans
                n_clusters = 10
                features_np = features_per_cls.cpu().numpy()
                # features_per_cls = torch.cat(features_per_cls_list, dim=0).cpu().numpy()
                kmeans = KMeans(n_clusters=n_clusters)
                kmeans.fit(features_np)
                cluster_lables = kmeans.labels_
                cluster_means = []
                cluster_vars = []
                for i in range(n_clusters):
                    cluster_data = features_np[cluster_lables == i]
                    cluster_mean = torch.tensor(np.mean(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_var = torch.tensor(np.var(cluster_data, axis=0), dtype=torch.float64).to(device)
                    cluster_means.append(cluster_mean)
                    cluster_vars.append(cluster_var)

                cls_mean[cls_id] = cluster_means
                cls_cov[cls_id] = cluster_vars
    file_path = './cls_save_imr_multi.pth'
    cls_save = {
        'cls_mean': cls_mean,
        'cls_cov': cls_cov
    }

    torch.save(cls_save, file_path)

    print(f"字典 cls_save 已成功保存到 {file_path}")


def train_task_adaptive_prediction(model, args, device, taskid=0, data_manager=None):
    model._network.train()
    run_epochs = args["tap_epochs"]
    crct_num = 0
    # param_list = [p for n, p in model._network.named_parameters() if p.requires_grad and 'prompt' not in n]
    # network_params = [{'params': param_list, 'lr': 0.01, 'weight_decay': args["weight_decay"]}]
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model._network.parameters()), momentum=0.9,
                          lr=args["lr"], weight_decay=args["weight_decay"])

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs)
    # criterion = torch.nn.CrossEntropyLoss().to(device)
    # criterion = GHMC().to(device)
    # criterion = FocalLoss().to(device)

    for i in range(taskid):
        crct_num += args['init_cls']
    print(f"TAP设置的crct_num={crct_num}")
    # TODO: efficiency may be improved by encapsulating sampled data into Datasets class and using distributed sampler.
    for epoch in tqdm(range(run_epochs), desc="Processing epochs"):
        model._network.train()
        sampled_data = []
        sampled_label = []
        num_sampled_pcls = args["batch_size"] * 5

        if args["ca_storage_efficient_method"] in ['covariance', 'variance']:
            for c_id in range((taskid + 1) * args['init_cls']):
                mean = torch.tensor(cls_mean[c_id], dtype=torch.float64).to(device)
                cov = cls_cov[c_id].to(device)
                if args["ca_storage_efficient_method"] == 'variance':
                    cov = torch.diag(cov)
                m = MultivariateNormal(mean.float(), cov.float())
                sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                sampled_data.append(sampled_data_single)

                sampled_label.extend([c_id] * num_sampled_pcls)

        elif args["ca_storage_efficient_method"] == 'multi-centroid':
            for c_id in range((taskid + 1) * args['init_cls']):
                for cluster in range(len(cls_mean[c_id])):
                    mean = cls_mean[c_id][cluster]
                    var = cls_cov[c_id][cluster]
                    if var.mean() == 0:
                        continue
                    m = MultivariateNormal(mean.float(),
                                           (torch.diag(var) + 1e-4 * torch.eye(mean.shape[0]).to(mean.device)).float())
                    sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,))
                    sampled_data.append(sampled_data_single)
                    sampled_label.extend([c_id] * num_sampled_pcls)
        else:
            raise NotImplementedError

        sampled_data = torch.cat(sampled_data, dim=0).float().to(device)
        sampled_label = torch.tensor(sampled_label).long().to(device)
        print(sampled_data.shape)

        inputs = sampled_data
        targets = sampled_label

        sf_indexes = torch.randperm(inputs.size(0))
        inputs = inputs[sf_indexes]
        targets = targets[sf_indexes]
        # print(targets)
        running_loss = 0.0
        correct = 0
        total = 0
        for _iter in range(crct_num):
            inp = inputs[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            tgt = targets[_iter * num_sampled_pcls:(_iter + 1) * num_sampled_pcls]
            for i in range(taskid + 1):
                if i == 0:
                    logits = model._network.clas_w[i](inp)['logits']
                else:
                    logit = model._network.clas_w[i](inp)['logits']
                    logits = torch.cat((logits, logit), 1)

            # loss = criterion(logits, tgt)  # base criterion (CrossEntropyLoss)
            probabilities = torch.softmax(logits, dim=1).detach()
            # 取出对应的类别概率
            p_t = probabilities[torch.arange(len(tgt)), tgt]

            def f(x):
                return torch.sin(torch.pi * x / (2 * args["margin_tap"]) - torch.pi / 2) + 1

            weights = f(p_t)
            mask = p_t >= args["margin_tap"]
            weights[mask] = 1
            weights[~mask] = 0
            # 计算交叉熵损失
            ce_loss = F.cross_entropy(logits, tgt, reduction='none')
            # 将调整因子和损失加权
            weight_loss = weights * ce_loss
            # 计算加权后的总损失
            tot = float(mask.sum().item())
            loss = weight_loss.sum() / tot

            # acc1, acc5 = accuracy(logits, tgt, topk=(1, 5))

            if not math.isfinite(loss.item()):
                print("Loss is {}, stopping training".format(loss.item()))
                sys.exit(1)

            optimizer.zero_grad()
            loss.backward()
            # for name, p in model._network.named_parameters():
            #    if p.requires_grad:
            #        print(name, p.grad)
            optimizer.step()
            # torch.cuda.synchronize()
            running_loss += loss.item() * inp.size(0)
            _, predicted = logits.max(1)
            total += inp.size(0)
            correct += predicted.eq(tgt).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        print(f"Train Loss: {epoch_loss:.4f} - Train Acc: {epoch_acc:.2f}%")
        scheduler.step()


class GHMC(nn.Module):
    def __init__(self, bins=100, momentum=0, loss_weight=1.0):
        super(GHMC, self).__init__()
        self.bins = bins
        self.momentum = momentum
        edges = torch.arange(bins + 1).float() / bins
        self.register_buffer('edges', edges)
        self.edges[-1] += 1e-6  # 防止最后一个边界值为1

        if momentum > 0:
            acc_sum = torch.zeros(bins)
            self.register_buffer('acc_sum', acc_sum)

        self.loss_weight = loss_weight

    def forward(self, pred, target, *args, **kwargs):
        # target should be class indices for multi-class
        # target = target.long()

        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros(pred.size(0))

        # Compute gradient (here we assume pred is already logit)
        # g = torch.abs(pred.softmax(dim=-1).detach() - target.float())
        probabilities = torch.softmax(pred, dim=1).detach()
        g= 1 - probabilities[torch.arange(len(target)), target]
        # Calculate sample weight based on gradient binning
        valid = target >= 0  # All valid labels
        tot = max(valid.float().sum().item(), 1.0)
        n = 0  # Count valid bins

        for i in range(self.bins):
            # Find the samples that fall into the current gradient bin
            inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
            num_in_bin = inds.sum().item()

            if num_in_bin > 0:
                if mmt > 0:
                    self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin

                n += 1

        if n > 0:
            weights = weights / n

        # Apply weights for each sample
        weights=weights.to(pred.device)
        loss = F.cross_entropy(pred, target, reduction='none')  # Do not reduce the loss yet
        weighted_loss = loss * weights
        weighted_loss = weighted_loss.sum() / tot
        return weighted_loss * self.loss_weight

class FocalLoss(nn.Module):
    def __init__(self, gamma=1.0, loss_weight=1.0):
        super(FocalLoss, self).__init__()
        self.gamma = gamma  # 控制难易样本的关注程度
        self.loss_weight = loss_weight

    def forward(self, pred, target, *args, **kwargs):
        # 计算概率分布
        probabilities = torch.softmax(pred, dim=1).detach()
        # 取出对应的类别概率
        p_t = probabilities[torch.arange(len(target)), target]
        margin = 0.45
        def f(x):
            return torch.sin(torch.pi * x/(2 * margin) - torch.pi / 2) + 1

        weights = f(p_t)
        mask = p_t >= margin
        weights[mask]=1
        # 计算Focal Loss的调整因子
        # modulating_factor = (1 - p_t) ** self.gamma
        # mask = p_t > 0.3
        # modulating_factor = modulating_factor[mask]
        # 计算交叉熵损失
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        # ce_loss = ce_loss[mask]
        # 将调整因子和损失加权
        # focal_loss = modulating_factor * ce_loss
        weight_loss = weights * ce_loss
        # 计算加权后的总损失
        # tot = float(mask.sum().item())
        tot = float(len(target))
        weighted_loss = weight_loss.sum() / tot

        return weighted_loss * self.loss_weight

