import hydra
from omegaconf import OmegaConf

import collections
from collections.abc import Iterable
import copy
# import ctranslate2
import datasets
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import random
import sys
import torch
import time
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer
from policies import LMPolicy
from samplers import UnguidedLMSampler
import wandb
import pandas as pd
from dataset_utils import load_dataset

from code_lm_benchmark import (
    GenerateTestCaseTask,
    GenerateTestCaseTaskBenchmark,
    GenerateTestCaseTaskBenchmarkParallel,
    ErrorType,
)
from lm import (
    conditional_nn_generate,
    predict_error,
    sample_top_p,
)



def compute_reward(function_name, prompt_ids, output_ids, tokenizer, verbose = True):
    assert output_ids[:len(prompt_ids)] == prompt_ids
    generation_ids = output_ids[len(prompt_ids):]

    num_test_cases = 1
    num_demonstrations = 1
    assert len(function_name) == 3
    task = GenerateTestCaseTask(num_demonstrations, num_test_cases, len(function_name), function_name=function_name)
    test_cases = []
    generation_str = tokenizer.decode(generation_ids)
    try:
        task.parse(generation_str, test_cases, verbose=verbose)
        results = task.check(test_cases, verbose=verbose)
        reward = (results['num_correct'] == num_test_cases)
    except:
        reward = 0
    if verbose:
        print("++++++++++++++++++++++++++++++++++++")
        print(f"Evaluating Reward: {reward}")
        print("++++++++++++++++++++++++++++++++++++")
        print(tokenizer.decode(prompt_ids))
        #print(prompt_ids)
        print("------------------------------------")
        print(tokenizer.decode(generation_ids))
        print("++++++++++++++++++++++++++++++++++++", flush=True)
        #print(generation_ids)
    #print("++++++++++++++++++++++++++++++++++++")
    #print(f"Reward: {reward}")
    #print("++++++++++++++++++++++++++++++++++++")
    return reward


CONFIG_NAME = None
if "--config-name" in sys.argv:
        CONFIG_NAME = sys.argv[sys.argv.index("--config-name") + 1]
else:
    CONFIG_NAME = "main_codellama"

@hydra.main(config_path='../hydra_configs', config_name=CONFIG_NAME, version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    if torch.cuda.is_available():
        free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
        best_gpu = free_mem.index(max(free_mem))
        cfg.meta.device = f"cuda:{best_gpu}"
    else:
        cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)

    verbose = 1

    if "llama" in CONFIG_NAME:
        model = LlamaForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
        newline_or_null = ""
    else:
        model = AutoModelForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
        newline_or_null = "\n"

    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=torch.float16)

    if os.path.exists(cfg.fs.generation_save_path):
        print("Loading existing prompts")
        data = load_dataset(cfg.fs.generation_save_path)
        for i in range(10):
            print("++++++++++++++++++++++++++++++++++++++++++")
            print(tokenizer.decode(data[i]['prompt_ids']))
            print("------------------------------------------")
            print(tokenizer.decode(data[i]['output_ids']))
            print("------------------------------------------")
            print("Reward", data[i]['reward'])
    else:
        print('generating new prompts')

        
        piref = LMPolicy(model, tokenizer, cfg.generation_configs.max_new_tokens, cfg.meta.device)

        sampler = UnguidedLMSampler(
            piref=piref,
            value_function=None,
            cfg_rep = cfg.rep,
            device=cfg.meta.device
        )

        if cfg.generation_configs.function_name_type == "random":
            ood_function_names = ['ovs', 'cyk', 'mcl', 'heh', 'fgu', 'knk', 'zmf', 'bgz', 'cub', 'dfn']
        else:
            ood_function_names = ['pop','add','sub','mul','div','max','min','std','avg','exp']
        prompts_and_completions = []
        total_generations = {}
        correct_generations = {}
        for function_name in ood_function_names:
            total_generations[function_name] = 0
            correct_generations[function_name] = 0

        for i in range(cfg.generation_configs.n):
            if i%10 == 0:
                print(f"On generation {i}", flush=True)
            task = GenerateTestCaseTask(
                1,  # num_demonstrations
                cfg.generation_configs.tests,  # num_test_cases_per_function
                function_name_length=3,
                function_name=ood_function_names[i%10],
            )
            token_ids_prompt = tokenizer(task.prompt+newline_or_null)['input_ids']
            #print('token_ids_prompt', token_ids_prompt)  # debug
            output_ids, _ = sampler.sample(token_ids_prompt)
            #output_ids = conditional_nn_generate(None, model, [token_ids_prompt], top_p = 1.0, max_new_tokens=cfg.generation_configs.max_new_tokens, tokenizer=tokenizer)[0]['tokens']
            
            reward = compute_reward(task.function_name, token_ids_prompt, output_ids, tokenizer, verbose=True)
            
            correct_generations[task.function_name] += reward

            prompts_and_completions.append({"function_name": task.function_name, "prompt_ids": token_ids_prompt, "output_ids": output_ids, "reward": reward})
            

        print("Total generations:", total_generations)
        print("Correct generations:",correct_generations)
        print(f'Saving to {cfg.fs.generation_save_path}')
        save_dir = os.path.dirname(cfg.fs.generation_save_path)
        os.makedirs(save_dir, exist_ok=True)
        df = pd.DataFrame(prompts_and_completions)
        df.to_csv(cfg.fs.generation_save_path, index=False)
        #with open(cfg.fs.generation_save_path,'wb') as f:
        #    pickle.dump(prompts_and_completions, f)


if __name__ == '__main__':
    main()
