#%%
from numpy import loadtxt
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from loadData import loadData
from torch.utils.data import DataLoader
import random
import warnings
import torch.optim as optim
import pandas as pd
from arg import arg_parse
from tqdm import tqdm

from flow import NICE, RealNVP
from train import Flow_Trainer

from util import CosineAnnealingWarmUpRestarts, singular_value_extract, gmm_fitting
from sklearn.metrics import roc_auc_score, average_precision_score


seed = 42

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False  

if __name__ == '__main__':
    args = arg_parse()
    warnings.filterwarnings("ignore")
    os.chdir(os.getcwd())
    print(args)
    batch_size = args.batch_size
    repeat = args.repeat
    norm_flow_epochs = args.num_epochs
    norm_flow_weight_decay = args.weight_decay
    norm_flow_lr = args.lr
    num_layers  = args.num_layers
    num_coupling_layers = args.num_coupling_layers
    norm_flow_latent_dim = args.latent_dim
    invlayer = args.invlayer
    scaler_flag = args.scaler_flag


    stat_list = []
    
    
    
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    print(f"Device : {device}")
    datasetname_list= sorted(os.listdir('./adbench'))
    print(datasetname_list)
    for datasetname in datasetname_list:
        print(datasetname)
        datasetname = datasetname.split('.')[0]
        auroclist = np.zeros([repeat, 1])
        auprclist = np.zeros([repeat,1])


        for rep_num in tqdm(range(repeat)):
            #if datasetname=='cardio' or datasetname=='campaign' or datasetname=='fraud' or datasetname=='nslkdd' or datasetname =='census':
            #    break

            seed_everything(seed+rep_num)

            train_x, train_y, test_x, test_y, test_normal_x, test_normal_y, test_anomaly_x, test_anomaly_y=loadData(datasetname, scaler_flag)


            if train_x.shape[1] % 2 == 1:
                train_x = np.pad(train_x, ((0,0),(0,1)), 'constant', constant_values=0)
                test_x = np.pad(test_x, ((0,0),(0,1)), 'constant', constant_values=0)
                test_normal_x = np.pad(test_normal_x, ((0,0),(0,1)), 'constant', constant_values=0)
                test_anomaly_x = np.pad(test_anomaly_x, ((0,0),(0,1)), 'constant', constant_values=0)



            train_dataset=torch.utils.data.TensorDataset(torch.tensor(train_x), torch.tensor(train_y))
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_dataset=torch.utils.data.TensorDataset(torch.tensor(test_x), torch.tensor(test_y))
            test_loader = DataLoader(test_dataset, batch_size=batch_size)

            test_dataset_normal=torch.utils.data.TensorDataset(torch.tensor(test_normal_x), 
                                                            torch.tensor(test_normal_y))
            test_normal_loader = DataLoader(test_dataset_normal, batch_size=batch_size)
            test_dataset_anomaly=torch.utils.data.TensorDataset(torch.tensor(test_anomaly_x), 
                                                        torch.tensor(test_anomaly_y))
            test_anomaly_loader = DataLoader(test_dataset_anomaly, batch_size=batch_size)


            input_dim=torch.tensor(train_x).shape[1]

            if args.model == 'nice':
                flow_model = NICE(input_dim, norm_flow_latent_dim,  num_layers, num_coupling_layers, invlayer)
            elif args.model == 'realnvp':
                flow_model = RealNVP(input_dim, norm_flow_latent_dim,  num_layers, num_coupling_layers, invlayer)

            flow_optimizer = optim.AdamW(flow_model.parameters(), lr = norm_flow_lr, weight_decay = norm_flow_weight_decay)
           
            flow_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(flow_optimizer, T_0 = int(norm_flow_epochs/2), T_mult = 1, eta_min=1e-6)
        
            flow_model = flow_model.to(device)
            flow_trainer = Flow_Trainer(flow_model, norm_flow_epochs, train_loader,
                                        test_normal_loader, test_anomaly_loader, flow_optimizer, flow_scheduler, device)

            flow_trainer.fit()

            (train_each_ll_list, test_each_normal_ll_list, test_each_anomaly_ll_list) = flow_trainer.valid()
            train_latent, test_normal_latent, test_anomaly_latent = flow_trainer.extract_latent()

            from sklearn.manifold import TSNE

            tsne_model = TSNE(n_components=2)
            tsne_data = np.concatenate([test_normal_latent.cpu().detach().numpy()[:500],test_anomaly_latent.cpu().detach().numpy()[:500]])
            tsne_result = tsne_model.fit_transform(tsne_data)
            


            plt.scatter(tsne_result[:min(len(test_normal_latent),500),0], tsne_result[:min(len(test_normal_latent),500),1], label='normal', alpha=0.5)
            plt.scatter(tsne_result[min(len(test_normal_latent),500):,0], tsne_result[min(len(test_normal_latent),500):,1], label='anomaly', alpha=0.5)
            plt.legend(prop={'size': 12})
            plt.title(datasetname, fontsize=20)
            plt.savefig(f'./fig/{datasetname}.png')
            plt.clf()



            ll = []
            ll.extend(test_each_normal_ll_list)
            ll.extend(test_each_anomaly_ll_list)
            
            print(len(test_each_normal_ll_list))
            print(len(test_each_anomaly_ll_list))
            label = np.concatenate((test_normal_y, test_anomaly_y))
            '''
            plt.hist(test_each_normal_ll_list, alpha=0.5, label='normal',bins=50)
            plt.hist(test_each_anomaly_ll_list, alpha = 0.5, label='anomaly',bins=50)
            plt.legend()
            plt.title(datasetname)
            plt.savefig(f'./fig/{datasetname}.png')
            plt.clf()
            '''
            auroclist[rep_num] = roc_auc_score(label,np.array(ll) *(-1))
            auprclist[rep_num] =average_precision_score(label, np.array(ll)*(-1))
        print(f"Dataset : {datasetname}")
        print(f"AUROC Mean : {np.mean(auroclist)}")
        print(f"AUROC Var : {np.std(auroclist)}")
        print(f"AUPRC Mean : {np.mean(auprclist)}") 
        print(f"AUPRC Var : {np.std(auprclist)}")
      
        stat_list.append([datasetname, np.mean(auroclist), np.std(auroclist),np.mean(auprclist),np.std(auprclist)])
    

    col = ['dataset','auc_mean', 'auc_var', 'auprc_mean', 'auprc_var']
    df = pd.DataFrame(stat_list, columns=col)
    print(args)
    print(df)
    df.to_csv(f'./result/stat_{args.model}_{norm_flow_lr}_BS_{batch_size}_latent_dim_{args.latent_dim}_num_layers_{args.num_layers}_scaler_{scaler_flag}.csv')
## %%

# %%
