from model_train import ConvNet, SubsetSamper, ResNet9
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from main import CIFAR2ModelOutput
from model_train import get_cifar2_indices_and_adjust_labels
import argparse

if __name__ == "__main__":
    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="resnet9", help="model to train")
    arg = parser.parse_args()

    # Load Cifar data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    cifar2_indices = get_cifar2_indices_and_adjust_labels(test_dataset)
    all_index = cifar2_indices[:500]
    sampler_test = SubsetSamper(all_index)
    test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

    # Initialize the model, loss function, and optimizer
    checkpoint_files = [
        f"./checkpoint/checkpoint_{arg.model}_{i}.pt" for i in range(50)
    ]

    for checkpoint_id, checkpoint_file in enumerate(tqdm(checkpoint_files)):
        model=ResNet9().to(device)
        model_output_list = []
        model.load_state_dict(torch.load(checkpoint_file))
        model.eval()  # do not use dropout
        for _, data in enumerate(tqdm(test_loader)):
            model_output = CIFAR2ModelOutput.model_output(data, model)
            model_output_list.append(model_output)
        model_output_tensor = torch.stack(model_output_list, dim=0)
        print(model_output_tensor.shape)
        print(model_output_tensor[0:20])
        torch.save(model_output_tensor, f"./checkpoint/model_{arg.model}_output_checkpoint_{checkpoint_id}.pt")
