from data.Dataloaders import *
from models.DisCoNet import DisCoNet
from utils.util import parse_args_DisCoNet
import torch
import wandb
import os
from config import data_raw_dir, models_dir

args = parse_args_DisCoNet()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size, number of channels and OOD datasets
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
    near_ood = ['fashionmnist']
    far_ood = ['cifar10', 'tinyimagenet', 'dtd', 'places365']
    in_loader = pick_dataset(name = args.dataset, train=True, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'mnist_c'))
    corruptions.remove('identity')
    corruptions.remove('translate')

if args.dataset == 'cifar100':
    img_size = 32
    channels = 3
    far_ood = ['svhn', 'dtd', 'places365']
    near_ood = ['cifar10', 'tinyimagenet']
    near_ood_scores = np.zeros((len(near_ood), 2), dtype=np.float32)
    far_ood_scores = np.zeros((len(far_ood), 2), dtype=np.float32)
    base_dir = os.path.join(models_dir, 'DisCoNet', f'cifar100_1')
    checkpoints = ['Discriminator_cifar100_249.pt', 'Discriminator_cifar100_239.pt', 'Discriminator_cifar100_229.pt', 'Discriminator_cifar100_219.pt', 'Discriminator_cifar100_209.pt']
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)

elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
    near_ood = ['cifar100', 'tinyimagenet']
    near_ood_scores = np.zeros((len(near_ood), 2), dtype=np.float32)
    far_ood = ['mnist', 'svhn', 'dtd', 'places365']
    far_ood_scores = np.zeros((len(far_ood), 2), dtype=np.float32)
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'CIFAR-10-C'))
    corruptions = [c.split('.')[0] for c in corruptions]
    corruptions.remove('labels')
    base_dir = os.path.join(models_dir, 'DisCoNet', f'cifar10_{args.latent_dim}')
    #base_dir = os.path.join(models_dir, 'DisCoNet', 'cifar10_blur_1')
    checkpoints = ['Discriminator_cifar10_249.pt', 'Discriminator_cifar10_239.pt', 'Discriminator_cifar10_229.pt', 'Discriminator_cifar10_219.pt', 'Discriminator_cifar10_209.pt']
    high_freq = ['brightness', 'frost', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise']
    corruption_scores = np.zeros((5, 2), dtype=np.float32)
    high_freq_scores = np.zeros(2, dtype=np.float32)
    low_freq_scores = np.zeros(2, dtype=np.float32)

elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3
    near_ood = ['tinyimagenet','ssb-hard', 'ninco']
    near_ood_scores = np.zeros((len(near_ood), 2), dtype=np.float32)
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    far_ood_scores = np.zeros((len(far_ood), 2), dtype=np.float32)
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'Tiny-ImageNet-C'))
    #base_dir = os.path.join(models_dir, 'DisCoNet', f'tinyimagenet_{args.latent_dim}')
    base_dir = os.path.join(models_dir, 'DisCoNet', 'tinyimagenet_blur_1')
    checkpoints = ['Discriminator_tinyimagenet_129.pt', 'Discriminator_tinyimagenet_19.pt', 'Discriminator_tinyimagenet_79.pt', 'Discriminator_tinyimagenet_39.pt', 'Discriminator_tinyimagenet_29.pt']
    high_freq = ['brightness', 'frost', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'pixelate', 'shot_noise', 'snow']
    corruption_scores = np.zeros((5, 2), dtype=np.float32)
    high_freq_scores = np.zeros(2, dtype=np.float32)
    low_freq_scores = np.zeros(2, dtype=np.float32)


elif args.dataset == 'imagenet':
    img_size = 128
    channels = 3
    near_ood = ['ssb-hard', 'ninco']
    near_ood_scores = np.zeros((len(near_ood), 2), dtype=np.float32)
    far_ood = ['inaturalist', 'dtd', 'openimageo']
    far_ood_scores = np.zeros((len(far_ood), 2), dtype=np.float32)
    checkpoints = ['Discriminator_imagenet_129.pt', 'Discriminator_imagenet_89.pt', 'Discriminator_imagenet_79.pt', 'Discriminator_imagenet_109.pt', 'Discriminator_imagenet_99.pt']
    base_dir = os.path.join(models_dir, 'DisCoNet', 'imagenet_special')
    checkpoints = os.listdir(base_dir)
    in_loader = pick_dataset(name = args.dataset, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
    corruption_scores = np.zeros((5, 2), dtype=np.float32)
    corruptions = os.listdir(os.path.join(data_raw_dir, 'ImageNet-C'))

# Initialize model and load checkpoint
model = DisCoNet(input_shape = img_size, device = device, input_channels = channels, latent_dim = args.latent_dim, n_epochs = args.n_epochs, hidden_dims = args.hidden_dims, lr = args.lr, batch_size = args.batch_size, gen_weight = args.gen_weight, recon_weight=args.recon_weight, sample_and_save_frequency = args.sample_and_save_frequency)
for d in checkpoints:
    args.discriminator_checkpoint = os.path.join(base_dir, d)
    if args.discriminator_checkpoint is not None:
        model.discriminator.load_state_dict(torch.load(args.discriminator_checkpoint))

    in_array = None

    # Evaluate OOD detection
    if args.ood_task == 'near':
        print(f"Near OOD Detection for {args.dataset}\n")
        # create array of size (len(near_ood), 2)
        counter=0
        for ood in near_ood:
            out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset=args.dataset)
            auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
            print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
            near_ood_scores[counter, 0] += auroc
            near_ood_scores[counter, 1] += fpr95
            counter+=1

    elif args.ood_task == 'far':
        print(f"Far OOD Detection for {args.dataset}\n")
        counter=0
        for ood in far_ood:
            out_loader = pick_dataset(name = ood, train=False, batch_size=args.batch_size, img_size=img_size, id_dataset='cifar10')
            auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
            print(f"OOD: {ood}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
            far_ood_scores[counter, 0] += auroc
            far_ood_scores[counter, 1] += fpr95
            counter+=1

    elif args.ood_task == 'covar':
        print(f"Covariate Shift Detection for {args.dataset}\n")
        if args.dataset == 'mnist':
            corruptions.sort()
            aurocs = []
            fpr95s = []
            for c in corruptions:
                out_loader = mnistc_dataloader(args.batch_size, c)
                auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                aurocs.append(auroc)
                fpr95s.append(fpr95)
                print(f"OOD: {c}\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n\n")
            print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

        elif args.dataset == 'cifar10':
            corruptions.sort()
            aurocs = []
            fpr95s = []
            for c in corruptions:
                for i in range(1,6):
                    out_loader = cifar10c_dataloader(args.batch_size, c, i)
                    auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                    print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                    aurocs.append(auroc)
                    fpr95s.append(fpr95)
                    corruption_scores[i-1, 0] += auroc
                    corruption_scores[i-1, 1] += fpr95
                    if c in high_freq:
                        high_freq_scores[0] += auroc
                        high_freq_scores[1] += fpr95
                    else:
                        low_freq_scores[0] += auroc
                        low_freq_scores[1] += fpr95
            print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")
        
        elif args.dataset == 'tinyimagenet':
            corruptions.sort()
            aurocs = []
            fpr95s = []
            for c in corruptions:
                for i in range(1,6):
                    out_loader = tinyimagenetc_dataloader(args.batch_size, c, i)
                    auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                    print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                    aurocs.append(auroc)
                    fpr95s.append(fpr95)
                    corruption_scores[i-1, 0] += auroc
                    corruption_scores[i-1, 1] += fpr95
                    if c in high_freq:
                        high_freq_scores[0] += auroc
                        high_freq_scores[1] += fpr95
                    else:
                        low_freq_scores[0] += auroc
                        low_freq_scores[1] += fpr95
            print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

        elif args.dataset == 'imagenet':
            corruptions.sort()
            aurocs = []
            fpr95s = []
            for c in corruptions:
                for i in range(1,6):
                    out_loader = imagenetc_dataloader(args.batch_size, c, i)
                    auroc, fpr95, in_array, _ = model.outlier_detection(in_loader, out_loader, display=False, in_array=in_array)
                    print(f"OOD: {c} ({i})\nAUROC: {auroc:.4f}\nFPR95: {fpr95:.4f}\n")
                    aurocs.append(auroc)
                    fpr95s.append(fpr95)
                    corruption_scores[i-1, 0] += auroc
                    corruption_scores[i-1, 1] += fpr95
            print(f"Mean AUROC: {np.mean(aurocs):.4f}\nMean FPR95: {np.mean(fpr95s):.4f}\n")

# print mean near OOD scores
for i in range(len(near_ood)):
    print(f"Dataset: {near_ood[i]}, Mean AUROC: {near_ood_scores[i, 0]*20:.1f}\nMean FPR95: {np.mean(near_ood_scores[i, 1]*20):.1f}\n")

for i in range(len(far_ood)):
    print(f"Dataset: {far_ood[i]}, Mean AUROC: {far_ood_scores[i, 0]*20:.1f}\nMean FPR95: {np.mean(far_ood_scores[i, 1]*20):.1f}\n")

for i in range(5):
    print(f"Intensity: {i+1}, Mean AUROC: {corruption_scores[i, 0]*20/len(corruptions):.1f}\nMean FPR95: {np.mean(corruption_scores[i, 1]*20/len(corruptions)):.1f}\n")
print(f"High Frequency, Mean AUROC: {high_freq_scores[0]*4/len(high_freq):.1f}\nMean FPR95: {(high_freq_scores[1]*4/len(high_freq)):.1f}\n")
print(f"Low Frequency, Mean AUROC: {low_freq_scores[0]*4/(len(corruptions)-len(high_freq)):.1f}\nMean FPR95: {(low_freq_scores[1]*4/(len(corruptions)-len(high_freq))):.1f}\n")