import sys
sys.path.append("..")

from droptrak.output_function import BaseModelOutputClass
from model_train import SimpleDNN, SubsetSamper
from droptrak.droptrak import DropTRAK
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import argparse
import time

class MNISTModelOutput(BaseModelOutputClass):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(self):
        super().__init__(self)

    @staticmethod
    def model_output(data, model, *args, **kwargs):
        image, label = data
        image, label = image.to(MNISTModelOutput.device), label.to(MNISTModelOutput.device)
        raw_logit = model(image)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        logp = -loss_fn(raw_logit, label)

        return logp - torch.log(1 - torch.exp(logp))

    @staticmethod
    def get_out_to_loss_grad(data, model, *args, **kwargs):
        image, label = data
        image, label = image.to(MNISTModelOutput.device), label.to(MNISTModelOutput.device)
        raw_logit = model(image)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        p = torch.exp(-loss_fn(raw_logit, label))

        return (1-p).clone().detach().unsqueeze(-1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ensemble", type=int, default=1, help="ensemble number")
    parser.add_argument("--independent", type=int, default=1, help="independent number")
    parser.add_argument("--dropout", action="store_true", help="dropout or not")
    parser.add_argument("--dropout_only_Q", action="store_true", help="dropout only Q or not")
    args = parser.parse_args()

    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    sampler_train = SubsetSamper([i for i in range(5000)])
    sampler_test = SubsetSamper([i for i in range(500)])

    train_loader = DataLoader(train_dataset, batch_size=1, sampler=sampler_train)
    test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

    # Initialize the model, loss function, and optimizer
    if args.dropout_only_Q:
        checkpoint_files = []
        for independent in range(args.independent):
            checkpoint_files += [
                [f"./checkpoint/checkpoint_{independent}_sample_{'2500' if args.dropout else '2500'}.pt" for _ in range(args.ensemble)]
            ]
    else:
        checkpoint_files = []
        for independent in range(args.independent):
            checkpoint_files += [
                f"./checkpoint/checkpoint_{independent}_sample_{'2500' if args.dropout else '2500'}.pt" for _ in range(args.ensemble)
            ]

    trak = DropTRAK(model=SimpleDNN(dropout_rate=0.1).to(device),
                    model_checkpoints=checkpoint_files,
                    train_loader=train_loader,
                    test_loader=test_loader,
                    model_output_class=MNISTModelOutput,
                    device=device,
                    dropout=args.dropout,
                    dropout_only_Q=args.dropout_only_Q)

    torch.cuda.reset_peak_memory_stats("cuda")
    st = time.time()
    score = trak.score()
    print(f"independent: {args.independent}, ensemble: {args.ensemble}, Time used:", time.time() - st)
    peak_memory = torch.cuda.max_memory_allocated("cuda") / 1e6  # Convert to MB
    print(f"Peak memory usage: {peak_memory} MB")

    print(torch.argmax(score, dim=1))

    torch.save(score, f"./score/score_mnist{'_dropout' if (args.dropout or args.dropout_only_Q) else ''}_ensemble_{args.ensemble}_independent_{args.independent}_Q_{args.dropout_only_Q}.pt")
