import hydra
from omegaconf import OmegaConf

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
import argparse
import random
from generate_prompts import compute_reward, load_dataset
from value_functions import ValueFunction
from policies import LMPolicy
from code_lm_benchmark import GenerateTestCaseTask
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer
import pickle
from samplers import *
import time

def generate_multiple_completions_per_prefix(sampler_class, value_function, model, tokenizer, eval_prefixes, K, cfg_rep, max_new_tokens, device, use_rewards):
    # Instantiate LMPolicy once for all samples
    piref = LMPolicy(model, tokenizer, max_new_tokens, device)

    sampler = sampler_class(
        piref=piref,
        value_function=value_function,
        cfg_rep = cfg_rep,
        device=device
    )
    
    results = []
    valid_count = 0
    total_samples = len(eval_prefixes)
    start_time = time.perf_counter() 
    for prefix_idx in range(total_samples):
        def reward_func(token_ids):
            return compute_reward(eval_prefixes[prefix_idx]['function_name'], eval_prefixes[prefix_idx]['prompt_ids'], token_ids, tokenizer, verbose=False)
        if use_rewards:
            sampler.reward_func = reward_func

        if prefix_idx % 10 == 0:
            print(f"Completed {prefix_idx} prefixes out of {total_samples}", flush=True)
            print("Current valid count:", valid_count)
        prefix_token_ids = eval_prefixes[prefix_idx]['prompt_ids']
        print(f"Current index: {prefix_idx}", flush=True)
        #print("BOS token:", tokenizer.bos_token_id)
        prefix_results = sampler.sample_multiple_completions(prefix_token_ids, K)
        for completion_idx in range(K):
            #print(prefix_idx, completion_idx)
            full_sequence_token_ids, step_count = prefix_results[completion_idx]
            if full_sequence_token_ids[:len(prefix_token_ids)] != prefix_token_ids:
                print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
                print("ERROR ERROR ERROR")
                print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
                print(prefix_token_ids)
                print("--------------------------------------------------------------------")
                print(full_sequence_token_ids)
                print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
                assert False
            #print(full_sequence)
            reward = compute_reward(eval_prefixes[prefix_idx]['function_name'], prefix_token_ids, full_sequence_token_ids, tokenizer, verbose=True)

            # Store results
            results.append({
                'prefix_idx': prefix_idx,
                'completion_idx': completion_idx,
                'prefix_ids': prefix_token_ids,
                'sequence_ids': full_sequence_token_ids,
                'steps': step_count,
                'reward': reward
                #'relevant_weight': relevant_weight,
                #'last_weight': last_weight
            })
            #print(relevant_weight, last_weight, reward, flush=True)
            
            # Calculate accuracy using first completion from each prefix
            #if prefix_idx < 10:
            #    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
            #    print(tokenizer.decode(full_sequence_token_ids))
            #    print("------------------------------------------------------------------")
            #    print(reward)
            #    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
            valid_count += reward
        end_time = time.perf_counter()
        avg_duration = (end_time - start_time) / (prefix_idx + 1)
        print(f"Average duration: {avg_duration:.4f}",flush=True)
            
                
    
    accuracy = valid_count / (total_samples * K)
    return results, accuracy


def compute_mean_length(results, tokenizer):
    lengths = []
    items = {}
    types = {}
    ncorrect = 0
    unsorted = 0
    ood_function_names = ['pop','add','sub','mul','div','max','min','std','avg','exp']
    for i, res in enumerate(results):
        if res['reward'] != 1:
            continue
        ncorrect += 1
        function_name = ood_function_names[res['prefix_idx']%10]
        generation_ids = res['sequence_ids'][len(res['prefix_ids']):]
        generation = tokenizer.decode(generation_ids)
        assert function_name in generation
        task = GenerateTestCaseTask(1,1, len(function_name), function_name=function_name)
        test_cases = []
        task.parse(generation, test_cases, verbose=0)
        nc = 0
        for test_case in test_cases:
            if task.check([test_case],verbose=0)['num_correct'] == 1:
                lengths.append(len(test_case['results']))
                nc += 1
                try:
                    if test_case['results'] != sorted(test_case['results']):
                        unsorted += 1
                except:
                    unsorted += 1
                for ob in test_case['results'][:1]:
                    if type(ob) == int:
                        if ob not in items:
                            items[ob] = 0
                        items[ob] += 1
                    if type(ob) not in types:
                        types[type(ob)] = 0
                    types[type(ob)] += 1
        assert nc == 1
    print("Fraction of correct generations:", ncorrect/len(results))
    print("Average length of RHS list in generations with reward 1:", np.mean(lengths))
    print(sorted(items.items(), key = lambda item: -item[1]))
    print(types)
    print("Fraction unsorted:", unsorted / len(lengths))


def evaluate_correlations(results):
    relevant_weights = []
    last_weights = []
    rewards = []
    for d in results:
        if d['relevant_weight'] is not None:
            relevant_weights.append(d['relevant_weight'])
            last_weights.append(d['last_weight'])
            rewards.append(int(d['reward']))
    print(len(rewards))
    print("Correlation for relevant weights:", np.corrcoef(relevant_weights, rewards))
    print("Correlation for last weights:", np.corrcoef(last_weights, rewards))


CONFIG_NAME = None
if "--config-name" in sys.argv:
        CONFIG_NAME = sys.argv[sys.argv.index("--config-name") + 1]
else:
    CONFIG_NAME = "main_codellama"

@hydra.main(config_path='../hydra_configs', config_name=CONFIG_NAME, version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    if torch.cuda.is_available():
        free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
        best_gpu = free_mem.index(max(free_mem))
        cfg.meta.device = f"cuda:{best_gpu}"
    else:
        cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=torch.float16)

    if cfg.eval.sampler[:6] == "Oracle":
        use_rewards_at_final_step = True
        sampler_name = cfg.eval.sampler[6:]
        assert sampler_name in ['JS', 'TW', 'LM']
    else:
        use_rewards_at_final_step = False
        sampler_name = cfg.eval.sampler[:]

    if sampler_name == "JS":
        sampler = GenericJSSampler()
    elif sampler_name == "MJS":
        sampler = MomentumJSSampler
    elif sampler_name[:3] == "LJS":
        min_steps = int(sampler_name[3:])
        sampler = GenericJSSampler(min_steps=min_steps)
    elif sampler_name[:3] == "WJS":
        up_prob = float(sampler_name[3:])
        sampler = GenericJSSampler(up_prob)
        print(f"Weighted JS Sampler with up probability {up_prob}",flush=True)
    elif sampler_name == "TW":
        sampler = TokenwiseSamplerWithCompletions
    elif sampler_name == "TWMAXVAL":
        sampler = TokenwiseArgmaxValueSampler
    elif sampler_name == "LM":
        sampler = UnguidedLMSampler
    elif sampler_name[:5] == "LMBoN": #LM BoN
        best_of = int(sampler_name[5:])
        sampler = GenericEstimatedBestOfNSampler(UnguidedLMSampler, best_of)
    elif sampler_name[:3] == "BoN": #TW BoN
        best_of = int(sampler_name[3:])
        sampler = GenericEstimatedBestOfNSampler(TokenwiseSamplerWithCompletions, best_of)
    elif sampler_name[:4] == "BBoN": #Block BoN
        params = sampler_name[4:].split("_")
        num_blocks = int(params[0])
        block_length = int(params[1])
        sampler = GenericBlockBoNSampler(num_blocks, block_length)
    elif sampler_name[:5] == "BProp": #Block Proportional Sampler
        params = sampler_name[5:].split("_")
        num_blocks = int(params[0])
        block_length = int(params[1])
        sampler = GenericBlockPropSampler(num_blocks, block_length)
    elif sampler_name == "SeqProp":
        sampler = SequencePropSampler
    elif sampler_name[:6] == "FakeJS":
        sampler = GenericJSSampler(fixed_up_weight = float(sampler_name[6:]))
    elif sampler_name == "FakeLayerJS":
        sampler = GenericJSSampler(fake_layer=True)
    else:
        print(f"Error: Sampler name {args.sampler} not valid")

    assert ("testing" in cfg.fs.vf_save_dir) == ("testing" in cfg.fs.eval_save_path)
    vf_save_path = cfg.fs.vf_save_dir + "/e" + str(cfg.eval.epoch) + ".pkl"
    with open(vf_save_path, "rb") as f:
        value_function = pickle.load(f).to(cfg.meta.device)
        value_function.eval()

    prompts_and_completions = load_dataset(cfg.fs.generation_save_path)
    assert len(prompts_and_completions) == cfg.generation_configs.n
    nbatches = cfg.generation_configs.n // cfg.vf.batch_size
    if nbatches * cfg.vf.batch_size < cfg.generation_configs.n:
        nbatches += 1
    ntrain_batches = int(0.9 * nbatches)
    first_eval_idx = cfg.vf.batch_size * ntrain_batches

    if cfg.eval.function_name == "all":
        eval_list = [d for d in prompts_and_completions[first_eval_idx:]]
    else:
        eval_list = [d for d in prompts_and_completions[first_eval_idx:] if d['function_name'] == cfg.eval.function_name]
    if len(eval_list) > 1000:
        eval_list = eval_list[:1000]
    print("Length of prefix evaluation list:", len(eval_list))

    output_path = cfg.fs.eval_save_path 
    if os.path.exists(output_path):
        print(f"Path already exists; loading results from {output_path}")
        #with open(output_path, 'rb') as f:
        #    results = pickle.load(f)
        #evaluate_correlations(results)
        #compute_mean_length(results, tokenizer)
    else:
        if "llama" in CONFIG_NAME:
            model = LlamaForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
        else:
            model = AutoModelForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
        model.eval()
        print(f"Generating {cfg.eval.K} completions per prefix...")
        results, accuracy = generate_multiple_completions_per_prefix(
                sampler, value_function, model, tokenizer, eval_list, cfg.eval.K, cfg.rep, cfg.generation_configs.max_new_tokens, cfg.meta.device, use_rewards_at_final_step
            )
        print(f"Accuracy of {cfg.eval.sampler} is {accuracy}")
        print(f"Saving results...")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path,'wb') as f:
            pickle.dump(results, f)


if __name__ == "__main__":
    main()
