import torch
import json 
from model import *
from dataloader import *
import numpy as np
import argparse
from pytorch_models.utils.logger import *
import random
from sklearn.metrics import classification_report, roc_auc_score, precision_score, recall_score, f1_score
from preprocess import *
from utils import *

def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--suffix', default="")
    parser.add_argument('--seed', type=int, default=42)
    #model
    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--load_model', default="")
    parser.add_argument('--model', default="lstm")
    parser.add_argument('--mtype', default="char_prediction")
    parser.add_argument('--rnn_cell', default=None)
    parser.add_argument('--hinit', type = int, default= 0)
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--nlayers', type=int, default=1)
    #data
    parser.add_argument('--dataset', type = str, default="msi")
    parser.add_argument('--grammar', type = str, default="tomita-1")
    parser.add_argument('--dtype', type = str, default="test")
    parser.add_argument('--nbin', type=int, default=0)
    parser.add_argument('--batch_size', type = int, default="1024")
    parser.add_argument('--partial_strings', action = "store_true")
    #results
    parser.add_argument('--result_file', type = str, default = None)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model_name = "dfa_stability_{}".format(session_args.model)
    suffix = "{0}_{1}_seed_{2}".format(model_name, session_args.suffix, session_args.seed)
    lfname = '../logs/test_{0}.log'.format(suffix)

    logger = create_logger(suffix, lfname, seed = session_args.seed)
    logger.info(session_args)

    #================================================================================================================================
    # Dataset, vocab 
    #================================================================================================================================

    vocab = get_vocab(session_args)

    vocab_size = len(vocab)
    session_args.vocab_size = vocab_size

    session_args.valid_batch_size = session_args.batch_size


    if session_args.dtype == "train":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = train_data
    elif session_args.dtype == "val":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = val_data
    else:
        test_data = load_test_dataset(session_args, logger)

    #================================================================================================================================
    # Model 
    #================================================================================================================================

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    session_args.device = device

    model, model_iter = load_model_for_test(session_args)

    _n_params = model.count_parameters()
    logger.info("Model Size : {}".format(_n_params))

    #================================================================================================================================
    # Results
    #================================================================================================================================

    result_doc = {
       "partial strings": session_args.partial_strings,
       "model_path" : session_args.load_model,
       "n_parameters": int(_n_params),
       "seed" : session_args.seed,
       "state_dim" : session_args.state_dim,
       "dataset": session_args.dataset,
       "grammar": session_args.grammar,
       "dtype": session_args.dtype,
       "nbin":session_args.nbin,
       "mtype":session_args.mtype,
       "model": session_args.model,
       "rnn_cell": session_args.rnn_cell,
       "model_iter": int(model_iter),
       "result":{}
    }

    true_y = []
    pred_y = []
    score_y = []

    with torch.no_grad():
        for batch_data in test_data:
            logits, y = model.run_model(batch_data)
            # input_seq, out_seq, seq_len = [item.to(device) for item in batch_data]
            # clogits, mask = model(input_seq, seq_len)
            # y = torch.masked_select(out_seq, mask)
            klas = logits.round().squeeze()
            # #print(clogits)
            true_y.append(y.detach().cpu())
            pred_y.append(klas.detach().cpu())
            score_y.append(logits.detach().cpu())


    true_y = torch.cat(true_y, dim = -1)
    pred_y = torch.cat(pred_y, dim = -1)
    score_y = torch.cat(score_y, dim = -1)

    accuracy = torch.mean((true_y== pred_y).type(torch.float))

    true_y = true_y.numpy()
    pred_y = pred_y.numpy()
    score_y = score_y.numpy()

    precision = precision_score(true_y, pred_y)
    recall = recall_score(true_y, pred_y)
    f1 = f1_score(true_y, pred_y)

    report = classification_report(true_y, pred_y)
    roc_auc = roc_auc_score(true_y, pred_y)
    

    print(report)
    print("roc_auc : ", roc_auc)
    print("avg : ", accuracy)

    result_doc["result"] = {
        "accuracy" : float(accuracy),
        "roc_auc" : float(roc_auc),
        "recall" : float(recall),
        "precision" : float(precision),
        "f1" : float(f1),
    }

    print(result_doc)

    #================================================================================================================================
    # Save results to file 
    #================================================================================================================================


    if os.path.exists(session_args.result_file):
        with open(session_args.result_file, "r+") as ff:
            res_data = json.load(ff)
        
        res_data.append(result_doc)
    else:
        res_data = [result_doc]
    
    with open(session_args.result_file, "w+") as ff:
        json.dump(res_data, ff)






