import sys
sys.path.append("..")

from droptrak.output_function import BaseModelOutputClass
from model_train import ResNet9, ConvNet, SubsetSamper, get_cifar2_indices_and_adjust_labels
from droptrak.droptrak import DropTRAK
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import argparse


class CIFAR2ModelOutput(BaseModelOutputClass):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(self):
        super().__init__(self)

    @staticmethod
    def model_output(data, model, *args, **kwargs):
        image, label = data
        image, label = image.to(CIFAR2ModelOutput.device), label.to(CIFAR2ModelOutput.device)
        raw_logit = model(image)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        logp = -loss_fn(raw_logit, label)

        return logp - torch.log(1 - torch.exp(logp))

    @staticmethod
    def get_out_to_loss_grad(data, model, *args, **kwargs):
        image, label = data
        image, label = image.to(CIFAR2ModelOutput.device), label.to(CIFAR2ModelOutput.device)
        raw_logit = model(image)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        p = torch.exp(-loss_fn(raw_logit, label))

        return (1-p).clone().detach().unsqueeze(-1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ensemble", type=int, default=1, help="ensemble number")
    parser.add_argument("--independent", type=int, default=1, help="independent number")
    parser.add_argument("--dropout", action="store_true", help="dropout or not")
    parser.add_argument("--dropout_only_Q", action="store_true", help="dropout only Q or not")
    parser.add_argument("--model", type=str, default="resnet9", help="model to train")
    args = parser.parse_args()

    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    cifar2_indices_train = get_cifar2_indices_and_adjust_labels(train_dataset)
    train_index = cifar2_indices_train[:5000]
    cifar2_indices_test = get_cifar2_indices_and_adjust_labels(test_dataset)
    test_index = cifar2_indices_test[:500]

    sampler_train = SubsetSamper(train_index)
    sampler_test = SubsetSamper(test_index)

    train_loader = DataLoader(train_dataset, batch_size=1, sampler=sampler_train)
    test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

    # Initialize the model, loss function, and optimizer
    if args.dropout_only_Q:
        checkpoint_files = []
        for independent in range(args.independent):
            checkpoint_files += [
                [f"./checkpoint/checkpoint_{args.model}_{independent}.pt" for _ in range(args.ensemble)]
            ]
    else:
        checkpoint_files = []
        for independent in range(args.independent):
            checkpoint_files += [
                f"./checkpoint/checkpoint_{args.model}_{independent}.pt" for _ in range(args.ensemble)
            ]

    if args.model == "resnet9":
        model = ResNet9(dropout_rate=0.1).to(device)
    else:
        model = ConvNet(dropout_rate=0.1).to(device)

    trak = DropTRAK(model=model,
                    model_checkpoints=checkpoint_files,
                    train_loader=train_loader,
                    test_loader=test_loader,
                    model_output_class=CIFAR2ModelOutput,
                    device=device,
                    dropout=args.dropout,
                    dropout_only_Q=args.dropout_only_Q)
    torch.cuda.reset_peak_memory_stats("cuda")
    import time
    st = time.time()
    score = trak.score()
    print(f"independent: {args.independent}, ensemble: {args.ensemble}, Time used:", time.time() - st)
    peak_memory = torch.cuda.max_memory_allocated("cuda") / 1e6  # Convert to MB
    print(f"Peak memory usage: {peak_memory} MB")

    print(torch.argmax(score, dim=1))

    torch.save(score, f"./score/score_model_{args.model}_cifar{'_dropout' if args.dropout else ''}_ensemble_{args.ensemble}_independent_{args.independent}_Q_{args.dropout_only_Q}.pt")
