import random
import torch
import os
import time
import sys
import io

import numpy as np
import pprint as pprint
import logging
_utils_pp = pprint.PrettyPrinter()
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import matplotlib
from collections import OrderedDict, defaultdict
from torch.utils.data import DataLoader, Sampler
from modules.switch_module import *
from modules.rotated_modules import *
from tqdm import tqdm
import torch.nn as nn
# import numpy as np
# import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def pprint(x):
    _utils_pp.pprint(x)


def set_seed(seed):
    if seed == 0:
        print(' random seed')
        torch.backends.cudnn.benchmark = True
    else:
        print('manual seed:', seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def set_gpu(args):
    gpu_list = [int(x) for x in args.gpu.split(',')]
    print('use gpu:', gpu_list)
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    return gpu_list.__len__()

class Timer():

    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)


def count_acc_(logits, label, fp, fn, tp, tn, FPR, FNR):
    pred = torch.argmax(logits, dim=1)
    #
    if is_last:

        pred = torch.argmax(logits, dim=1)

        # mistake_b =   # base (<60) -> novel (>=60)
        # mistake_n = 0  # novel (>=60) -> base (<60)
        # base_correct = 0
        # novel_correct = 0
        threshold = 60
        for idx in range(len(label)):
            true = label[idx].item()
            p = pred[idx].item()
            if true < threshold:
                if p >= threshold:
                    fn += 1
                else:
                    tp += 1
            else:  # true >= threshold
                if p < threshold:
                    fp += 1
                else:
                    tn += 1

        # FBR = mistake_b / (base_correct + mistake_n)
        denom_fbr = base_correct + mistake_n
        FBR = mistake_b / denom_fbr if denom_fbr > 0 else 0.0

        # FNR = mistake_n / (novel_correct + mistake_b)
        denom_fnr = novel_correct + mistake_b
        FNR = mistake_n / denom_fnr if denom_fnr > 0 else 0.0

        # res = analyze_classification(logits, label, threshold=60)
        # for k, v in res.items():
        #     logger.info(f"{k}: {v}")
    #
    if torch.cuda.is_available():
        if is_last:
            return (pred == label).type(torch.cuda.FloatTensor).mean().item(), fp, fn, tp, tn
        else:
            return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    else:
        return (pred == label).type(torch.FloatTensor).mean().item()

def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    if torch.cuda.is_available():
            return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    else:
        return (pred == label).type(torch.FloatTensor).mean().item()

def save_list_to_txt(name, input_list):
    f = open(name, mode='a')
    for item in input_list:
        f.write(str(item) + '\n')
    f.close()

def get_optimizer_scheduler(self, optimize_parameters=None):
    optimizer = torch.optim.Adam(optimize_parameters, self.args.lr_base, weight_decay=self.args.decay)
    if self.args.schedule == 'Step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.step, gamma=self.args.gamma)
    elif self.args.schedule == 'Milestone':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.milestones,
                                                         gamma=self.args.gamma)
    elif self.args.schedule == 'Cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args.epochs_base)
    return optimizer, scheduler

def set_save_path(self):
        mode = self.args.base_mode + '_' + self.args.new_mode
        self.args.save_path = '%s/%s/%s_%s/epo_b%d-epo_n%d-bs_b%d_-bs_n%d-bs_t%d_step%s_seed_%d_lr_b_%.4f_way_%d_shot_%d' % (self.args.dataset, self.args.project, self.args.network, mode,
                                                            self.args.epochs_base, self.args.epochs_new, self.args.batch_size_base, self.args.batch_size_new, self.args.test_batch_size, self.args.time_step, self.args.seed,  self.args.lr_base,
                                                                                                                             self.args.way, self.args.shot)
        # if 'cos' in mode:
        #     self.args.save_path = self.args.save_path + '-T_%.2f' % self.args.temperature
        self.args.save_path = os.path.join('checkpoint', self.args.save_path)
        if os.path.exists(self.args.save_path):
            pass
        else:
            print('create folder:', self.args.save_path)
            os.makedirs(self.args.save_path)
        return None

# def log(out, log_str):
#     out.write(log_str + '\n')
#     out.flush()
#     print(log_str)

def count_acc_topk(x,y,k=5):
    _,maxk = torch.topk(x,k,dim=-1)
    total = y.size(0)
    test_labels = y.view(-1,1)
    #top1=(test_labels == maxk[:,0:1]).sum().item()
    topk=(test_labels == maxk).sum().item()
    return float(topk/total)

def log_to_file(log_name):
# def log_to_file():
    logger = logging.getLogger('my_logger')
    logger.setLevel(logging.DEBUG)  # 设置日志级别，DEBUG表示记录所有级别的日志
    # 创建一个FileHandler，将日志记录到指定文件，将以下内容注释节省io
    file_handler = logging.FileHandler(log_name, mode='a')
    file_handler.setLevel(logging.DEBUG)  # 可以根据需要设置日志级别
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)
    # 将文件处理器添加到日志记录器
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.info("hello")
    return logger

class BufferedLogger:
    def __init__(self, name='buffered_logger', level=logging.INFO,
                 fmt='%(asctime)s - %(levelname)s - %(message)s',
                 datefmt='%Y-%m-%d %H:%M:%S'):
        # 内存缓冲区
        self.buffer = io.StringIO()

        # 日志格式器
        formatter = logging.Formatter(fmt, datefmt)

        # handler1：写入内存
        self.buffer_handler = logging.StreamHandler(self.buffer)
        self.buffer_handler.setFormatter(formatter)

        # handler2：输出到控制台
        self.console_handler = logging.StreamHandler(sys.stdout)
        self.console_handler.setFormatter(formatter)

        # 创建 logger
        self.logger = logging.getLogger(name)
        self.logger.setLevel(level)
        self.logger.addHandler(self.buffer_handler)
        self.logger.addHandler(self.console_handler)
        self.logger.propagate = False  # 防止重复输出到 root logger 控制台

    def get_logger(self):
        return self.logger

    def write_to_file(self, filepath, mode='a'):
        """将日志缓存写入文件，可选覆盖或附加（默认附加）"""
        with open(filepath, mode, encoding='utf-8') as f:
            f.write(self.buffer.getvalue())

    def get_log_text(self):
        """获取当前内存中的日志内容"""
        return self.buffer.getvalue()

    def clear(self):
        """清空日志缓冲区"""
        self.buffer.seek(0)
        self.buffer.truncate()

    def close(self):
        """关闭资源并移除 handler"""
        self.logger.removeHandler(self.buffer_handler)
        self.logger.removeHandler(self.console_handler)
        self.buffer.close()

def confmatrix(logits, label, filename):
    font = {'family': 'DejaVu Sans', 'size': 18}
    matplotlib.rc('font', **font)
    matplotlib.rcParams.update({'font.family': 'DejaVu Sans', 'font.size': 18})
    plt.rcParams["font.family"] = "DejaVu Sans"

    pred = torch.argmax(logits, dim=1)
    cm = confusion_matrix(label, pred, normalize='true')
    # print(cm)
    clss = len(cm)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.imshow(cm, cmap=plt.cm.jet)
    if clss <= 100:
        plt.yticks([0, 19, 39, 59, 79, 99], [0, 20, 40, 60, 80, 100], fontsize=16)
        plt.xticks([0, 19, 39, 59, 79, 99], [0, 20, 40, 60, 80, 100], fontsize=16)
    elif clss <= 200:
        plt.yticks([0, 39, 79, 119, 159, 199], [0, 40, 80, 120, 160, 200], fontsize=16)
        plt.xticks([0, 39, 79, 119, 159, 199], [0, 40, 80, 120, 160, 200], fontsize=16)
    else:
        plt.yticks([0, 199, 399, 599, 799, 999], [0, 200, 400, 600, 800, 1000], fontsize=16)
        plt.xticks([0, 199, 399, 599, 799, 999], [0, 200, 400, 600, 800, 1000], fontsize=16)

    plt.xlabel('Predicted Label', fontsize=20)
    plt.ylabel('True Label', fontsize=20)
    plt.tight_layout()
    plt.savefig(filename + '.pdf', bbox_inches='tight')
    plt.close()

    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.imshow(cm, cmap=plt.cm.jet)
    cbar = plt.colorbar(cax)  # This line includes the color bar
    cbar.ax.tick_params(labelsize=16)
    if clss <= 100:
        plt.yticks([0, 19, 39, 59, 79, 99], [0, 20, 40, 60, 80, 100], fontsize=16)
        plt.xticks([0, 19, 39, 59, 79, 99], [0, 20, 40, 60, 80, 100], fontsize=16)
    elif clss <= 200:
        plt.yticks([0, 39, 79, 119, 159, 199], [0, 40, 80, 120, 160, 200], fontsize=16)
        plt.xticks([0, 39, 79, 119, 159, 199], [0, 40, 80, 120, 160, 200], fontsize=16)
    else:
        plt.yticks([0, 199, 399, 599, 799, 999], [0, 200, 400, 600, 800, 1000], fontsize=16)
        plt.xticks([0, 199, 399, 599, 799, 999], [0, 200, 400, 600, 800, 1000], fontsize=16)
    plt.xlabel('Predicted Label', fontsize=20)
    plt.ylabel('True Label', fontsize=20)
    plt.tight_layout()
    plt.savefig(filename + '_cbar.pdf', bbox_inches='tight')
    plt.close()

    return cm

def harm_mean(seen, unseen):
    # compute from session1
    assert len(seen) == len(unseen)
    harm_means = []
    for _seen, _unseen in zip(seen, unseen):
        _hmean = (2 * _seen * _unseen) / (_seen + _unseen + 1e-12)
        _hmean = float('%.3f' % (_hmean))
        harm_means.append(_hmean)
    return harm_means

# class Averager():
#
#     def __init__(self):
#         self.v = 0
#         self.acc_v = 0
#         self.n = 0
#
#     def add(self, x, current_n=1):
#         # self.acc_v += x * current_n
#         self.acc_v += x
#         self.n += current_n
#         self.v = self.acc_v / self.n
#
#     def item(self):
#         try:
#             return self.acc_v / self.n
#         except ZeroDivisionError:
#             return 0.

# class Averager_Loss():
#
#     def __init__(self):
#         self.v = 0
#         self.acc_v = 0
#         self.n = 0
#
#     def add(self, x, current_n=1):
#         self.acc_v += x * current_n
#         # self.acc_v += x
#         self.n += current_n
#         self.v = self.acc_v / self.n
#
#     def item(self):
#         return self.v

def identify_importance(args, model, trainset, batchsize=60, keep_ratio=0.1, session=0, way=10, new_labels=None):
    importances = OrderedDict()
    temp = OrderedDict()
    dl = DataLoader(trainset, shuffle=False, batch_size=batchsize)
    model.eval().cuda()
    for module in model.modules():
        if isinstance(module, WaRPModule):
            module.coeff_mask_prev = module.coeff_mask.data
            module.coeff_mask.data = torch.zeros(module.coeff_mask.shape).cuda().data
    training = model.training
    epoch_iter = tqdm(dl)
    for i, batch in enumerate(epoch_iter):
        if session == 0:
            x, y = [_.cuda() for _ in batch]
        else:
            x, y = batch.cuda(), new_labels.cuda()[i * batchsize:(i+1) * batchsize]
        yhat = model(x)[:, :args.base_class + session * way]
        loss = nn.CrossEntropyLoss()(yhat, y)
        model.zero_grad()
        loss.backward()
        for module in model.modules():
            if isinstance(module, WaRPModule):
                temp[module] = module.basis_coeff.grad.abs().detach().cpu().numpy().copy()
        for module in model.modules():
            if isinstance(module, WaRPModule):
                if module not in importances:
                    print(1/0)
                    importances[module] = temp[module]
                else:
                    print(1/0)
                    importances[module] += temp[module]
    flat_importances = flatten_importances_module(importances)
    threshold = fraction_threshold(flat_importances, keep_ratio)
    masks = importance_masks_module(importances, threshold)
    for module in model.modules():
        if isinstance(module, WaRPModule):
            coeff_mask = masks[module]
            coeff_mask = same_device(ensure_tensor(coeff_mask), module.basis_coefficients)
            module.coeff_mask.data = 1 - (1 - coeff_mask.data) * (1 - module.coeff_mask_prev.data)
    # -------------------------- get accumulative mask ratio ---------------------------------
    for module in model.modules():
        if isinstance(module, WaRPModule):
            masks[module] = module.coeff_mask.data.detach().cpu().numpy().copy()
    print(flatten_importances_module(masks).mean())
    # ----------------------------------------------------------------------------------------
    model.zero_grad()
    model.training = training
    for module in model.modules():
        if hasattr(module, 'weight') and not isinstance(module, WaRPModule):
            for param in module.parameters():
                param.requires_grad = False
    return model

def flatten_importances_module(importances):
    for _, params in importances.items():
        print("type(params)=", type(params))  # 打印params的类型
        print("params.shape=",params.shape)  # 如果params是ndarray，检查它的形状
    return np.concatenate([params.flatten()
        for _, params in importances.items()
    ])

def get_gacc(ratio, all_acc):
    alpha = [i for i in range(ratio + 1)]
    g_acc = generalised_avg_acc(alpha, all_acc)
    area = np.trapz(g_acc, x=alpha) / ((alpha[-1] - alpha[0]) * 100)
    return g_acc, area

def print_config(args, logger=None, is_end=False):
    if is_end:
        logger.info(f'method: {args.project}, dataset: {args.dataset}, backbone: {args.network}, epochs_base: {args.epochs_base}, epochs_new: {args.epochs_new},'
                     f' bs_base: {args.batch_size_base}, bs_new: {args.batch_size_new}, test_bs: {args.test_batch_size},'
                     f'base_mode: {args.base_mode}, new_mode: {args.new_mode}, seed: {args.seed}, n_way: {args.way}, k_shot: {args.shot}, lr_new: {args.lr_new}, time_step: {args.time_step}')
    else:
        print(
            f'method: {args.project}, dataset: {args.dataset}, backbone: {args.network}, epochs_base: {args.epochs_base}, epochs_new: {args.epochs_new},'
            f' bs_base: {args.batch_size_base}, bs_new: {args.batch_size_new}, test_bs: {args.test_batch_size},'
            f'base_mode: {args.base_mode}, new_mode: {args.new_mode}, seed: {args.seed}, n_way: {args.way}, k_shot: {args.shot}, time_step: {args.time_step}, lr_base: {args.lr_base}')

def debug_pr(x, name):
    print(f'{name}.shape = {x.shape}')



def plot_tsne(features, labels, title='t-SNE Visualization', save_path=None):
    """
    用 t-SNE 可视化特征向量（降维到 2D 并根据标签上色）

    Args:
        features: numpy.ndarray, shape = [N, D]
        labels: numpy.ndarray, shape = [N]
        title: 图标题
        save_path: 保存图片路径（可选）
    """
    print("Standardizing features...")
    features_std = StandardScaler().fit_transform(features)

    print("Running t-SNE...")
    tsne = TSNE(n_components=2, init='pca', random_state=42, perplexity=30)
    tsne_result = tsne.fit_transform(features_std)

    print("Plotting...")
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels, cmap='tab10', s=10, alpha=0.7)
    plt.colorbar(scatter, ticks=np.unique(labels))
    plt.title(title)
    plt.xlabel('t-SNE Dim 1')
    plt.ylabel('t-SNE Dim 2')
    plt.grid(True)

    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"Saved t-SNE plot to {save_path}")
    else:
        plt.show()


import torch

def analyze_classification(logits, labels, threshold=60):
    pred = torch.argmax(logits, dim=1)

    mistake_b = 0  # base (<60) -> novel (>=60)
    mistake_n = 0  # novel (>=60) -> base (<60)

    base_correct = 0
    novel_correct = 0

    for idx in range(len(labels)):
        true = labels[idx].item()
        p = pred[idx].item()
        if true < threshold:
            if p >= threshold:
                mistake_b += 1
            elif p == true:
                base_correct += 1
        else:  # true >= threshold
            if p < threshold:
                mistake_n += 1
            elif p == true:
                novel_correct += 1

    # FBR = mistake_b / (base_correct + mistake_n)
    denom_fbr = base_correct + mistake_n
    FBR = mistake_b / denom_fbr if denom_fbr > 0 else 0.0

    # FNR = mistake_n / (novel_correct + mistake_b)
    denom_fnr = novel_correct + mistake_b
    FNR = mistake_n / denom_fnr if denom_fnr > 0 else 0.0

    result = {
        'mistake_b': mistake_b,
        'mistake_n': mistake_n,
        'base_correct': base_correct,
        'novel_correct': novel_correct,
        'FBR': FBR,
        'FNR': FNR,
        'acc': (pred == labels).float().mean().item()
    }

    return result
