from data.Dataloaders import *
from models.DisCoNet import DisCoNet
from utils.util import parse_args_DisCoNet
import torch
import wandb
import os
from config import data_raw_dir

args = parse_args_DisCoNet()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size, number of channels and OOD datasets
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
    near_ood = ['fashionmnist']
    far_ood = ['cifar10', 'tinyimagenet', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'mnist_c'))
    corruptions.remove('identity')
    corruptions.remove('translate')

elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
    near_ood = ['cifar100', 'tinyimagenet']
    far_ood = ['mnist', 'svhn', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'CIFAR-10-C'))
    corruptions = [c.split('.')[0] for c in corruptions]
    corruptions.remove('labels')

elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3
    near_ood = ['ssb-hard', 'ninco']
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'Tiny-ImageNet-C'))

elif args.dataset == 'imagenet':
    img_size = 128
    channels = 3
    near_ood = ['ssb-hard', 'ninco']
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'ImageNet-C'))

# Initialize model and load checkpoint
model = DisCoNet(input_shape = img_size, device = device, input_channels = channels, latent_dim = args.latent_dim, n_epochs = args.n_epochs, hidden_dims = args.hidden_dims, lr = args.lr, batch_size = args.batch_size, gen_weight = args.gen_weight, recon_weight=args.recon_weight, sample_and_save_frequency = args.sample_and_save_frequency) 

if args.discriminator_checkpoint is not None:
    model.discriminator.load_state_dict(torch.load(args.discriminator_checkpoint))

in_array = None

# Evaluate OOD detection
if args.ood_task == 'near':
    print(f"Near OOD Detection for {args.dataset}\n")
    for ood in near_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'far':
    print(f"Far OOD Detection for {args.dataset}\n")
    for ood in far_ood:
        out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
        auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
        print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")

elif args.ood_task == 'covar':
    print(f"Covariate Shift Detection for {args.dataset}\n")
    if args.dataset == 'mnist':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            out_loader = mnistc_dataloader(args.batch_size, c)
            auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
            aurocs.append(auroc)
            fpr95s.append(fpr95)
            print(f"OOD: {c}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

    elif args.dataset == 'cifar10':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = cifar10c_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")
    
    elif args.dataset == 'tinyimagenet':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = tinyimagenetc_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

    elif args.dataset == 'imagenet':
        corruptions.sort()
        aurocs = []
        fpr95s = []
        for c in corruptions:
            for i in range(1,6):
                out_loader = imagenetc_dataloader(args.batch_size, c, i)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                aurocs.append(auroc)
                fpr95s.append(fpr95)
        print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")