import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
import jsonlines
import os
# from vllm_inject import sequence_inject, sample_output_inject, model_runner_inject, llm_engine_inject, scheduler_inject, config_inject
# from vllm_inject.utils import *
import json, re
import evaluate
from transformers import AutoTokenizer
from typing import Iterable, Dict, List, Optional, Union
import gzip
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import defaultdict, Counter

from codeutils.data import write_jsonl, read_problems
from codeutils.evaluation import evaluate_functional_correctness, clean_generation

os.environ["TOKENIZERS_PARALLELISM"] = "false"
def trim_output(output):
    instruction_prefix = "Answer the following question"
    question_prefix = 'Question:'
    comment_prefix = 'Comment:'  # for some reason, Llama 13B likes to generate these comments indefinitely

    for prefix in [instruction_prefix, question_prefix, comment_prefix]:
        if prefix in output:
            output = output.split(prefix)[0]

    return output

def clean_generation(output):
    """Clean generated code by removing assert & print statements and comments"""

    lines = output.split('\n')
    cleaned_lines = []
    banned_list = ["if __name__ == \"__main__\":"]
    for li in lines:
        if any([li.strip().startswith(a) for a in banned_list]):
            break
        if not any([li.strip().startswith(a) for a in ['assert', 'print']]):
            cleaned_lines.append(li)
    cleaned_output = '\n'.join(cleaned_lines)
    cleaned_output = cleaned_output.rstrip()

    return cleaned_output

def get_equation_lhs_rhs_indices(tokens):
    """
    Returns two lists of indices, one for tokens in the LHS of equations and one for those in the RHS.

    Args:
        tokens: list of str
    """
    equal_indices = [i for i, x in enumerate(tokens) if x == '=']
    lhs_idx, rhs_idx = [], []

    for equal_idx in equal_indices:
        # go left until it's no longer a number or symbol
        left_idx, right_idx = equal_idx - 1, equal_idx + 1
        while True:
            if left_idx < 0 or not (tokens[left_idx].isdigit() or tokens[left_idx] in ",$€+-x*/"):
                break
            lhs_idx.append(left_idx)
            left_idx -= 1

        # go right until it's no longer a number or symbol
        while True:
            if right_idx >= len(tokens) or \
                 not (tokens[right_idx].isdigit() or tokens[right_idx] in ",$€+-x*/"):
                break
            rhs_idx.append(right_idx)
            right_idx += 1

    return lhs_idx, rhs_idx

unbiased_sampling_size_n = 1
eval_pass_at_ks = [1, 10]
@torch.inference_mode()
def get_code_output(base_model,
                   tokenizer,
                   max_tokens,
                   batch_size,
                   temperature,
                   top_p,
                   use_chat_format=False,
                   save_dir= "outputs/code"):
    print("Loading data...")
    test_data = list(read_problems().values())
    prompts = [example["prompt"] for example in test_data]

    stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]
    stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
    banned_sequences = ['pass', '...']
    banned_id_sequences = [tokenizer.encode(x, add_special_tokens=False) for x in banned_sequences]
    
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop_token_ids=stop_sequences)
    outputs_per_sampling_iter = []
    for sampling_iter in range(unbiased_sampling_size_n):
        print(f"Sampling iter: {sampling_iter} / {unbiased_sampling_size_n}")
        iter_save_path = os.path.join(save_dir, 'sampling_iterations', f'{sampling_iter}.jsonl')
        all_results = []
        for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
            batch_prompts = prompts[i: i + batch_size]
            base_output = base_model.generate(batch_prompts, sampling_params)
            for j in range(len(base_output)):
                all_results.append(
                    {"inputs": batch_prompts[j],
                    "output": base_output[j].outputs[0].text}
                )
        sampling_outputs = [clean_generation(o["output"]) for o in all_results]
        pd.DataFrame({'output': sampling_outputs}).to_json(iter_save_path, lines=True, orient='records')
        outputs_per_sampling_iter.append(sampling_outputs)
    outputs = []
    for i in range(len(prompts)):
        for j in range(unbiased_sampling_size_n):
            outputs.append(outputs_per_sampling_iter[j][i])
    return test_data, all_results, outputs

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/CodeLlama-7b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/code"):
    # load model
    # clear_share_io()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    exact_match = evaluate.load("exact_match")
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.95, enforce_eager=True, max_num_seqs=max_num_seqs)
    test_data, all_results, outputs = get_code_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, save_dir=save_dir)
    
    duplicate_test_data = [
        example for example in test_data for _ in range(unbiased_sampling_size_n)
    ]
    assert len(duplicate_test_data) == len(outputs)
    predictions = [{"task_id": example["task_id"], "prompt": example["prompt"], "completion": output} for example, output in zip(duplicate_test_data, outputs)]
    prediction_save_path = os.path.join(save_dir, "predictions.jsonl")
    write_jsonl(prediction_save_path, predictions)
    
    pass_at_k_results = evaluate_functional_correctness(
        sample_file=prediction_save_path,
        k=eval_pass_at_ks,
        problems={example["task_id"]: example for example in test_data},
        n_workers=64
    )

    print(pass_at_k_results)
    
    with open(os.path.join(save_dir, f"predictions_{model_name.replace('/', '#')}.jsonl"), "w") as fout:
        json.dump(pass_at_k_results, fout)
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
