from model_train import SimpleDNN, SubsetSamper
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from main_droptrak import MNISTModelOutput


if __name__ == "__main__":
    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    sampler_test = SubsetSamper([i for i in range(500)])
    test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

    # Initialize the model, loss function, and optimizer
    checkpoint_files = [
        f"./checkpoint/checkpoint_{i}_sample_2500" for i in range(50)
    ]

    for checkpoint_id, checkpoint_file in enumerate(tqdm(checkpoint_files)):
        model_output_tensor_all = 0
        for M in range(1):
            ext_checkpoint_file = checkpoint_file + f".pt"
            model=SimpleDNN().to(device)
            model_output_list = []
            model.load_state_dict(torch.load(ext_checkpoint_file))
            model.eval()
            for _, data in enumerate(tqdm(test_loader)):
                model_output = MNISTModelOutput.model_output(data, model)
                model_output_list.append(model_output)
            model_output_tensor = torch.stack(model_output_list, dim=0)
            model_output_tensor_all += model_output_tensor
        model_output_tensor_all /= 1
        torch.save(model_output_tensor_all, f"./checkpoint/model_output_checkpoint_{checkpoint_id}_sample_2500_M.pt")
