import argparse, os
from dataset.gamedata import GameDataset
from dataset.json_graph import JsonToGraph
# from torch_geometric.data.dataloader import DataLoader as GraphDataLoader
from torch_geometric.loader import DataLoader as GraphDataLoader
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
from Autoregressive_model import AutoregressiveModel
# from gan_model import GANModel
from train import train
from test_games import test

def calculate_auc(test_predictions, all_labels):
    '''
    Calculate AUC
        test_predictions: nx1x1 predictions for each datapoint
        all_labels: nx1 binary labels (anomaly = 1)
    '''
    all_preds = np.array([i[0] for i in test_predictions])
    print("AUC:",roc_auc_score(all_labels,all_preds)) #negative since normal data gets high scores

def load_model_configs(model_type,config):
    import json
    n_masks=config.n_masks
    dataset = config.task
    fname=  f"./models/model_configs/{model_type}_{n_masks}_{dataset}.json"
    with open(fname,'r') as f:
        model_hparams = json.load(f)

    return model_hparams

def main(args):
    

    # Create directories if not exist.
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    if args.mode == 'train':
        
        
        dataset = GameDataset(data_path=args.train_dataset_dir, concat_steps=args.winsize, mode='training',ignore_intermediate_nodes= not args.use_json_graph, task = args.task)
        print('size',len(dataset))
        trainloader = GraphDataLoader(dataset, batch_size=args.bsz, shuffle=False)

        val_dataset = GameDataset(data_path=args.val_dataset_dir, concat_steps=args.winsize, mode='validation',ignore_intermediate_nodes=not args.use_json_graph, task = args.task)
        valloader=GraphDataLoader(val_dataset, batch_size=args.bsz, shuffle=False)
        
        hyper_params = load_model_configs(args.model_type,args)
        model = AutoregressiveModel(dataset.num_nodes, dataset.node_feature_dim, dataset.node_info,model_type= args.model_type, config = args, device = args.gpu,train_dataset=dataset.all_node_array, hyper_params=hyper_params)
        model.to('cuda')

        train(model,trainloader, valloader=valloader, epochs=10)

        #test(args)

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='graph-anomaly-detection')
    # model parameters

    parser.add_argument("--model_type", type=str, default="mlp", help="mlp or GCN or GAT or GraphSAGE")
    parser.add_argument("--n_masks", type=int, default=1, help="number of masks")
    parser.add_argument("--winsize", type=int, default=10, help="window size")
    parser.add_argument("--reconstructing", type=bool, default=False, help="Whether the model acts as an encoder")
    
    
    
    parser.add_argument("--emb-dim", type=int, default=32, help="dimension of string embeddings")

    # training parameters
    parser.add_argument("--mode", type=str, default='train', help="gpu")
    parser.add_argument("--gpu", type=int, default=0, help="gpu")
    parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
    parser.add_argument("--bsz", type=int, default=16,help="batch size")
    parser.add_argument('--use_gamma_generator', type=bool, default=False, help='train the generator with the gamma method')
    parser.add_argument('--validation_step', type=int, default=100, help='validation every number of training steps')
    parser.add_argument('--use_json_graph',type=bool,default=False,help='Wether to use grapbh structure derived from json')

    parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this')
    parser.add_argument("--ckpt", type=str, default="checkpoint_gan/")
    
    parser.add_argument("--net_type",type=str,default="GraphSAGE")
    
    parser.add_argument("--is_continue",type=bool,default=False)
    parser.add_argument("--slide_win",type=int,default=2) # for wadi
    parser.add_argument("--slide_stride",type=int,default=1) # for wadi
    parser.add_argument("--use_data_diff",type=bool,default=False)

    parser.add_argument("--task", type=str, default="gridworld", help="what task we are training")
    parser.add_argument('--dropout_ratio', type=float, default=0.0, help='dropout ratio')

    # Directories.
    parser.add_argument("--train_dataset_dir",type=str,default=".")
    parser.add_argument("--val_dataset_dir",type=str,default="./")
    
    parser.add_argument("--test_dataset_dir",type=str,default=".")
    parser.add_argument('--log_dir', type=str, default='./results')
    parser.add_argument('--model_save_dir', type=str, default='./results')
    parser.add_argument('--result_dir', type=str, default='./results')
    parser.add_argument('--inject_novelty_type', type=str, default='cash_change')

    # Step size.
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--model_save_step', type=int, default=1000)
    parser.add_argument('--lr_update_step', type=int, default=1000)

    parser.add_argument('--seed',type=int, default=0)
    parser.add_argument('--mname',type=str, default=None)

    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    main(args)
