from data.Dataloaders import *
from models.DCGAN import Discriminator
from utils.util import parse_args_DCGAN
import torch
import wandb
import os
from config import data_raw_dir

args = parse_args_DCGAN()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size, number of channels and OOD datasets
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
    near_ood = ['fashionmnist']
    far_ood = ['cifar10', 'tinyimagenet', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'mnist_c'))
    corruptions.remove('identity')
    corruptions.remove('translate')

elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
    near_ood = ['cifar100', 'tinyimagenet']
    far_ood = ['mnist', 'svhn', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'CIFAR-10-C'))
    corruptions = [c.split('.')[0] for c in corruptions]
    corruptions.remove('labels')

elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3
    near_ood = ['ssb-hard', 'ninco']
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'Tiny-ImageNet-C'))

# Initialize model and load checkpoint
model = Discriminator(channels=channels, d=args.d, imgSize=img_size).to(device)

if args.discriminator_checkpoint is not None:
    model.load_state_dict(torch.load(args.discriminator_checkpoint))

in_array = None

# Evaluate OOD detection
if args.ood_task == 'near':
    print(f"Near OOD Detection for {args.dataset}\n")
    for ood in near_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, device=device, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'far':
    print(f"Far OOD Detection for {args.dataset}\n")
    for ood in far_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, device=device, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'covar':
    print(f"Covariate Shift Detection for {args.dataset}\n")
    if args.dataset == 'mnist':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            out_loader = mnistc_dataloader(args.batch_size, c)
            auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, device=device, in_array=in_array)
            aurocs.append(auroc)
            fpr95s.append(fpr95)
            print(f"OOD: {c}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

    elif args.dataset == 'cifar10':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = cifar10c_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, device=device, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")
    
    elif args.dataset == 'tinyimagenet':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = tinyimagenetc_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, device=device, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")