import argparse
from preprocess import *
from extract_dfa import *
from utils import *
from lstar_dfa_utils import *
import json

def get_alphabet(grammar):
    if grammar[:6] == "tomita" : return "01"
    elif grammar[:4] == "dyck":
        num = int(grammar[5]) 
        all_pairs = ['()', '[]', '{}', '<>', '+-', 'ab', 'xo', '@#']
        return "".join(all_pairs[:num])

def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--suffix', default="")
    parser.add_argument('--seed', type = int, default=1234)
    parser.add_argument('--dfa_path', type=str, required = True)
    parser.add_argument('--result_file', type = str, required = True)

    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--model', default="lstm")
    parser.add_argument('--mtype', default="char_prediction")
    parser.add_argument('--rnn_cell', default=None)

    parser.add_argument('--dataset', type = str, default="msi")
    parser.add_argument('--grammar', type = str, default="tomita-1")
    parser.add_argument('--dtype', type = str, default="test")
    parser.add_argument('--nbin', type=int, default=0)
   # parser.add_argument('--cluster_method', type = str, default="som")

    parser.add_argument('--remove_start', action = "store_true")
    parser.add_argument('--remove_end', action = "store_true")
    parser.add_argument('--shuffle_data', action = "store_true")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    #torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    
    model_name = "dfa_lstar_{}".format(session_args.model)
    if session_args.rnn_cell is not None:
        suffix = "{0}_{1}_{2}_seed_{3}".format(model_name, session_args.rnn_cell, session_args.suffix, session_args.seed)
    else:
        suffix = "{0}_{1}_seed_{2}".format(model_name, session_args.suffix, session_args.seed)
    lfname = '../logs/test_{0}.log'.format(suffix)
    
    logger = create_logger(suffix, lfname, seed = session_args.seed, resume = False )

    #================================================================================================================================
    # Dataset
    #================================================================================================================================
    vocab = get_vocab(session_args)

    vocab_size = len(vocab)
    session_args.vocab_size = vocab_size
    session_args.batch_size = 10000
    session_args.valid_batch_size = session_args.batch_size


    if session_args.dtype == "train":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = train_data
    elif session_args.dtype == "val":
        train_data, val_data = load_dataset(session_args, logger)
        test_data = val_data
    else:
        test_data = load_test_dataset(session_args, logger)

    test_words, test_y = words_from_dataset(test_data, vocab, remove_start=session_args.remove_start, remove_end=session_args.remove_end)
    alphabet = get_alphabet(session_args.grammar)

    # input_seq, out_seq, seq_len = next(iter(test_data))
    # # assert input_seq.shape[0] == len(test_data.dataset)

    # mask = torch.arange(input_seq.shape[1])
    # mask = mask[None] == (seq_len-2)[:, None] 
    # y = torch.masked_select(out_seq, mask)

    # alphabet = {v : k for k,v in vocab.items()}
    # del alphabet[0]
    # if session_args.remove_start:
    #     del alphabet[3]
    #     input_seq = input_seq[:, 1:]
    #     seq_len = seq_len - 1
    # if session_args.remove_end:
    #     del alphabet[4]
    #     seq_len = seq_len - 1

    print(alphabet)
    #================================================================================================================================
    # Load DFA
    #================================================================================================================================

    dfa = load_lstar_dfa(session_args.dfa_path)
    [precision, recall, accuracy, f1, fp, fn] =  test_lstar_dfa(words=test_words, y_true=test_y, dfa=dfa) 
    
    result_doc = {
        "dfa" : session_args.dfa_path,
        "seed" : session_args.seed,
        "n_states" : len(dfa.Q),
        "state_dim" : session_args.state_dim,
        "dataset": session_args.dataset,
        "grammar": session_args.grammar,
        "dtype": session_args.dtype,
        "nbin":session_args.nbin,
        "mtype":session_args.mtype,
        "model": session_args.model,
        "rnn_cell": session_args.rnn_cell
    }

    result_doc["result"] = {
        "precision" :float(precision), 
        "recall" : float(recall),
        "accuracy" : float(accuracy),
        "f1": float(f1),
        "fp": int(fp),
        "fn": int(fn),
        "total_pos" : torch.sum(torch.tensor(test_y)).item(),
        "total_neg" : torch.sum(~torch.tensor(test_y)).item()
    }

    print(result_doc)

    #================================================================================================================================
    # Save results to file
    #================================================================================================================================

    if os.path.exists(session_args.result_file):
        with open(session_args.result_file, "r+") as ff:
            res_data = json.load(ff)
        
        res_data.append(result_doc)
    else:
        res_data = [result_doc]
    
    with open(session_args.result_file, "w+") as ff:
        json.dump(res_data, ff)