import argparse
import numpy as np
import pickle
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from imagenet_ipc import ImageFolderIPC

def get_data_path(suffix, rand, number):
    return f"./syn_data/{suffix}_10ipc/rand{rand}/SPEC_NUMBER{number}_IPC10"

def get_trainset(path, ipc):
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    trainset = ImageFolderIPC(root=path, transform=transform_train, ipc=ipc)
    return trainset

def get_loader(trainset):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=False, num_workers=2)
    return trainloader

def load_model(ckpt_path):
    CLASS_NUM = 100
    device = 'cuda'
    ckpt = torch.load(ckpt_path, weights_only=False)
    model = torchvision.models.get_model("resnet18", num_classes=CLASS_NUM)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    net = model.to(device)
    cudnn.benchmark = True
    net.load_state_dict({k[len("module."):]:ckpt["state_dict"][k] for k in ckpt["state_dict"]})
    net.eval()
    return net

def get_model(ir, ipc, t, suffix, stage='save_post_cifar100', stage_name_fn=None):
    if stage_name_fn is None:
        ckpt_root = f"./{stage}/{suffix}/rand{t}"
        ckpt_path = os.path.join(ckpt_root, f"{ir}_{ipc}", "ckpt.pth")
    else:
        ckpt_root = f"./{stage}/{suffix}/rand{t}"
        ckpt_path = os.path.join(ckpt_root, stage_name_fn(suffix, ir, ipc))
    net = load_model(ckpt_path)
    return net

def get_info(spec_number, suffix_data, suffix_model, model_stage, model_stage_name_fn, model_ir, random_list):
    trainloader_list = []

    for rand in random_list:
        data_path = get_data_path(suffix_data, rand, spec_number)
        trainset = get_trainset(data_path, ipc=10)
        trainloader = get_loader(trainset)
        trainloader_list.append(trainloader)
    
    model_list = []
    for rand in random_list:
        model = get_model(ir=model_ir, ipc=10, t=rand, suffix=suffix_model, stage=model_stage, stage_name_fn=model_stage_name_fn) 
        model_list.append(model)

    logits = []
    labels = []
    softmaxs = []
    
    for trainloader in trainloader_list:
        for i, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            _logit = []
            for model in model_list:
                outputs = model(inputs)
                _logit.append(outputs)
            _logit = torch.mean(torch.stack(_logit), dim=0)
            logits.append(_logit.cpu().detach().numpy())
            labels.append(targets.cpu().detach().numpy())
            softmaxs.append(torch.nn.functional.softmax(_logit, dim=1).cpu().detach().numpy())
    logits = np.concatenate(logits, axis=0)
    labels = np.concatenate(labels, axis=0)
    softmaxs = np.concatenate(softmaxs, axis=0)
    return logits, labels, softmaxs

def softmax_entropy(logits):
    """
    计算 softmax 输出的信息熵
    :param logits: 原始 logits 值 (未归一化)
    :return: 信息熵
    """
    # 计算 softmax 输出
    exp_logits = np.exp(logits - np.max(logits))  # 防止数值溢出
    softmax_probs = exp_logits / np.sum(exp_logits)
    
    # 计算信息熵
    entropy = -np.sum(softmax_probs * np.log(softmax_probs + 1e-9))  # 加上小值防止 log(0)
    return entropy

def get_info_cifar100(dataloader, suffix_model, model_stage, model_stage_name_fn, model_ir, random_list):
    model_list = []
    for rand in random_list:
        model = get_model(ir=model_ir, ipc=10, t=rand, suffix=suffix_model, stage=model_stage, stage_name_fn=model_stage_name_fn) 
        model.eval()
        model_list.append(model)

    logits, labels, softmaxs = [], [], []
    
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            _logit = [model(inputs) for model in model_list]
            _logit = torch.mean(torch.stack(_logit), dim=0)
            logits.append(_logit.cpu().numpy())
            labels.append(targets.cpu().numpy())
            softmaxs.append(torch.nn.functional.softmax(_logit, dim=1).cpu().numpy())
    logits = np.concatenate(logits, axis=0)
    labels = np.concatenate(labels, axis=0)
    softmaxs = np.concatenate(softmaxs, axis=0)
    return logits, labels, softmaxs

def main():
    parser = argparse.ArgumentParser(description='Compute entropy of a dataset')
    parser.add_argument('--spec-number', type=int)
    parser.add_argument('--model-ir', type=int)

    args = parser.parse_args()
    spec_number = args.spec_number
    model_ir = args.model_ir
    random_list = [0, 1, 2, 3, 4, 5]
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    result = {}
    entropy_list = []
    
    for rand in random_list:
        # for model_rand in random_list:
        distset = ImageFolderIPC(root=f"./syn_data/imb_size10000_extream_lt_specific_10ipc/rand{rand}/SPEC_NUMBER{spec_number}_IPC10", transform=transform, ipc=10)

        distloader = torch.utils.data.DataLoader(distset, batch_size=100, shuffle=False, num_workers=2)

        distill_logits, distill_labels, distill_softmaxs = get_info_cifar100(distloader, "imb_size10000_extream_lt_specific", 'save', lambda _suffix, _ir, _ipc: f"E200_lr0.1_{_suffix}_{_ir}_10000/ckpt.pth", model_ir, [rand])
        # compute each class entropy
        distill_entropy = []
        for i in range(100):
            distill_entropy.append(softmax_entropy(distill_logits[distill_labels==i]))
        distill_entropy = np.array(distill_entropy)
        entropy_list.append(distill_entropy)

        result[rand] = {
            "distill_logits": distill_logits,
            "distill_labels": distill_labels,
            "distill_softmaxs": distill_softmaxs,
            "distill_entropy": distill_entropy
        }

    # Convert entropy list to array and compute mean entropy across `rand` values for each class
    entropy_array = np.stack(entropy_list, axis=0)  # Shape: (len(random_list), 100)
    mean_entropy = np.nanmean(entropy_array, axis=0)  # Mean across `rand` for each class

    print("Mean class-wise entropy across all random seeds:", mean_entropy)

    with open(f"cache/distilled/entropy_{spec_number}_{model_ir}.pkl", "wb") as f:
        pickle.dump({"result": result, "mean_entropy": mean_entropy}, f)

if __name__ == "__main__":
    main()