#%%
# Import required packages
import random
import os
import cv2
import torch
import torchvision as tv
import numpy as np
import pandas as pd
import normflows as nf
import importlib
from datetime import datetime

import arg
import models
from train import Flow_Trainer_realnvp, Flow_Trainer_nsf, Flow_Trainer_cnn
from util import *
from evaluation import *

from torchinfo import summary
from matplotlib import pyplot as plt
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import roc_auc_score
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA


import warnings

warnings.filterwarnings('ignore')

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False  

if __name__ == '__main__':

    args = arg.arg_parse()
    print(args)
    seed = 42
    hidden_channels = args.latent_dim
    split_mode = 'channel'
    scale = True
    feature_extractor = args.feature_extractor
    batch_size = args.batch_size
    epochs = args.num_epochs
    lr = args.lr
    weight_decay = args.weight_decay
    K = args.K
    L = args.L
    pretrain_flag = args.pretrain
    scheduler_flag = args.scheduler
    in_dist_dataset = args.in_dataset
    out_dist_dataset = args.out_dataset
    input_size = args.input_size
    pca_dim = args.pca_dim
    gray_scale = args.gray

    if in_dist_dataset in ['CIFAR10', 'CIFAR100', 'SVHN', 'celebA'] and gray_scale is not True:
        input_shape = (3, input_size, input_size)
        channels = 3
        n_dims = np.prod(input_shape)
    else:
        input_shape = (3, input_size, input_size)
        channels = 3
        n_dims = np.prod(input_shape) 

    seed_everything(seed)
    device = torch.device(f'cuda:{args.device}' if args.device != 'cpu' else 'cpu')
    print(n_dims)
    if gray_scale:
        in_dist_train_complexity, in_dist_test_complexity, in_dist_train_loader, in_dist_test_loader= dataload_gray_2d(in_dist_dataset, batch_size,input_size)
        out_dist_train_complexity, out_dist_test_complexity, out_dist_train_loader, out_dist_test_loader = dataload_gray_2d(out_dist_dataset, batch_size,input_size)
    else:    
        in_dist_train_complexity, in_dist_test_complexity, in_dist_train_loader, in_dist_test_loader= dataload(in_dist_dataset, batch_size,input_size)
        out_dist_train_complexity, out_dist_test_complexity, out_dist_train_loader, out_dist_test_loader = dataload(out_dist_dataset, batch_size,input_size)


    p_model = models.making_model_glow(input_shape, channels, hidden_channels, split_mode, L, K, scale)
    q_model = models.making_model_glow(input_shape, channels, hidden_channels, split_mode, L, K, scale)

    p_optimizer = torch.optim.Adam(p_model.parameters(), lr=lr, weight_decay=weight_decay)

    if scheduler_flag:
        p_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(p_optimizer, T_0 = int(epochs), T_mult = 1, eta_min=1e-6)
    else:
        p_scheduler = None
    q_optimizer = torch.optim.Adam(q_model.parameters(), lr=lr, weight_decay=weight_decay)

    if scheduler_flag:
        q_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(q_optimizer, T_0 = int(epochs), T_mult = 1, eta_min=1e-6)
    else:
        q_scheduler = None

    if pretrain_flag  and os.path.exists(f'./models/glow_{K}_{in_dist_dataset}_gray_{gray_scale}.pt'):
        p_model.load_state_dict(torch.load(f'./models/glow_{K}_{in_dist_dataset}_gray_{gray_scale}.pt'))
        p_model = p_model.to(device)
        p_trainer  = Flow_Trainer_cnn(p_model, epochs, in_dist_train_loader, in_dist_test_loader,
                                out_dist_test_loader, p_optimizer, p_scheduler, device,
                                in_dist_dataset, out_dist_dataset, n_dims)
    else:
        p_model = p_model.to(device)  
        p_trainer  = Flow_Trainer_cnn(p_model, epochs, in_dist_train_loader, in_dist_test_loader,
                                out_dist_test_loader, p_optimizer, p_scheduler, device,
                                in_dist_dataset, out_dist_dataset, n_dims)        
        loss_hist= p_trainer.fit()
        torch.save(p_model.state_dict(), f'./models/glow_{K}_{in_dist_dataset}_gray_{gray_scale}.pt')

    date = datetime.now().strftime('%Y%m%d%H%M%S')


    if pretrain_flag  and os.path.exists(f'./models/glow_{K}_{out_dist_dataset}_gray_{gray_scale}.pt'):
        q_model.load_state_dict(torch.load(f'./models/glow_{K}_{out_dist_dataset}_gray_{gray_scale}.pt'))
        q_model = q_model.to(device)  
        q_trainer  = Flow_Trainer_cnn(q_model, epochs, out_dist_train_loader, out_dist_test_loader,
                                in_dist_test_loader, q_optimizer, q_scheduler, device,
                                out_dist_dataset, in_dist_dataset, n_dims) 
    else:
        q_model = q_model.to(device)  
        q_trainer  = Flow_Trainer_cnn(q_model, epochs, out_dist_train_loader, out_dist_test_loader,
                                in_dist_test_loader, q_optimizer, q_scheduler, device,
                                out_dist_dataset, in_dist_dataset, n_dims)        
        loss_hist= q_trainer.fit()
        torch.save(q_model.state_dict(), f'./models/glow_{K}_{out_dist_dataset}_gray_{gray_scale}.pt')
   
    p_background_model = models.making_model_glow(input_shape, channels, hidden_channels, split_mode, L, K, scale)
    p_background_model.load_state_dict(torch.load(f'./models/glow_{K}_{in_dist_dataset}_background.pt'))
    p_background_model = p_background_model.to(device)
    q_background_model = models.making_model_glow(input_shape, channels, hidden_channels, split_mode, L, K, scale)
    q_background_model.load_state_dict(torch.load(f'./models/glow_{K}_{out_dist_dataset}_background.pt'))
    q_background_model = q_background_model.to(device)

    in_ll, in_z_ll, in_log_det, in_z_tensor, out_ll, out_z_ll, out_log_det, out_z_tensor = p_trainer.extract_features()
    print("Evaluation Start")
    likelihood_auc = eval_likelhiood(in_ll, out_ll, in_dist_dataset, out_dist_dataset)
    complexity_auc = eval_complexity(in_ll, out_ll, in_dist_test_complexity, out_dist_test_complexity, in_dist_dataset, out_dist_dataset)
    typicality_auc = eval_typicality_latent(in_z_tensor, out_z_tensor, in_dist_dataset, out_dist_dataset)
    typicality_entropy_auc = eval_typicality_entropy(p_model, device, in_dist_train_loader, in_ll, out_ll, in_dist_dataset, out_dist_dataset, False)
    ratio_auc = likelihood_ratio(p_background_model, in_ll, out_ll, K, device, in_dist_test_loader, out_dist_test_loader, in_dist_dataset, out_dist_dataset)
    perturb_pretrain_auc, alpha_list = perturb_pretrained(p_model, device, batch_size, in_dist_train_loader, in_dist_test_loader, out_dist_test_loader, in_dist_dataset, out_dist_dataset, feature_extractor, False)
    gmm_auc = statistic_gmm(in_z_ll, in_dist_test_complexity, out_z_ll, out_dist_test_complexity, in_dist_dataset, out_dist_dataset, seed)

    col = ['likelihood', 'complexity', 'typicality', 'typicality_entropy', 'likelihood_ratio', 'gmm']
    col.extend([f"perturb_alpha_{x}" for x in alpha_list])
    result = [likelihood_auc, complexity_auc, typicality_auc, typicality_entropy_auc, ratio_auc, gmm_auc]
    result.extend(perturb_pretrain_auc)

    result_csv = pd.DataFrame(np.array([result]), columns=col)
    result_csv.to_csv(f'./result/glow_{K}_{in_dist_dataset}_{out_dist_dataset}_gray_{gray_scale}_{feature_extractor}_{date}.csv')
    
    #Reverse Evaluation
    in_ll, in_z_ll, in_log_det, in_z_tensor, out_ll, out_z_ll, out_log_det, out_z_tensor = q_trainer.extract_features()
    print("Evaluation Start")
    likelihood_auc = eval_likelhiood(in_ll, out_ll, out_dist_dataset, in_dist_dataset)
    complexity_auc = eval_complexity(in_ll, out_ll, out_dist_test_complexity, in_dist_test_complexity, out_dist_dataset, in_dist_dataset)
    typicality_auc = eval_typicality_latent(in_z_tensor, out_z_tensor, out_dist_dataset, in_dist_dataset)
    typicality_entropy_auc = eval_typicality_entropy(q_model, device, out_dist_train_loader, in_ll, out_ll, out_dist_dataset, in_dist_dataset, False)
    ratio_auc = likelihood_ratio(q_background_model, in_ll, out_ll, K, device, out_dist_test_loader, in_dist_test_loader, out_dist_dataset, in_dist_dataset) 
    perturb_pretrain_auc, alpha_list  = perturb_pretrained(q_model, device, batch_size, out_dist_train_loader, out_dist_test_loader, in_dist_test_loader, out_dist_dataset, in_dist_dataset, feature_extractor, False)
    gmm_auc = statistic_gmm(in_z_ll, out_dist_test_complexity, out_z_ll, in_dist_test_complexity, out_dist_dataset, in_dist_dataset, seed)

    col = ['likelihood', 'complexity', 'typicality', 'typicality_entropy', 'likelihood_ratio', 'gmm']
    col.extend([f"perturb_alpha_{x}" for x in alpha_list])
    result = [likelihood_auc, complexity_auc, typicality_auc, typicality_entropy_auc, ratio_auc, gmm_auc]
    result.extend(perturb_pretrain_auc)
    
    result_csv = pd.DataFrame(np.array([result]), columns=col)
    result_csv.to_csv(f'./result/glow_{K}_{out_dist_dataset}_{in_dist_dataset}__gray_{gray_scale}_{feature_extractor}_{date}.csv')
    


    
# %%
