import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

from imagenet_ipc import ImageFolderIPC
def get_trainset(path, ipc):
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    trainset = ImageFolderIPC(root=path, transform=transform_train, ipc=ipc)
    return trainset

def get_loader(trainset):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=False, num_workers=2)
    return trainloader

def load_model(ckpt_path):
    CLASS_NUM = 100
    device = 'cuda'
    ckpt = torch.load(ckpt_path, weights_only=False)
    model = torchvision.models.get_model("resnet18", num_classes=CLASS_NUM)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    net = model.to(device)
    cudnn.benchmark = True
    net.load_state_dict({k[len("module."):]:ckpt["state_dict"][k] for k in ckpt["state_dict"]})
    net.eval()
    return net

def get_model(ir, ipc, t, suffix, stage='save_post_cifar100', stage_name_fn=None):
    if stage_name_fn is None:
        ckpt_root = f"./{stage}/{suffix}/rand{t}"
        ckpt_path = os.path.join(ckpt_root, f"{ir}_{ipc}", "ckpt.pth")
    else:
        ckpt_root = f"./{stage}/{suffix}/rand{t}"
        ckpt_path = os.path.join(ckpt_root, stage_name_fn(suffix, ir, ipc))
    net = load_model(ckpt_path)
    return net



def get_data_path(suffix, rand, number):
    return f"./syn_data/{suffix}_10ipc/rand{rand}/SPEC_NUMBER{number}_IPC10"


import matplotlib.pyplot as plt

def bar(ax, logits, labels, class_id, err=False, suffix_title=""):
    class_logits = logits[labels == class_id]
    mean = class_logits.mean(axis=0)
    std = class_logits.std(axis=0)
    if err:
        ax.bar(range(len(mean)), mean, yerr=std)
    else:
        ax.bar(range(len(mean)), mean)
    ax.set_title(f"{suffix_title}: class {class_id}")
    ax.set_xlabel("class")
    ax.set_ylabel("logit")

# paint the distribution by 20 classes
def bin5(logits):
    def sub_acc(a, b):
        return np.sum(logits[:, a:b], axis=1)
    new_logits = np.zeros((logits.shape[0], 5))
    for i in range(5):
        new_logits[:, i] = sub_acc(i*20, (i+1)*20)/20
    return new_logits



from tqdm import tqdm
def get_info(spec_number, suffix_data, suffix_model, model_stage, model_stage_name_fn, model_ir, random_list, verbose=True):
    trainloader_list = []

    for rand in random_list:
        data_path = get_data_path(suffix_data, rand, spec_number)
        trainset = get_trainset(data_path, ipc=10)
        trainloader = get_loader(trainset)
        trainloader_list.append(trainloader)
    
    model_list = []
    for rand in random_list:
        model = get_model(ir=model_ir, ipc=10, t=rand, suffix=suffix_model, stage=model_stage, stage_name_fn=model_stage_name_fn) 
        model_list.append(model)

    logits = []
    labels = []
    softmaxs = []
    
    for trainloader in trainloader_list:
        for i, (inputs, targets) in tqdm(enumerate(trainloader), total=len(trainloader), disable=not verbose):
            inputs, targets = inputs.cuda(), targets.cuda()
            _logit = []
            for model in model_list:
                outputs = model(inputs)
                _logit.append(outputs)
            _logit = torch.mean(torch.stack(_logit), dim=0)
            logits.append(_logit.cpu().detach().numpy())
            labels.append(targets.cpu().detach().numpy())
            softmaxs.append(torch.nn.functional.softmax(_logit, dim=1).cpu().detach().numpy())
    logits = np.concatenate(logits, axis=0)
    labels = np.concatenate(labels, axis=0)
    softmaxs = np.concatenate(softmaxs, axis=0)
    return logits, labels, softmaxs


import argparse
import pickle
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--spec_number", type=int, default=0)
    parser.add_argument("--model_ir", type=int, default=0)
    args = parser.parse_args()
    
    spec_number = args.spec_number
    model_ir = args.model_ir

    random_list = list(range(7))
    suffix_data = "imb_size10000_extream_lt_specific"
    suffix_model = "imb_size10000_extream_lt_specific"
    model_stage = 'save'
    model_stage_name_fn = lambda _suffix, _ir, _ipc: f"E200_lr0.1_{_suffix}_{_ir}_10000/ckpt.pth"
    # spec_number = 1
    logits, labels, softmaxs = get_info(spec_number, suffix_data, suffix_model, model_stage, model_stage_name_fn, model_ir, random_list, verbose=False)

    # save
    with open(f"./cache/logits/{suffix_data}_{suffix_model}_{model_ir}_{spec_number}.pkl", "wb") as f:
        pickle.dump({"logits": logits, "labels": labels, "softmaxs": softmaxs}, f)
