import numpy as np
from lm import conditional_nn_generate
from utils import dyck_reward, one_hot_encode
import argparse
import pickle
import os
from dyck import RandomWalkDyck
from config_for_dyck import config

def generate_training_data_from_prefix_list(dyck, model, prefix_list):
    completions = conditional_nn_generate(dyck, model, prefix_list, batch_size = 32, top_p=1.0, temperature=1.0)
    completions = [seq['tokens'] for seq in completions]
    A = len(prefix_list[0])
    B = len(completions[0]) - A
    print(f"Prefix length: {B}; suffix length: {A}")

    training_data = [[] for _ in range(B)]
    num_samples = len(prefix_list)
    for i in range(num_samples):
        seq = completions[i]
        reward = dyck_reward(seq, dyck)
        for horizon in range(1, B+1):
            context = seq[:A+horizon]
            training_data[horizon-1].append((context, reward))
            if i <= 10 and horizon == B:
                print(f"Sequence: {dyck.detokenize(context)}, reward: {reward}")
    return training_data

def load_training_dataset(dyck, model, prefix_list, save_path):
    if os.path.isfile(save_path):
        with open(save_path, 'rb') as f:
            print(f"Loading training dataset from {save_path}")
            training_data = pickle.load(f)
    else:
        training_data = generate_training_data_from_prefix_list(dyck, model, prefix_list)
        print(f"Generated training dataset of length {len(training_data)}.")
        corpus_dir = os.path.dirname(save_path)
        os.makedirs(corpus_dir, exist_ok=True)
        with open(save_path, 'wb') as f:
            print(f"Saving training dataset to {save_path}")
            pickle.dump(training_data, f)
    return training_data

def generate_training_data(num_samples, prefix_length, completion_length, vocab_size, dyck_gen, model):
    """
    Generate training data for value functions.
    
    Args:
        num_samples: Number of training samples to generate
        prefix_length: Length of prefix (excluding BOS)
        completion_length: Length of completion to generate
        vocab_size: Size of vocabulary
        dyck_gen: Dyck grammar generator
        model: Language model for generating completions
    
    Returns:
        List of B lists, each containing (input, reward) pairs for horizon h=1..B
    """
    # Initialize training data: one list per value function (horizon 1 to B)
    training_data = [[] for _ in range(completion_length)]
    
    for _ in range(num_samples):
        # Step 1: Generate a valid Dyck sequence as the ground truth
        full_sequence_dict = dyck_gen.generate()
        full_sequence = full_sequence_dict['tokens']
        prefix = full_sequence[1:1+prefix_length]  # Extract prefix (excluding BOS)
        
        # Step 2: Use LM to generate completion for the prefix
        lm_completions = conditional_nn_generate(
            dyck=dyck_gen,
            model=model,
            prefix_sequences=[[dyck_gen.bos] + prefix],
            batch_size=1,
            temperature=1.0,
            max_new_tokens=completion_length
        )
        
        # Extract the completion (remove BOS, prefix, and EOS)
        completion_sequence = lm_completions[0]['tokens']
        completion = completion_sequence[len(prefix)+1:-1]
        
        # Step 3: Calculate reward for the full completed sequence
        full_completed_sequence = [dyck_gen.bos] + prefix + completion + [dyck_gen.eos]
        reward = dyck_reward(full_completed_sequence, dyck_gen)
        
        # Step 4: Create training examples for each horizon h=1..B
        for horizon in range(1, completion_length + 1):
            # Input sequence: BOS + prefix + completion[:horizon]
            input_sequence = [dyck_gen.bos] + prefix + completion[:horizon]
            input_encoding = one_hot_encode(input_sequence, 1 + prefix_length + horizon, vocab_size)
            training_data[horizon-1].append((input_encoding, reward))
    
    return training_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--prefix_path", type=str)
    args = parser.parse_args()
    
    #Load model
    with open(args.model_path, 'rb') as f:
        model = pickle.load(f)

    with open(args.prefix_path, 'rb') as f:
        prefix_list = pickle.load(f)

    

    #Identify model meta-parameters from path
    path_components = args.model_path.split("/")
    path_components = [p for p in path_components if "dyck" in p]
    assert(len(path_components) == 1)
    training_string = path_components[0]
    training_components = training_string.split("_")
#    config = {'num_types': 2, 'type_probs': [0.8, 0.2]} # OOD bracket distribution
#    for p in training_components:
#        if p[0] == 'm':
#            config['max_depth'] = int(p[1:])
#        elif p[:3] == 'len':
#            config['length'] = int(p[3:])
#    assert('max_depth' in config.keys())
#    assert('length' in config.keys())

    print("Config", config)

    dyck = RandomWalkDyck(config)

    training_path = args.prefix_path.replace("prefix_datasets", "value_datasets")
    training_data = load_training_dataset(dyck, model, prefix_list, training_path)
