from data.Dataloaders import *
from models.PrescribedGAN import PresGAN
from utils.util import parse_args_PresGAN
import torch
import wandb
import os
from config import data_raw_dir

args = parse_args_PresGAN()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size, number of channels and OOD datasets
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
    near_ood = ['fashionmnist']
    far_ood = ['cifar10', 'tinyimagenet', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'mnist_c'))
    corruptions.remove('identity')
    corruptions.remove('translate')

elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
    near_ood = ['cifar100', 'tinyimagenet']
    far_ood = ['mnist', 'svhn', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'CIFAR-10-C'))
    corruptions = [c.split('.')[0] for c in corruptions]
    corruptions.remove('labels')

elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3
    near_ood = ['ssb-hard', 'ninco']
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'Tiny-ImageNet-C'))

# Initialize model and load checkpoint
model = PresGAN(imgSize=img_size, nz=args.nz, ngf = args.ngf, ndf = args.ndf, nc = channels, device = device, beta1 = args.beta1, lrD = args.lrD, lrG = args.lrG, n_epochs = args.n_epochs, sigma_lr=args.sigma_lr,
                    num_gen_images=args.num_gen_images, restrict_sigma=args.restrict_sigma, sigma_min=args.sigma_min, sigma_max=args.sigma_max, stepsize_num=args.stepsize_num, lambda_=args.lambda_,
                    burn_in=args.burn_in, num_samples_posterior=args.num_samples_posterior, leapfrog_steps=args.leapfrog_steps, hmc_learning_rate=args.hmc_learning_rate, hmc_opt_accept=args.hmc_opt_accept, flag_adapt=args.flag_adapt, 
                    sample_and_save_freq=args.sample_and_save_freq, dataset=args.dataset)

model.load_checkpoints(generator_checkpoint=args.checkpoint, discriminator_checkpoint=args.discriminator_checkpoint, sigma_checkpoint=args.sigma_checkpoint)

in_array = None

# Evaluate OOD detection
if args.ood_task == 'near':
    print(f"Near OOD Detection for {args.dataset}\n")
    for ood in near_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'far':
    print(f"Far OOD Detection for {args.dataset}\n")
    for ood in far_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'covar':
    print(f"Covariate Shift Detection for {args.dataset}\n")
    if args.dataset == 'mnist':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            out_loader = mnistc_dataloader(args.batch_size, c)
            auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
            aurocs.append(auroc)
            fpr95s.append(fpr95)
            print(f"OOD: {c}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

    elif args.dataset == 'cifar10':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = cifar10c_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")
    
    elif args.dataset == 'tinyimagenet':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = tinyimagenetc_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")