import torch
from model import *
from dataloader import *
import numpy as np
import argparse
from pytorch_models.utils.logger import *
import random
from preprocess import *
from lstar_extraction.Extraction import extract
from utils import *
from lstar_dfa_utils import *

tomita_examples = {
            1: ["", "10", "01"],
            2: ["", "11", "00"],
            3: ["", "10", "0011"],
            4: ["", "0001", "1000"],
            5: ["","1", "1110"],
            6: ["", "0", "1"],
            7: ["","01","101010"]
        }

dyck_examples = {
    2: ["([])", "(())", "[[]]" , "(]", "))"],
    3: ["({[]})", "((()))", "[{}]", "({}", "}{[", "(]["],
    6: ["++--", "<+->", "[{<>}]", "++-}", "+-{>", "<)"],
    8: ["xxoo", "<<xxoo>>", "({@#})", "xa@", ">+b", "{#)"]
}

def get_starting_examples(grammar):
    if grammar[:6] == "tomita" : 
        return tomita_examples[int(grammar[-1])]
    elif grammar[:4] == "dyck":
        return dyck_examples[int(grammar[-1])]


def get_alphabet(grammar):
    if grammar[:6] == "tomita" : return "01"
    elif grammar[:4] == "dyck":
        num = int(grammar[5]) 
        all_pairs = ['()', '[]', '{}', '<>', '+-', 'ab', 'xo', '@#']
        return "".join(all_pairs[:num])

def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--suffix', default="")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--load_model', default="")
    parser.add_argument('--model', default="rnn")
    parser.add_argument('--rnn_cell', default=None)
    parser.add_argument('--hinit', type = int, default= 0)
    parser.add_argument('--mtype', default="char_prediction")
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--nlayers', type=int, default=1)
    parser.add_argument('--dataset', type = str, default="msi")
    parser.add_argument('--grammar', type = str, default="tomita-1")
    # parser.add_argument('--dfa_min_state', type = int, default=1)
    # parser.add_argument('--dfa_max_state', type = int, default=100)
    # parser.add_argument('--dfa_state_step', type = int, default=1)
    # parser.add_argument('--cluster_method', type = str, default="som")
    parser.add_argument('--batch_size', type = int, default="1024")
    # parser.add_argument('--valid_batch_size', type = int, default = -1)
    # parser.add_argument('--normalize_states', action = "store_true")
    # parser.add_argument('--tsne', action = "store_true")
    # parser.add_argument('--tsne_fname', type = str, default = None)
    # parser.add_argument('--pos_data_only', action = "store_true")
    parser.add_argument('--partial_strings', action = "store_true")
    parser.add_argument('--remove_start', action = "store_true")
    parser.add_argument('--remove_end', action = "store_true")
    parser.add_argument('--shuffle_data', action = "store_true")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    model_name = "dfa_lstar_{}".format(session_args.model)
    suffix = "{0}_{1}_seed_{2}".format(model_name, session_args.suffix, session_args.seed)

    lfname = '../logs/{0}.log'.format(suffix)    
    logger = create_logger(suffix, lfname, seed = session_args.seed)

    logger.info(session_args)

    print(session_args)

    #================================================================================================================================
    # Dataset, vocab 
    #================================================================================================================================
    batch_size = session_args.batch_size
    session_args.valid_batch_size = batch_size

    vocab = get_vocab(session_args)
    session_args.vocab_size = len(vocab)
    
    train_data, val_data = load_dataset(session_args, logger)

    train_words, train_y = words_from_dataset(train_data, vocab, 
                            remove_start=session_args.remove_start, remove_end=session_args.remove_end)
    
    val_words, val_y = words_from_dataset(val_data, vocab, 
                            remove_start=session_args.remove_start, remove_end=session_args.remove_end)

    alphabet = get_alphabet(session_args.grammar)
    starting_examples = get_starting_examples(session_args.grammar)
    #================================================================================================================================
    # Model 
    #================================================================================================================================

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    session_args.device = device

    lstar_model, test_model = load_model_for_lstar(session_args, alphabet, vocab)

    _n_params = test_model.count_parameters()
    logger.info("Model Size : {}".format(_n_params))


    #================================================================================================================================
    # Extract DFA from model 
    #================================================================================================================================
    print("alphabet : ", alphabet)
    dfa = extract(lstar_model,time_limit = 50, initial_split_depth = 10, starting_examples=starting_examples)

    #================================================================================================================================
    # Test DFA on train and val data
    #================================================================================================================================
    #--- test on train data
    [precision, recall, accuracy, f1, fp, fn] =  test_lstar_dfa(words=train_words, y_true=train_y, dfa=dfa)   
    
    logger.info("---------------- on train data --------------")
    logger.info("precision : {}".format(precision))
    logger.info("recall : {}".format(recall))
    logger.info("accuracy : {}".format(accuracy))
    logger.info("f1 : {}".format(f1))
    logger.info("fp : {}".format(fp))
    logger.info("fn : {}".format(fn))

    #---- test on val data
    [precision, recall, accuracy, f1, fp, fn] =  test_lstar_dfa(words=val_words, y_true=val_y, dfa=dfa)
    logger.info("---------------- on val data --------------")
    logger.info("precision : {}".format(precision))
    logger.info("recall : {}".format(recall))
    logger.info("accuracy : {}".format(accuracy))
    logger.info("f1 : {}".format(f1))
    logger.info("fp : {}".format(fp))
    logger.info("fn : {}".format(fn))
    
    #================================================================================================================================
    # Save DFA
    #================================================================================================================================
    dfa_path = "../lstar_dfa/" + suffix + ".dfa"
    save_lstar_dfa(dfa, dfa_path)
    
    dfa = load_lstar_dfa(dfa_path)
    #print(dfa)
