import sys
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from tqdm import tqdm
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from timm.models import create_model
# from HiDe.engines.hide_tii_engine import *
# import HiDe.vits.hide_prompt_vision_transformer as hide_prompt_vision_transformer
from utils.data_manager import DataManager



def eval_TIIcp(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        test(args)


def test(args):
    _set_device(args)
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    device = args['device'][0]
    model._network.eval()
    model._network.to(device)
    model._device = device

    model_TII = create_model(
        'vit_base_patch16_224',
        pretrained=True,
        num_classes=200,
        drop_rate=0.0,
        drop_path_rate=0.0,
        drop_block_rate=None,
        mlp_structure=[2],
    )
    model_TII.to(device)
    checkpoint_path=''
    if os.path.exists(checkpoint_path):
        print('Loading checkpoint from:', checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        model_TII.load_state_dict(checkpoint['model'])
    else:
        print('No checkpoint found at:', checkpoint_path)
        return
    model_TII.eval()

    task = 9
    test_dataset = data_manager.get_dataset(np.arange(0, (task + 1) * args["increment"]), source='train', mode='train',appendent=None)
    # test_dataset = data_manager.get_dataset(np.arange(0, (task+1)*args["increment"]), source='test', mode='test')
    test_loader = DataLoader(test_dataset, batch_size=args["batch_size"], shuffle=False, num_workers=8)
    total = 0
    cor = 0
    cor_ture=0
    is_tty = sys.stdout.isatty()
    val_bar = tqdm(test_loader, desc=f"测试 [Val]", leave=False, disable=not is_tty)
    all_pt = []        # 存储所有样本的pt值
    error_pt = []      # 存储错误样本的pt值
    error_samples = [] # 存储错误样本数据（可选）
    for _, inputs, targets in val_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        gen_p = []
        #预测task_id
        with torch.no_grad():
            output = model_TII(inputs)
        logits = output['logits']
        # pred_task_id = torch.randint(0, 10, (inputs.size(0),), device=device)

        pred_task_id = torch.max(logits, 1)[1] // args["increment"]
        ts_prompts_1=model._network.ts_prompts_1
        P1=torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in pred_task_id.unsqueeze(1)],dim=0)
        gen_p.append(P1)
        ts_prompts_2=model._network.ts_prompts_2
        P2=torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in pred_task_id.unsqueeze(1)],dim=0)
        gen_p.append(P2)
        with torch.no_grad():
            out_logits=model._network(inputs,gen_p,train=False)
        preds_final=torch.max(out_logits, dim=1)[1]
        cor+=preds_final.eq(targets.expand_as(preds_final)).cpu().sum().numpy()
        total += len(targets)

        #给定真实的task_id
        true_task_id = targets // args["increment"]
        gen_p = []
        ts_prompts_1=model._network.ts_prompts_1
        P1=torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in true_task_id.unsqueeze(1)],dim=0)
        gen_p.append(P1)
        ts_prompts_2=model._network.ts_prompts_2
        P2=torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in true_task_id.unsqueeze(1)],dim=0)
        gen_p.append(P2)
        with torch.no_grad():
            out_logits=model._network(inputs,gen_p,train=False)
        preds_true=torch.max(out_logits, dim=1)[1]
        cor_ture += preds_true.eq(targets.expand_as(preds_true)).cpu().sum().numpy()

        probabilities = torch.softmax(out_logits, dim=1)
        pt_values = probabilities[torch.arange(len(targets)), targets].cpu().numpy()
        all_pt.extend(pt_values.tolist())
        # ================== 收集错误样本 ==================
        # 找出第一部分预测错误但第二部分预测正确的样本
        mask = (preds_final == targets) & (preds_true == targets)
        error_samples.extend(list(zip(inputs[mask].cpu(), targets[mask].cpu())))

        val_bar.set_postfix(acc=100. * cor / total if total > 0 else 0.0)
    val_bar.close()
    print(f"TII+CP正确率为:{100. * cor / total :.2f}")
    print(f"给定task_id的正确率为:{100. * cor_ture / total :.2f}")

    # ================== 结果可视化 ==================
    # 绘制所有样本的pt分布
    plt.figure(figsize=(12, 6))
    plt.hist(all_pt, bins=50, alpha=0.7, color='blue', range=(0, 1))
    plt.title('PT Value Distribution (All Samples)')
    plt.xlabel('PT Value')
    plt.ylabel('Count')
    plt.grid(True)
    plt.show()

    # ================== 分析错误样本 ==================
    if len(error_samples) > 0:
        # 重新处理错误样本
        error_inputs = torch.stack([x[0] for x in error_samples]).to(device)
        error_targets = torch.tensor([x[1] for x in error_samples]).to(device)

        # 获取错误样本的真实task_id
        error_true_task_id = error_targets // args["increment"]
        # 分批处理错误样本
        batch_size = 32  # 可根据显存调整
        for i in range(0, len(error_inputs), batch_size):
            batch_inputs = error_inputs[i:i + batch_size]
            batch_targets = error_targets[i:i + batch_size]
            batch_task_id = error_true_task_id[i:i + batch_size]

            # 生成正确的prompt参数
            gen_p = []
            ts_prompts_1 = model._network.ts_prompts_1
            P1 = torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in batch_task_id.unsqueeze(1)],
                           dim=0)
            gen_p.append(P1)
            ts_prompts_2 = model._network.ts_prompts_2
            P2 = torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in batch_task_id.unsqueeze(1)],
                           dim=0)
            gen_p.append(P2)

            # 使用model计算PT值
            with torch.no_grad():
                output = model._network(batch_inputs, gen_p, train=False)
            logits = output

            probabilities = torch.softmax(logits, dim=1)
            pt_values = probabilities[torch.arange(len(batch_targets)), batch_targets].detach().cpu().numpy()
            error_pt.extend(pt_values.tolist())

        save_path='./hard_cp_pt.pt'
        torch.save(error_pt,save_path)
        # 绘制错误样本的pt分布
        plt.figure(figsize=(12, 6))
        plt.hist(error_pt, bins=50, alpha=0.7, color='red', range=(0, 1))
        plt.title('PT Value Distribution (both-Corrected Samples)')
        plt.xlabel('PT Value')
        plt.ylabel('Count')
        plt.grid(True)
        plt.show()
    else:
        print("No error-corrected samples found.")


    return

def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus