"""
Description: Main code running the baseline EDL methods: EDL (NeurIPS 2018), I-EDL (ICML 2023), R-EDL (ICLR 2024) and DAEDL (ICML 2024)
"""
import time
import argparse
import json
import os
import torch
import warnings
from tqdm import tqdm

import utility, train, ood_detection, conf_calibration
from utility import *
from train import *
from ood_detection import *
from conf_calibration import *

warnings.filterwarnings('ignore')

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--ID_dataset", default="CIFAR-10", choices=["MNIST", "CIFAR-10", "CIFAR-100"], help="Pick a dataset")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--val_size", type=float, default=0.05, help="Validation size")
    parser.add_argument("--val_seed", type=int, default=20250507, help="Validation seed")
    parser.add_argument("--imbalance_factor", type=float, default=0, help="Imbalance factor")
    
    parser.add_argument("--noise", action="store_true", help="Use noisy dataset")
    parser.add_argument("--dropout_rate", type=float, default=0, help="Dropout rate")

    parser.add_argument("--edl_type", type=str, default="EDL", help="EDL type")
    parser.add_argument("--reg_param_kl", type=float, default=1e-3, help="KL divergence regularization")
    parser.add_argument("--reg_param_fisher", type=float, default=1e-3, help="Fisher information regularization")
    parser.add_argument("--lamb", type=float, default=0.1, help="Lambda for R-EDL")
    
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--step_size", type=int, default=30, help="Step size for scheduler")
    parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay")
    parser.add_argument("--epsilon", type=float, default=1e-6, help="Numerical stability")
    parser.add_argument("--scheduler_type", type=str, default="step", help="Scheduler type")

    parser.add_argument("--index", type=int, default=0, help="Model index")
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Computation device")
    parser.add_argument("--result_dir", type=str, default="saved_results_edl", help="Results directory")
    parser.add_argument("--model_dir", type=str, default="saved_models_edl", help="Models directory")
    parser.add_argument("--pretrained", action="store_true", help="Use pretrained model")
    parser.add_argument("--spect_norm", action="store_true", help="Use spectral normalization")
   
    args = parser.parse_args()
    return args

def main(args):

    if args.ID_dataset == "MNIST":
        embedding_dim = 576
        num_classes = 10
        num_epochs = 50
        
    elif args.ID_dataset == "CIFAR-10":
        embedding_dim = 512
        num_classes = 10
        num_epochs = 100
        
    else:
        embedding_dim = 512
        num_classes = 100
        num_epochs = 1

    trainloader, validloader, testloader, ood_loader1, ood_loader2 = load_datasets(args.ID_dataset, args.batch_size, args.val_size, args.imbalance_factor, args.noise)

    model = load_model(args.ID_dataset, args.pretrained, args.index, args.imbalance_factor, args.spect_norm, args.device)

    train_edl(model, args.edl_type, args.learning_rate, args.step_size, args.reg_param_kl, args.reg_param_fisher, args.lamb, num_epochs, trainloader, validloader, num_classes, args.device)

    test_acc = eval_edl(model, args.edl_type, testloader, args.device)
    print(f"Test Accuracy: {test_acc}")
    
    if args.edl_type == "DAEDL":
        gda, p_z_train = fit_gda(model, trainloader, num_classes, embedding_dim, args.device)
        conf_aupr, conf_auroc, brier = conf_calibration_daedl(model, gda, p_z_train, testloader, num_classes, args.device) 
        ood_auroc, ood_aupr = ood_detection_daedl(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes, args.device)
        
        dist_auroc, dist_aupr = 0, 0
        
    else:   
        conf_auroc, conf_aupr, brier = conf_calibration_edl(model, args.edl_type, trainloader, testloader, args.lamb, args.device)
        print(f"CONF AUROC: {conf_auroc}, CONF AUPR: {conf_aupr}")

        ood_auroc, ood_aupr = ood_detection_edl(model, args.edl_type, trainloader, testloader, ood_loader1, ood_loader2, args.lamb, args.device)
        print(f"OOD AUROC: {ood_auroc}, OOD AUPR: {ood_aupr}")
        
        if args.ID_dataset in ["MNIST", "CIFAR-10"]:
            dist_auroc, dist_aupr =  dist_shift_detection_edl(args.ID_dataset, model, args.edl_type, trainloader, testloader, args.lamb, args.device)
        else:
            dist_auroc, dist_aupr = 0, 0
    
    result = {
        "Test Accuracy": test_acc,
        "CONF AUROC": conf_auroc,
        "CONF AUPR": conf_aupr,
        "OOD AUROC": ood_auroc,
        "OOD AUPR": ood_aupr,
        "DIST AUROC": dist_auroc,
        "DIST AUPR": dist_aupr
    }

    result = {key: convert_to_native(value) for key, value in result.items()}
    
    os.makedirs(args.result_dir, exist_ok=True)
    result_filename = f"{args.ID_dataset}_ir_{args.imbalance_factor}_{args.edl_type}_lr_{args.learning_rate}_results_{int(time.time())}.json"
    result_filepath = os.path.join(args.result_dir, result_filename)
    with open(result_filepath, 'w') as result_file:
        json.dump(result, result_file, indent=4)

    model_filename = f"{args.ID_dataset}_ir_{args.imbalance_factor}_{args.edl_type}_lr_{args.learning_rate}_models_{int(time.time())}.pth"
    model_filepath = os.path.join(args.model_dir, model_filename)
    os.makedirs(args.model_dir, exist_ok=True)
    torch.save(model.state_dict(), model_filepath)

if __name__ == "__main__":   
    args = parse_args()
    main(args)
