"""
Description: Main code for running the experiments for F-EDL
"""
import time
import argparse
import pickle
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm  
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split

import utility, train, density_estimation, ood_detection, conf_calibration 
from utility import *
from train import *
from density_estimation import *
from ood_detection import * 
from conf_calibration import *

import json
import os

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--ID_dataset", default="CIFAR-10", choices=["MNIST", "CIFAR-10", "CIFAR-100"], help="Pick a dataset")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--val_size", type=float, default=0.05, help="Validation size")
    parser.add_argument("--val_seed", type=int, default=0, help="Validation seed")
    parser.add_argument("--imbalance_factor", type=float, default = 0, help="Imbalance factor")
    parser.add_argument("--noise", action="store_true", help="Use noisy dataset (e.g., Dirty MNIST)")
    parser.add_argument("--dropout_rate", type=float, default = 0, help="Dropout rate")
    parser.add_argument("--num_epochs", type = int, default=100, help="Num Epochs")

    parser.add_argument("--loss_option", type=str, default = "1", help="Classification loss")
    
    parser.add_argument("--reg_param_p", type=float, default = 1, help="Regularization parameter alpha")
    parser.add_argument("--reg_param_alpha", type=float, default = 0, help="Regularization parameter alpha")
    
    parser.add_argument("--learning_rate", type=float, default = 5e-4, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default = 0, help="Weight decay")
    parser.add_argument("--step_size", type=int, default = 30, help="Step size")
    parser.add_argument("--epsilon", type=float, default=1e-6, help="Small value for numerical stability")
    parser.add_argument("--scheduler_type", type=str, default="step", choices=["none", "step", "lambda", "exponential", "cosine"], help="Scheduler type")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for computation")
    parser.add_argument("--result_dir", type=str, default="saved_results_fedl", help="Directory to save results")
    parser.add_argument("--model_dir", type=str, default="saved_models_fedl", help="Directory to save models")
    parser.add_argument("--hidden_dim", type=int, default = 256, help="Hidden dimension for MLP")
    parser.add_argument("--num_layers", type=int, default = 2, help="Number of layers in MLP")
    parser.add_argument("--spect_norm", action="store_true", help="Use spectral normalization")

    parser.add_argument("--fix_tau", action="store_true", help="Fix tau=1")
    parser.add_argument("--fix_p", type=str, default=None, choices=[None, "uniform", "dirichlet"], help="Fix p type")

    
    args = parser.parse_args()
    return args


def main(args):
    print(args)

    model_type = "FEDL"
    if args.ID_dataset == "MNIST":
        embedding_dim = 576
        num_classes = 10  
        
    elif args.ID_dataset == "CIFAR-10":
        embedding_dim = 512
        num_classes = 10    
    
    else:
        embedding_dim = 512
        num_classes = 100


    trainloader, validloader, testloader, ood_loader1, ood_loader2 = load_datasets(args.ID_dataset, args.batch_size, args.val_size, args.imbalance_factor, args.noise)
    
    model = FEDL(args.ID_dataset, args.dropout_rate, args.spect_norm, args.device, args.hidden_dim, args.num_layers)

    train_fedl(model, args.learning_rate, args.weight_decay, args.step_size, args.num_epochs, trainloader, validloader, num_classes, args.fix_tau, args.fix_p, args.device)

    test_acc = eval_fedl(model, testloader, args.fix_tau, args.fix_p, args.device)
    print(f"Test Accuracy: {test_acc}")
    conf_auroc, conf_aupr, brier = conf_calibration_fedl(model, testloader, args.fix_tau, args.fix_p,args.device)
    print(f"CONF AUROC: {conf_auroc}, CONF AUPR: {conf_aupr}")

    ood_auroc, ood_aupr = ood_detection_fedl(model, testloader, ood_loader1, ood_loader2, args.fix_tau, args.fix_p, args.device)
    print(f"OOD AUROC: {ood_auroc}, OOD AUPR: {ood_aupr}")
    
    if args.ID_dataset in ["MNIST", "CIFAR-10"]:
        dist_auroc, dist_aupr = 0, 0
        #dist_auroc, dist_aupr = dist_shift_detection_fedl(args.ID_dataset, model, testloader, args.device)
    else:
        dist_auroc, dist_aupr = 0, 0
        
    print(f"DIST AUROC: {dist_auroc}, DIST AUPR: {dist_aupr}")

    result = {"Test Accuracy": test_acc, "CONF AUROC": conf_auroc, "CONF AUPR": conf_aupr, "BRIER": brier, "OOD AUROC": ood_auroc,"OOD AUPR": ood_aupr, "DIST AUROC": dist_auroc,"DIST AUPR": dist_aupr}

    ablation_tag = ""
    if args.fix_tau:
        ablation_tag += "_tau1"
    if args.fix_p:
        ablation_tag += f"_p_{args.fix_p}"


    exp_tag = f"{model_type}_"
    if args.fix_tau:
        exp_tag += "_FIXTAU"
    if args.fix_p:
        exp_tag += f"_FIXP-{args.fix_p.upper()}"
    
    result_filename = f"{args.ID_dataset}_{exp_tag}_IR{args.imbalance_factor}_LR{args.learning_rate}_HD{args.hidden_dim}_NL{args.num_layers}_results_{int(time.time())}.json"
    model_filename = f"{args.ID_dataset}_{exp_tag}_IR{args.imbalance_factor}_LR{args.learning_rate}_HD{args.hidden_dim}_NL{args.num_layers}_model_{int(time.time())}.pth"

    
    result = {key: convert_to_native(value) for key, value in result.items()}
    
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    result_filepath = os.path.join(args.result_dir, result_filename)
    
    with open(result_filepath, 'w') as result_file:
        json.dump(result, result_file, indent=4)
    
    all_models = model.state_dict()
    
    model_filepath = os.path.join(args.model_dir, model_filename)
    
    torch.save(all_models, model_filepath)
    
if __name__ == "__main__": 
    args = parse_args() 
    main(args)
