"""
Code for Softmax, Dropout (ICML 2016), and DDU (CVPR 2023)
"""

import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import argparse
import json

import utility, train, ood_detection, conf_calibration
from utility import *
from train import *
from ood_detection import *
from conf_calibration import *

warnings.filterwarnings('ignore')


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--ID_dataset", default="CIFAR-100", choices=["MNIST", "CIFAR-10", "CIFAR-100"], help="Pick a dataset")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--val_size", type=float, default=0.05, help="Validation size")
    parser.add_argument("--val_seed", type=int, default=20241120, help="Validation seed")
    parser.add_argument("--imbalance_factor", type=float, default=0, help="Imbalance factor")

    parser.add_argument("--noise", action="store_true", help="Use noisy dataset")
    parser.add_argument("--dropout_rate", type=float, default=0.3, help="Dropout rate")

    parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--step_size", type=int, default=30, help="Step size for scheduler")
    parser.add_argument("--num_epochs", type=int, default=100, help="Number of epochs")
    parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay")
    parser.add_argument("--epsilon", type=float, default=1e-6, help="Numerical stability")
    parser.add_argument("--scheduler_type", type=str, default="step", help="Scheduler type")

    parser.add_argument("--sm_type", type=str, default="Dropout", choices=["softmax", "Dropout", "DDU"], help="Softmax type")
    parser.add_argument("--index", type=int, default=0, help="Model index")
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Computation device")
    parser.add_argument("--result_dir", type=str, default="saved_results_softmax", help="Results directory")
    parser.add_argument("--model_dir", type=str, default="saved_models_softmax", help="Models directory")
    parser.add_argument("--pretrained", action="store_true", help="Use pretrained model")
    parser.add_argument("--spect_norm", action="store_true", help="Use spectral normalization")
    parser.add_argument("--num_passes", type=int, default=10, help="Number of forward passes in MC Dropout")

    return parser.parse_args()


def main(args):
    if args.ID_dataset == "MNIST":
        embedding_dim = 576
        num_classes = 10
        
    elif args.ID_dataset == "CIFAR-10":
        embedding_dim = 512
        num_classes = 10
        
    else:
        embedding_dim = 512
        num_classes = 100

    trainloader, validloader, testloader, ood_loader1, ood_loader2 = load_datasets(args.ID_dataset, args.batch_size, args.val_size, args.imbalance_factor, args.noise)
    
    model = load_model(args.ID_dataset, args.pretrained, args.index, args.dropout_rate, args.spect_norm, args.device)
    
    print(model)
    train_softmax(model, args.learning_rate, args.step_size, args.num_epochs, trainloader, validloader, args.sm_type, args.device)

    test_acc = eval_softmax(model, testloader, args.sm_type, args.num_passes, args.device)
    print(f"Test Accuracy: {test_acc}")
    
    if args.sm_type == "DDU":
        gda, p_z_train = fit_gda(model, trainloader, num_classes, embedding_dim, args.device)
        conf_aupr, conf_auroc, brier = conf_calibration_softmax(model, trainloader, testloader, args.sm_type, args.num_passes, args.device)
        print(f"CONF AUROC: {conf_auroc}, CONF AUPR: {conf_aupr}, Brier Score: {brier}")
        ood_auroc, ood_aupr = ood_detection_ddu(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes, args.device)
        print(f"OOD AUROC: {ood_auroc}, OOD AUPR: {ood_aupr}")
    
    else:
        conf_auroc, conf_aupr, brier = conf_calibration_softmax(model, trainloader, testloader, args.sm_type, args.num_passes, args.device)
        print(f"CONF AUROC: {conf_auroc}, CONF AUPR: {conf_aupr}, Brier Score: {brier}")

        ood_auroc, ood_aupr = ood_detection_softmax(model, args.sm_type, trainloader, testloader, ood_loader1, ood_loader2, args.num_passes, args.device)
        print(f"OOD AUROC: {ood_auroc}, OOD AUPR: {ood_aupr}")
        
        dist_auroc, dist_aupr = dist_shift_detection_softmax(model, args.sm_type, testloader, args.num_passes, args.device)

    result = {"Test Accuracy": test_acc,"CONF AUROC": conf_auroc,"CONF AUPR": conf_aupr,"Brier Score": brier,"OOD AUROC": ood_auroc,"OOD AUPR": ood_aupr,}
    result = {key: convert_to_native(value) for key, value in result.items()}
    result_filename = f"{args.sm_type}_{args.ID_dataset}_{args.imbalance_factor}_lr_{args.learning_rate}_p_drop_{args.dropout_rate}_results_{int(time.time())}.json"

    result_filepath = os.path.join(args.result_dir, result_filename)
    os.makedirs(args.result_dir, exist_ok=True)
    with open(result_filepath, "w") as f:
        json.dump(result, f, indent=4)

    model_filename = f"{args.sm_type}_{args.ID_dataset}_{args.imbalance_factor}_lr_{args.learning_rate}_p_drop_{args.dropout_rate}_weights_{int(time.time())}.pth"

    model_filepath = os.path.join(args.model_dir, model_filename)
    os.makedirs(args.model_dir, exist_ok=True)
    torch.save(model.state_dict(), model_filepath)


if __name__ == "__main__":
    for i in range(5):
        args = parse_args()
        main(args)
