import os
import sys
import math
import time
import shutil
import argparse
import numpy as np
# import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import LambdaLR
from PIL import Image
from tqdm import tqdm


# from imagenet_ipc import ImageFolderIPC
# from utils_auto_logit import AverageMeter, accuracy, get_parameters, compute_adjustment
sys.path.append('..')
from relabel.utils_fkd import mix_aug


class LT_Dataset(Dataset):
    
    def __init__(self, root, txt, transform=None):
        self.img_path = []
        self.labels = []
        self.transform = transform
        with open(txt) as f:
            for line in f:
                self.img_path.append(os.path.join(root, line.split()[0]))
                self.labels.append(int(line.split()[1]))
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, index):

        path = self.img_path[index]
        label = self.labels[index]
        
        with open(path, 'rb') as f:
            sample = Image.open(f).convert('RGB')
        
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, label, path


val_dir = "/data0/ImageNet/val"
batch_size=256
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_dir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=int(batch_size), shuffle=False, num_workers=8)

def evaluate(model):
    """
    Evaluate the model on the validation set.

    Args:
        model (torch.nn.Module): The model to evaluate.
        val_loader (torch.utils.data.DataLoader): Dataloader for validation data.

    Returns:
        numpy.ndarray: Predicted labels.
        numpy.ndarray: Ground truth labels.
        float: Accuracy of the model.
    """
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_targets = []

    with torch.inference_mode():  # Disable gradient computation for efficiency
        for data, target in tqdm(val_loader, desc="Evaluating"):
            data, target = data.cuda(), target.cuda()
            output = model(data)
            preds = output.argmax(dim=1)  # Get predicted class indices

            all_preds.append(preds.cpu())  # Move to CPU before storing
            all_targets.append(target.cpu())

    # Concatenate tensors to a single array
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    # Compute accuracy
    accuracy = (all_preds == all_targets).mean()

    return all_preds, all_targets, accuracy

import argparse
parser = argparse.ArgumentParser(description="Evaluate model and save logits")
parser.add_argument("--ipc", type=int, required=True, help="IPC value")
parser.add_argument("--auto", action="store_true", help="Use auto mode")
args = parser.parse_args()

ipc = args.ipc
auto = args.auto

if ipc == 50:
    if auto:
        ckpt = torch.load(f"/data0/xxx/imagenet-lt-pretrain/sre2l_imagenet/validate/save/IPC{ipc}/val_rn18_kd/rn18_[4K]_T20_logit_auto/checkpoint.pth.tar")
    else:
        ckpt = torch.load(f"/data0/xxx/imagenet-lt-pretrain/sre2l_imagenet/validate/save/IPC{ipc}/val_rn18_kd/rn18_[4K]_T20_logit_tau_0_checkpoint/checkpoint.pth.tar")
else:
    if auto:
        ckpt = torch.load(f"/data0/xxx/imagenet-lt-pretrain/sre2l_imagenet/validate/save/val_rn18_kd/rn18_[4K]_T20_logit_auto_IPC{ipc}/checkpoint.pth.tar")
    else:
        ckpt = torch.load(f"/data0/xxx/imagenet-lt-pretrain/sre2l_imagenet/validate/save/val_rn18_kd/rn18_[4K]_T20_logit_tau_0_checkpoint_IPC{ipc}/checkpoint.pth.tar")

state = ckpt["state_dict"]
new_state_dict = {k.replace("module.", ""): v for k, v in state.items()}
model = torchvision.models.__dict__["resnet18"](pretrained=False)
model.load_state_dict(new_state_dict)
# model = torch.nn.DataParallel(model, device_ids=[0, 3, 5])
model.cuda().eval()

logits, targets, acc = evaluate(model)

name = f"ipc{ipc}_tau_{'auto' if auto else 0}_logits_targets"

# 保存数据，明确 key 名称
np.savez(name, logits=logits, targets=targets, acc=acc)

