import os
import json
import fire
import tiktoken
from data.utils import load_data, construct_trainable_data
from tqdm import tqdm
from model.codex import query_with_retry, get_frist_line


def main(
        language:str = "java", 
        length="2k", 
        max_new_tokens:int = 64,
        resume_part:str = "cross_file_first",
        resume:int = 0,
        ):

    # load datasets
    settings = ["cross_file_first", "cross_file_random", "in_file"]
    datasets = load_data(split="test", task="completion", language=language, length=length, settings=settings)

    # tokenizer
    tokenizer = tiktoken.encoding_for_model("code-davinci-002")

    # now we can sample
    for data_part, dataset in zip(settings, datasets):
        dataset = construct_trainable_data(dataset, language=language)

        if resume_part == "cross_file_random" and data_part == "cross_file_first":
            continue
            
        if resume_part == "in_file" and data_part in ["cross_file_first", "cross_file_random"]:
            continue
        
        if resume_part == data_part:
            dataset = dataset[resume:]

        if data_part == resume_part:
            initial = resume
        else:
            initial = 0

        for data_idx, data in tqdm(enumerate(dataset), total=len(dataset)+initial, initial=initial):

            # load data
            prompt = data['data']
            label = data['label']

            # get the length of the prompt
            prompt_length = len(tokenizer.encode(prompt, disallowed_special=()))

            # if the prompt is too long, we need to reduce the max_new_tokens
            max_new_tokens = min(8001 - prompt_length, max_new_tokens)

            # if the max_new_tokens is less than 0 or too small (<32), we try to reduce the prompt
            if max_new_tokens < 32:
                # randomly drop some lines that are consecutive comments at the beginning of the prompt
                lines = prompt.split("\n")
                possible_line_idx = []
                for i, line in enumerate(lines):
                    if language == "python" and line.startswith("#"):
                        possible_line_idx.append(i)
                    elif language == "java" and line.startswith("//"):
                        possible_line_idx.append(i)
                    else:
                        break
                
                # if we can drop some lines
                length_to_drop = 32 - max_new_tokens

                # select some lines to drop until their length is enough
                drop_line_idx = []
                import random
                random.shuffle(possible_line_idx)
                for idx in possible_line_idx:
                    drop_line_idx.append(idx)
                    length_to_drop -= len(tokenizer.encode(lines[idx]))
                    max_new_tokens += len(tokenizer.encode(lines[idx]))
                    if length_to_drop <= 0:
                        break
                
                # drop lines
                prompt = "\n".join([line for i, line in enumerate(lines) if i not in drop_line_idx])

            # query
            response = query_with_retry(prompt, max_tokens=max_new_tokens)

            # get first line
            pred = get_frist_line(response, language=language)

            if data_part == resume_part:
                data_idx += resume

            # save
            res_dic = {
                    "data_idx": data_idx,
                    "label": label,
                    "generated": pred,
                    "prompt_length": prompt_length,
                }
                
            directory = f"./results/codex-{length}/{language}/"
            os.makedirs(directory, exist_ok=True)

            with open(f"{directory}/{data_part}.jsonl", "a") as f:
                f.write(json.dumps(res_dic) + "\n")

if __name__ == "__main__":
    fire.Fire(main)
