import argparse
from preprocess import *
from extract_dfa import *
from utils import *
import json

def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--suffix', default="")
    parser.add_argument('--seed', type = int, default=1234)
    parser.add_argument('--dfa_path', type=str, required = True)
    parser.add_argument('--result_file', type = str, required = True)

    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--model', default="lstm")
    parser.add_argument('--mtype', default="char_prediction")
    parser.add_argument('--rnn_cell', default=None)

    parser.add_argument('--dataset', type = str, default="msi")
    parser.add_argument('--grammar', type = str, default="tomita-1")
    parser.add_argument('--dtype', type = str, default="test")
    parser.add_argument('--nbin', type=int, default=0)
    parser.add_argument('--cluster_method', type = str, default="som")

    parser.add_argument('--remove_start', action = "store_true")
    parser.add_argument('--remove_end', action = "store_true")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    #torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    
    model_name = "dfa_stability_{}".format(session_args.model)
    if session_args.rnn_cell is not None:
        suffix = "{0}_{1}_{2}_seed_{3}".format(model_name, session_args.rnn_cell, session_args.suffix, session_args.seed)
    else:
        suffix = "{0}_{1}_seed_{2}".format(model_name, session_args.suffix, session_args.seed)
    lfname = '../logs/train_{0}.log'.format(suffix)
    
    logger = create_logger(suffix, lfname, seed = session_args.seed, resume = False )

    #================================================================================================================================
    # Dataset
    #================================================================================================================================
    vocab = get_vocab(session_args)

    vocab_size = len(vocab)
    session_args.vocab_size = vocab_size
    session_args.batch_size = 20000
    session_args.valid_batch_size = session_args.batch_size


    if session_args.dtype == "train":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = train_data
    elif session_args.dtype == "val":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = val_data
    else:
        test_data = load_test_dataset(session_args, logger)

    # test_data.batch_size = len(test_data.dataset)

    input_seq, out_seq, seq_len = next(iter(test_data))
    # assert input_seq.shape[0] == len(test_data.dataset)

    mask = torch.arange(input_seq.shape[1])
    mask = mask[None] == (seq_len-2)[:, None] 
    if out_seq.dim() > 1:
        y = torch.masked_select(out_seq, mask)
    else:
        y = out_seq

    alphabet = {v : k for k,v in vocab.items()}
    
    if session_args.grammar[:6] == "tomita":
        # remove PAD
        del alphabet[0]
        if session_args.remove_start:
            del alphabet[3]
        if session_args.remove_end:
            del alphabet[4]
    elif session_args.grammar[:4] == "dyck":
        # remove PAD
        del alphabet[0]
        # remove UNK
        del alphabet[3] 
        if session_args.remove_start:
            del alphabet[1]
            input_seq = input_seq[:, 1:]
            seq_len -= 1
        if session_args.remove_end:
            del alphabet[2]
            seq_len -= 1

    print(alphabet)
    #================================================================================================================================
    # Load DFA
    #================================================================================================================================

    dfa = load_dfa(session_args.dfa_path)
    [precision, recall, accuracy, f1, fp, fn] =  test_dfa(logger, data=input_seq.numpy(), inp_len=seq_len.numpy(), 
                                                            y_true=y.numpy(), dfa=dfa, alphabet=alphabet)
    
    result_doc = {
        "dfa" : session_args.dfa_path,
        "seed" : session_args.seed,
        "n_states" : len(dfa.states),
        "state_dim" : session_args.state_dim,
        "dataset": session_args.dataset,
        "grammar": session_args.grammar,
        "dtype": session_args.dtype,
        "nbin":session_args.nbin,
        "mtype":session_args.mtype,
        "model": session_args.model,
        "rnn_cell": session_args.rnn_cell,
        "cluster_method" : session_args.cluster_method
    }

    result_doc["result"] = {
        "precision" :float(precision), 
        "recall" : float(recall),
        "accuracy" : float(accuracy),
        "f1": float(f1),
        "fp": int(fp),
        "fn": int(fn),
        "total_pos" : torch.sum(y).item(),
        "total_neg" : torch.sum(1-y).item()
    }

    print(result_doc)

    #================================================================================================================================
    # Save results to file
    #================================================================================================================================

    if os.path.exists(session_args.result_file):
        with open(session_args.result_file, "r+") as ff:
            res_data = json.load(ff)
        
        res_data.append(result_doc)
    else:
        res_data = [result_doc]
    
    with open(session_args.result_file, "w+") as ff:
        json.dump(res_data, ff)