import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import argparse

from dyck import RandomWalkDyck
from lm import conditional_nn_generate, nn_log_probs
from config_for_dyck import config
from utils import dyck_reward, one_hot_encode
from value_functions import ValueFunction
from policies import LMPolicy

import pickle
from samplers import *
import multiprocessing
import time
from tqdm import tqdm

def run_value_guided_generation_example(dyck_gen, value_functions, model, vocab_size, B, A, eval_prefixes):
    """
    Run value-guided generation example comparing tokenwise vs JS sampling.
    
    Args:
        dyck_gen: Dyck generator
        value_functions: List of trained value functions
        model: Language model
        vocab_size: Vocabulary size
        B: Completion length
        A: Prefix length
        eval_prefixes: List of evaluation prefixes to use
        
    Returns:
        tuple: (results, tokenwise_accuracy, js_accuracy)
    """
    # Instantiate LMPolicy once for all samples
    piref = LMPolicy(model, dyck_gen, vocab_size)
    
    js_sampler = JSSamplerWithCompletions(
        piref=piref,
        prefix=[],  # Will be set for each sample
        value_functions=value_functions,
        completion_length=B
    )
    
    results = []
    tokenwise_valid_count = 0
    js_valid_count = 0
    total_samples = len(eval_prefixes)
    
    for sample_idx, prefix_sequence in enumerate(eval_prefixes):
        # Use prefix from evaluation set
        if prefix_sequence[0] == dyck_gen.bos:
            prompt = prefix_sequence[:1+A]  # Keep BOS + A tokens
        else:
            prompt = [dyck_gen.bos] + prefix_sequence[:A]  # Add BOS + A tokens

        full_sequence_tokenwise, steps_tokenwise = tokenwise_sampler_with_completions(
            prompt, value_functions, list(range(vocab_size)), piref, B
        )

        # Update the prefix for this sample
        js_sampler.prefix = prompt

        full_sequence_js, steps_js = js_sampler.sample()
        
        # Check if sequences are valid Dyck sequences
        tokenwise_detokenized = dyck_gen.detokenize(full_sequence_tokenwise)
        js_detokenized = dyck_gen.detokenize(full_sequence_js)
        
        # Use dyck_reward to check validity (1.0 = valid, 0.0 = invalid)
        tokenwise_valid = dyck_reward(full_sequence_tokenwise, dyck_gen) == 1.0
        js_valid = dyck_reward(full_sequence_js, dyck_gen) == 1.0
        
        # Count valid sequences
        if tokenwise_valid:
            tokenwise_valid_count += 1
        if js_valid:
            js_valid_count += 1
        
        print(f"trial {sample_idx + 1}: Tokenwise sequence: {tokenwise_detokenized}, {str(tokenwise_valid)}, JS sampler: {js_detokenized}, {str(js_valid)}")
        
        results.append((full_sequence_tokenwise, full_sequence_js))
    
    # Calculate accuracies
    tokenwise_accuracy = tokenwise_valid_count / total_samples
    js_accuracy = js_valid_count / total_samples
    
    return results, tokenwise_accuracy, js_accuracy


def generate_multiple_completions_per_prefix(dyck_gen, sampler_class, value_functions, model, vocab_size, B, A, eval_prefixes, K=20, batch_generations = True, use_rewards = False):
    """
    Generate K independent completions for each prefix using both tokenwise and JS sampling.
    
    Args:
        dyck_gen: Dyck generator
        value_functions: List of trained value functions
        model: Language model
        vocab_size: Vocabulary size
        B: Completion length
        A: Prefix length
        eval_prefixes: List of evaluation prefixes
        K: Number of independent completions per prefix
        
    Returns:
        tuple: (tokenwise_results, js_results)
        Each result is a list of dictionaries with keys: 'prefix_idx', 'completion_idx', 'sequence', 'steps', 'reward'
    """
    # Instantiate LMPolicy once for all samples
    piref = LMPolicy(model, vocab_size)
    
    sampler = sampler_class(
        piref=piref,
        value_functions=value_functions,
        completion_length=B+1
    )
    
    valid_count = 0
    total_samples = len(eval_prefixes)
    
    if batch_generations:
        jobs = []
        for prefix_idx in range(total_samples):
            for completion_idx in range(K):
                jobs.append({
                    'prefix_idx': prefix_idx,
                    'completion_idx': completion_idx,
                    'prefix': eval_prefixes[prefix_idx]
                })
        results = sampler.sample_from_jobs(jobs)
        for res in results:
            full_sequence = res['sequence']
            res['reward'] = (dyck_reward(full_sequence, dyck_gen) == 1.0)
            if res['completion_idx'] == 0:
                valid_count += res['reward']
    else:
        results = []
        
        for prefix_idx in range(total_samples):
            def reward_func(token_ids):
                return dyck_reward(token_ids, dyck_gen)
            if use_rewards:
                sampler.reward_func = reward_func
            if prefix_idx % 10 == 0:
                print(f"Completed {prefix_idx} prefixes out of {total_samples}", flush=True)
            prefix_sequence = eval_prefixes[prefix_idx]
            assert(len(prefix_sequence) == A+1)
            #print(f"Generating completions for prefix {dyck_gen.detokenize(prefix_sequence)}", flush=True)
            start_time = time.perf_counter()
            # Generate K completions for each prefix
            prefix_results = sampler.sample_multiple_completions(prefix_sequence, K)
            for completion_idx in range(K):
                #print(prefix_idx, completion_idx)
                full_sequence, step_count = prefix_results[completion_idx]
                #print(full_sequence)
                reward = (dyck_reward(full_sequence, dyck_gen) == 1.0)
                #first_error_loc = dyck_gen.return_1st_err_pos(full_sequence)  
                # Store results
                results.append({
                    'prefix_idx': prefix_idx,
                    'completion_idx': completion_idx,
                    'sequence': full_sequence,
                    'steps': step_count,
                    'reward': reward
                #    'weight_list': weight_list,
                #    'first_error_loc': first_error_loc
                })
                
                # Calculate accuracy using first completion from each prefix
                if completion_idx == 0:
                    if prefix_idx < 10:
                        print(f"Completion: {dyck_gen.detokenize(full_sequence)}; reward: {reward}", flush=True)
                    valid_count += reward
                
            #end_time = time.perf_counter()
            #print(f"Elapsed time {end_time-start_time}")
                    
    
    accuracy = valid_count / total_samples
    return results, accuracy    


def save_completions_to_pickle(tokenwise_results, js_results, output_dir="completion_results"):
    """
    Save the completion results to pickle files.
    
    Args:
        tokenwise_results: List of tokenwise completion results
        js_results: List of JS completion results
        output_dir: Directory to save pickle files
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # Save tokenwise results
    tokenwise_filename = os.path.join(output_dir, "tokenwise_completions.pkl")
    with open(tokenwise_filename, 'wb') as f:
        pickle.dump(tokenwise_results, f)
    print(f"Saved tokenwise completions to: {tokenwise_filename}")
    
    # Save JS results
    js_filename = os.path.join(output_dir, "js_completions.pkl")
    with open(js_filename, 'wb') as f:
        pickle.dump(js_results, f)
    print(f"Saved JS completions to: {js_filename}")


def compute_average_ll(model, results):
    sequences = [res['sequence'] for res in results if res['reward'] == 1]
    num_correct = len(sequences)
    print("Number of correct generations:", num_correct)
    log_probs = nn_log_probs(model, sequences)
    print("Average log prob:", np.mean(log_probs))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test prefix completion pipeline')

    parser.add_argument('--model_path', type=str)
    parser.add_argument('--value_path', type=str)
    parser.add_argument('--prefix_eval_path', type=str)
    parser.add_argument('--prefix_length', type=int)
    parser.add_argument('--total_length', type=int)
    parser.add_argument('--sampler', type=str)
    parser.add_argument('--K', type=int, default=5, help='Number of independent completions per prefix')
    parser.add_argument('--first_n', type=int, default=-1)    
    parser.add_argument("--batched", action='store_true')
    args = parser.parse_args()
    
    A = args.prefix_length
    B = args.total_length - A
    vocab_size = 2 * config['num_types'] + 4

    dyck_gen = RandomWalkDyck(config)

    print("CPU count", multiprocessing.cpu_count())

    #Load model
    with open(args.model_path, 'rb') as f:
        model = pickle.load(f)
    model.eval()
    model.cuda()
    with open(args.value_path, 'rb') as f:
        value_functions = pickle.load(f)

    with open(args.prefix_eval_path, 'rb') as f:
        prefix_eval_list = pickle.load(f)

    if args.first_n != -1:
        prefix_eval_list = prefix_eval_list[:args.first_n]

    if args.sampler[:6] == "Oracle":
        use_rewards_at_final_step = True
        sampler_name = args.sampler[6:]
        assert sampler_name in ['JS','MJS', 'TW', 'LM']
    else:
        use_rewards_at_final_step = False
        sampler_name = args.sampler[:]

    if sampler_name == "JS":
        sampler = GenericJSSampler()
    elif sampler_name == "MJS":
        sampler = MomentumJSSampler
    elif sampler_name[:3] == "WJS":
        up_prob = float(sampler_name[3:])
        sampler = GenericJSSampler(up_prob)
        print(f"Weighted JS Sampler with up probability {up_prob}",flush=True)
    elif sampler_name == "TW":
        sampler = TokenwiseSamplerWithCompletions
    elif sampler_name == "TWMAXVAL":
        sampler = TokenwiseArgmaxValueSampler
    elif sampler_name == "LM":
        sampler = UnguidedLMSampler
    elif sampler_name[:5] == "LMBoN": #LM BoN
        best_of = int(sampler_name[5:])
        sampler = GenericEstimatedBestOfNSampler(UnguidedLMSampler, best_of)
    elif sampler_name[:5] == "TWBoN": #TW BoN
        best_of = int(sampler_name[5:])
        sampler = GenericEstimatedBestOfNSampler(TokenwiseSamplerWithCompletions, best_of)
    elif sampler_name[:6] == "TWProp":
        num_candidates = int(sampler_name[6:])
        sampler = GenericPropSampler(TokenwiseSamplerWithCompletions, num_candidates)
    elif sampler_name[:4] == "BBoN": #Block BoN
        params = sampler_name[4:].split("_")
        num_blocks = int(params[0])
        block_length = int(params[1])
        sampler = GenericBlockBoNSampler(num_blocks, block_length)
    elif sampler_name[:5] == "BProp": #Block Proportional Sampler
        params = sampler_name[5:].split("_")
        num_blocks = int(params[0])
        block_length = int(params[1])
        sampler = GenericBlockPropSampler(num_blocks, block_length)
    else:
        print(f"Error: Sampler name {sampler_name} not valid")


    # Compute output path
    value_dir = os.path.splitext(args.value_path)[0].replace("trained_values", "evals")
    prefix_tag = os.path.splitext(os.path.basename(args.prefix_eval_path))[0]
    output_dir = value_dir + "/" + "evalset_" + prefix_tag
    os.makedirs(output_dir, exist_ok=True)
    output_path = output_dir + "/" + args.sampler + "_K" + str(args.K) + ".pkl"

    if os.path.exists(output_path):
        print(f"Loading from {output_path}")
        with open(output_path, 'rb') as f:
            results = pickle.load(f)
        compute_average_ll(model, results)
    else:
        print("Will save to",output_path)

        # Generate multiple completions per prefix and save to pickle
        print(f"Generating {args.K} completions per prefix...")
        results, accuracy = generate_multiple_completions_per_prefix(
                dyck_gen, sampler, value_functions, model, vocab_size, B, A, prefix_eval_list, args.K, args.batched, use_rewards=use_rewards_at_final_step
            )
        print(f"Accuracy of {args.sampler} is {accuracy}")
        print(f"Saving results...")
        with open(output_path,'wb') as f:
            pickle.dump(results, f)

    '''
        save_completions_to_pickle(tokenwise_results, js_results)
        print(f"Generated {len(tokenwise_results)} tokenwise completions and {len(js_results)} JS completions")
        
        print("\n" + "="*60)
        print("ACCURACY RESULTS")
        print("="*60)
        print(f"Total samples: {len(eval_prefixes)}")
        print(f"Tokenwise sampling accuracy: {tokenwise_accuracy:.4f} ({tokenwise_accuracy*100:.2f}%)")
        print(f"JS sampling accuracy: {js_accuracy:.4f} ({js_accuracy*100:.2f}%)")
        print(f"JS improvement over tokenwise: {(js_accuracy - tokenwise_accuracy)*100:.2f} percentage points")
        print("="*60)
        print("=== DONE ===")
    else:
        print("Skipping evaluation step - no evaluation prefixes available.")
        print("=== DONE ===")


if __name__ == "__main__":
    main()
    '''
