import pickle
import transformers
from lm import nn_generate, conditional_nn_generate
from dyck import RandomWalkDyck
import copy
import os
import argparse
import time
from config_for_dyck import config


'''
Complete each prefix in prefix_list with a single rollout from model
Evaluate fraction of correct completions

NOTE: parameter 'dyck' should only be needed for dyck.length (length of completion)
'''
def eval_completion_accuracy(dyck, model, prefix_list, batch_size = 32):
    num_correct = 0
    N = len(prefix_list)
    completions = conditional_nn_generate(dyck, model, prefix_list, batch_size = batch_size, top_p=1.0, temperature=1.0)
    cnt = 20
    for seq in completions:
        if dyck.accept(seq['tokens']):
            num_correct += 1
        if cnt>0:
            print(dyck.detokenize(seq['tokens']), dyck.accept(seq['tokens']))
            cnt -= 1
    return num_correct / N

def generate_prefixes(dyck, prefix_length, n):
    prefix_list = []
    for i in range(n):
        prefix = dyck.generate()['tokens'][: prefix_length + 1]
        prefix_list.append(prefix)
    return prefix_list

def eval_model_ood(dyck, model, prefix_length, n_eval = 1000):
    prefix_list = generate_prefixes(dyck, prefix_length, n_eval)
    return eval_completion_accuracy(dyck, model, prefix_list)

def generate_error_prefixes(dyck, model, prefix_length, n_final, t_test = 1, single_prefix=False, initial_checkpoint = [], checkpoint_save_path = None):
    prefix_list = initial_checkpoint
    last_checkpoint_length = len(initial_checkpoint)
    batch_size = 128
    while len(prefix_list) < n_final:
        batch_list = generate_prefixes(dyck, prefix_length, batch_size)
        for t in range(t_test):
            if len(batch_list) == 0:
                break
            completions = conditional_nn_generate(dyck, model, batch_list, batch_size = batch_size, top_p=1.0, temperature=1.0)
            batch_list = [prefix for (prefix,full) in zip(batch_list,completions) if not dyck.accept(full['tokens'])]
        prefix_list.extend(batch_list)
        if len(prefix_list) // 10 > (len(prefix_list) - len(batch_list)) // 10:
            print(f"Generated {len(prefix_list)} out of {n_final} prefixes", flush=True)
        if (checkpoint_save_path is not None) and len(prefix_list) >= last_checkpoint_length + 500:
            corpus_dir = os.path.dirname(checkpoint_save_path)
            os.makedirs(corpus_dir, exist_ok=True)
            with open(checkpoint_save_path, 'wb') as f:
                print(f"Saving checkpoint of length {len(prefix_list)} to {checkpoint_save_path}")
                pickle.dump(prefix_list, f)
            last_checkpoint_length = len(prefix_list)
        if single_prefix and (len(prefix_list) > 0):
            prefix_list = prefix_list[:1]
            while len(prefix_list) < n_final:
                prefix_list.append(prefix_list[0])
            return prefix_list
    return prefix_list[:n_final]

def load_prefix_dataset(dyck, model, save_path, prefix_length, n_prefixes, t_test=1, single_prefix=False, shortcut_load_path="None"):
    if os.path.isfile(save_path):
        with open(save_path, 'rb') as f:
            print(f"Loading prefix dataset from {save_path}")
            prefix_list = pickle.load(f)
    elif shortcut_load_path != "None":
        with open(shortcut_load_path, 'rb') as f:
            print(f"Loading shortcut dataset from {shortcut_load_path}")
            prefix_list = pickle.load(f)
        #assert len(prefix_list) >= n_prefixes, "Length of shortcut dataset is shorter than length of desired dataset"
        if len(prefix_list) >= n_prefixes:
            print(f"Cutting prefix list to size {n_prefixes}")
            prefix_list = prefix_list[:n_prefixes]
        else:
            print(f"Extending prefix list (by duplication of first element) to size {n_prefixes}")
            k = n_prefixes - len(prefix_list)
            for i in range(k):
                prefix_list.append(prefix_list[0])
        with open(save_path, 'wb') as f:
            print(f"Saving prefix dataset of length {len(prefix_list)} to {save_path}")
            pickle.dump(prefix_list, f)
        return prefix_list
    else:
        checkpoint = []
        if os.path.isfile(save_path + "tmp"):
            with open(save_path + "tmp", 'rb') as f:
                print(f"Loading checkpoint from {save_path}tmp")
                checkpoint = pickle.load(f)
            print(f"Length of checkpoint dataset is {len(checkpoint)}")
        prefix_list = generate_error_prefixes(dyck, model, prefix_length, n_prefixes, t_test, single_prefix=single_prefix, initial_checkpoint = checkpoint, checkpoint_save_path = save_path + "tmp")
        print(f"Generated prefix dataset of length {len(prefix_list)}. First 10 samples:")
        for i in range(10):
            print(dyck.detokenize(prefix_list[i]))
        corpus_dir = os.path.dirname(save_path)
        os.makedirs(corpus_dir, exist_ok=True)
        with open(save_path, 'wb') as f:
            print(f"Saving prefix dataset to {save_path}")
            pickle.dump(prefix_list, f)
    return prefix_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--prefix_folder_name", type=str)
    parser.add_argument("--prefix_length", type=int)
    parser.add_argument("--n", type=int, default=1000)
    parser.add_argument("--t_tests", type=int, default=1)
    parser.add_argument("--single_prefix", action='store_true')
    parser.add_argument("--shortcut_load_path", type=str, default="None")
    parser.add_argument("--override_str", type=str, default="None")
    args = parser.parse_args()
    if args.override_str != "None":
        assert args.single_prefix == True
    
    print("Transformers version", transformers.__version__)

    #Load model
    with open(args.model_path, 'rb') as f:
        model = pickle.load(f)

    #Identify model meta-parameters from path
    path_components = args.model_path.split("/")
    path_components = [p for p in path_components if "dyck" in p]
    assert(len(path_components) == 1)
    training_string = path_components[0]
    training_components = training_string.split("_")
#    config = {'num_types': 3, 'type_probs': [0.5, 0.5, 0.0]} # OOD bracket distribution
#    for p in training_components:
#        if p[0] == 'm':
#            config['max_depth'] = int(p[1:])
#        elif p[:3] == 'len':
#            config['length'] = int(p[3:])
#    assert('max_depth' in config.keys())
#    assert('length' in config.keys())

    print("Config", config)

    dyck = RandomWalkDyck(config)

    #Construct dataset save path
    model_string = training_string.replace("dyck", "model")
    if args.single_prefix:
        if args.override_str == "None":
            dataset_path = "./prefix_datasets/" + model_string + "/" + args.prefix_folder_name + "/n" + str(args.n) + "_t" + str(args.t_tests) + "_single.pkl"
        else:
            dataset_path = "./prefix_datasets/" + model_string + "/" + args.prefix_folder_name + "/n" + str(args.n) + "_t" + str(args.t_tests) + "_single_override.pkl"
            i = 1
            while os.path.exists(dataset_path):
                i += 1
                dataset_path = "./prefix_datasets/" + model_string + "/" + args.prefix_folder_name + "/n" + str(args.n) + "_t" + str(args.t_tests) + f"_single_override{i}.pkl"

    else:
        dataset_path = "./prefix_datasets/" + model_string + "/" + args.prefix_folder_name + "/n" + str(args.n) + "_t" + str(args.t_tests) + ".pkl"

    print("Dataset save path", dataset_path)

    if args.override_str == "None":
        #Generate and save prefix dataset (or load if it already exists)
        prefix_list = load_prefix_dataset(dyck, model, dataset_path, args.prefix_length, args.n, args.t_tests, args.single_prefix, args.shortcut_load_path)
    else:
        prefix = dyck.tokenize(args.override_str)
        prefix_list = [prefix] * args.n
        corpus_dir = os.path.dirname(dataset_path)
        os.makedirs(corpus_dir, exist_ok=True)
        with open(dataset_path, 'wb') as f:
            pickle.dump(prefix_list, f)

    print("First 10 samples:")
    for i in range(10):
        print(dyck.detokenize(prefix_list[i]))

    #Evaluate accuracy of model on prefix_list

    start_time = time.perf_counter()
    acc = eval_completion_accuracy(dyck, model, prefix_list[:1000], batch_size = 32)
    end_time = time.perf_counter()
    print(f"Accuracy on surviving prefixes is {acc}")
    print(f"Evaluation time with batch size 32 is {end_time - start_time:.6f} seconds")


    #start_time = time.perf_counter()
    #acc = eval_completion_accuracy(dyck, model, prefix_list, batch_size = 1)
    #end_time = time.perf_counter()
    #print(f"Accuracy on surviving prefixes is {acc}")
    #print(f"Evaluation time with batch size 1 is {end_time - start_time:.6f} seconds")



#config = {'length': 32, 'num_types': 2, 'max_depth': 12, 'type_probs': [0.8, 0.2]}

#prefix_list = load_prefix_dataset(config, model, "./prefix_datasets/model_k2_0.2_0.8_m12_len32/prefix_0.8_0.2_m12_len16/n1000_t1.pkl",16, 1000, t_test = 1)

#acc = eval_completion_accuracy(dyck, model, prefix_list)
        
#print(f"Accuracy on surviving prefixes is {acc}")



#output = nn_generate(dyck, model, top_p=1.0, temperature=1.0, use_template_implementation=False)
#for seq in output:
#    print(dyck.detokenize(seq['tokens']))
