import sys
sys.path.append("..")

from dropgraddot.output_function import BaseModelOutputClass
from model_train import SimpleDNN, SubsetSamper
from dropgraddot.dropgraddot import DropGradDot
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import argparse


class MNISTModelOutput(BaseModelOutputClass):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(self):
        super().__init__(self)

    @staticmethod
    def model_output(data, model, *args, **kwargs):
        image, label = data
        image, label = image.to(MNISTModelOutput.device), label.to(MNISTModelOutput.device)
        raw_logit = model(image)

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        logp = -loss_fn(raw_logit, label)

        return logp - torch.log(1 - torch.exp(logp))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ensemble", type=int, default=1, help="ensemble number")
    parser.add_argument("--independent", type=int, default=1, help="independent number")
    parser.add_argument("--dropout", action="store_true", help="dropout or not")
    parser.add_argument("--cos", action="store_true", help="cosine similarity or not")
    args = parser.parse_args()

    if args.ensemble * args.independent > 300:
        exit(0)

    # First, check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    sampler_train = SubsetSamper([i for i in range(5000)])
    sampler_test = SubsetSamper([i for i in range(500)])

    train_loader = DataLoader(train_dataset, batch_size=1, sampler=sampler_train)
    test_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler_test)

    # Initialize the model, loss function, and optimizer
    checkpoint_files = []
    for independent in range(args.independent):
        checkpoint_files += [
            f"./checkpoint/checkpoint_{independent}_sample_{'5000' if args.dropout else '5000'}.pt" for _ in range(args.ensemble)
        ]

    trak = DropGradDot(model=SimpleDNN(dropout_rate=0.1).to(device),
                       model_checkpoints=checkpoint_files,
                       train_loader=train_loader,
                       test_loader=test_loader,
                       model_output_class=MNISTModelOutput,
                       device=device,
                       dropout=args.dropout,
                       cos=args.cos)
    score = trak.score()

    print(torch.argmax(score, dim=1))

    torch.save(score, f"./score/score_grad_cos_{args.cos}_mnist{'_dropout' if args.dropout else ''}_ensemble_{args.ensemble}_independent_{args.independent}_noproject_full.pt")
