import argparse, os

from PIL.Image import new
from dataset.data_gan import MyDataset
from dataset.sensordata import SensorDataset 
from dataset.json_graph import JsonToGraph
from torch.utils.data import DataLoader,Dataset
import pickle 
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
import pickle as pkl
from sklearn.metrics import roc_auc_score
import numpy as np
import sys
from natsort import natsorted
from Autoregressive_model import AutoregressiveModel
from gan_model import GANModel
import json
import utils
import copy
from dataset.TabularDataset import make_tabular_datasets
from torch_geometric.loader import DataLoader as GraphDataLoader


def calculate_auc(test_predictions, all_labels):
    '''
    Calculate AUC
        test_predictions: nx1x1 predictions for each datapoint
        all_labels: nx1 binary labels (anomaly = 1)
    '''
    #all_preds = np.array([i[0] for i in test_predictions])
    print("AUC:",roc_auc_score(all_labels,test_predictions)) #negative since normal data gets high scores



def main(args):
    

    # Create directories if not exist.
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    

    
    train_dataset,val_dataset,test_dataset=  make_tabular_datasets(name = args.task, train_test_split = 0.5)

    trainloader = GraphDataLoader(train_dataset, batch_size=args.bsz, shuffle=False, drop_last=True)

    valloader=GraphDataLoader(val_dataset, batch_size=args.bsz, shuffle=False)
    
    
    model = AutoregressiveModel(train_dataset.num_nodes, train_dataset.node_feature_dim, train_dataset.node_info,model= args.model_type, config = args)
    model.to('cuda')

    model.train(trainloader, valloader=valloader, epochs=6)

    testloader = GraphDataLoader(test_dataset, batch_size=args.bsz, shuffle=False)

    testlosses = []
    testpercs = []

    testpernodelosses = []
    testpernodepercentiles = []

    from sklearn.metrics import roc_auc_score
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import auc, f1_score, precision_score, recall_score
    import tqdm

    def concat(x,keys):
        a=[x[k] for k in keys]
        
        return torch.cat(a,axis=1)

    for d in tqdm.tqdm(testloader, total = len(testloader)):
        d.to(model.device)
        percentiles,loss, prediction = model.compute_novelty_score(d, return_prediction=True)
        testlosses.append(loss.detach().cpu().numpy())
        
        testpercs.append(torch.tensor(percentiles['graph'][0]).detach().cpu().numpy())

        testpernodelosses.append(percentiles['losses'])
        testpernodepercentiles.append(percentiles)


    keys = list(set(['numerical','binary','categorical']).intersection(set(testpernodelosses[0].keys())))
        
    testlosses = np.concatenate(testlosses)
    testpercs = np.concatenate(testpercs)
    testfeature_losses = torch.cat([concat(pred,keys) for pred in testpernodelosses],dim=0)
    testfeature_percentiles = torch.cat([concat(pred,keys) for pred in testpernodepercentiles],dim=0)

    print('losses:')
    prds = testlosses
    precision, recall, thresholds = precision_recall_curve(test_dataset.labels, prds)
   
    f1_scores = 2*recall*precision/(recall+precision+1e-5)
    print('Best threshold: ', thresholds[np.argmax(f1_scores)])
    print('Best F1-Score: ', np.max(f1_scores))
    thresholded = np.array(prds)>thresholds[np.argmax(f1_scores)]
    print('Precision:',precision_score(test_dataset.labels,thresholded))
    print('Recall:',recall_score(test_dataset.labels,thresholded))   
    print('AUC: ',roc_auc_score(test_dataset.labels,prds))

    print("percentiles:")
    prds = testpercs

    precision, recall, thresholds = precision_recall_curve(test_dataset.labels, prds)
    f1_scores = 2*recall*precision/(recall+precision+1e-5)
    print('Best threshold: ', thresholds[np.argmax(f1_scores)])
    print('Best F1-Score: ', np.max(f1_scores))
    thresholded = np.array(prds)>thresholds[np.argmax(f1_scores)]
    print('Precision:',precision_score(test_dataset.labels,thresholded))
    print('Recall:',recall_score(test_dataset.labels,thresholded))   
    print('AUC: ',roc_auc_score(test_dataset.labels,prds))



if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='graph-anomaly-detection')
    # model parameters

    parser.add_argument("--model_type", type=str, default="mlp", help="mlp or GCN or GAT or GraphSAGE")
    parser.add_argument("--n_masks", type=int, default=1, help="mlp or GCN or GAT or GraphSAGE")
    
    parser.add_argument("--emb-dim", type=int, default=32, help="dimension of string embeddings")

    # training parameters
    parser.add_argument("--mode", type=str, default='train', help="gpu")
    parser.add_argument("--gpu", type=int, default=0, help="gpu")
    parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
    parser.add_argument("--bsz", type=int, default=16,help="batch size")
    parser.add_argument('--use_gamma_generator', type=bool, default=False, help='train the generator with the gamma method')
    parser.add_argument('--validation_step', type=int, default=100, help='validation every number of training steps')

    parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this')
    parser.add_argument("--ckpt", type=str, default="checkpoint_gan/")
    
    parser.add_argument("--net_type",type=str,default="GraphSAGE")
    
    parser.add_argument("--is_continue",type=bool,default=False)
    parser.add_argument("--slide_win",type=int,default=2) # for wadi
    parser.add_argument("--slide_stride",type=int,default=1) # for wadi
    parser.add_argument("--use_data_diff",type=bool,default=False)
    

    parser.add_argument("--task", type=str, default="gridworld", help="what task we are training")
    parser.add_argument('--dropout_ratio', type=float, default=0.0, help='dropout ratio')

    # Directories.
    parser.add_argument("--train_dataset_dir",type=str,default=".")
    parser.add_argument("--val_dataset_dir",type=str,default="./")
    
    parser.add_argument("--test_dataset_dir",type=str,default=".")
    parser.add_argument('--log_dir', type=str, default='./results')
    parser.add_argument('--model_save_dir', type=str, default='./results')
    parser.add_argument('--result_dir', type=str, default='./results')

    # Step size.
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--model_save_step', type=int, default=1000)
    parser.add_argument('--lr_update_step', type=int, default=1000)

    parser.add_argument('--seed',type=int, default=0)
    parser.add_argument('--mname',type=str, default=None)

    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    main(args)
