import json
import sys
from pathlib import Path
from tqdm.auto import tqdm
from transformers import LlamaTokenizer
from utils import string_to_bool
from prompting.create_instance import add_text_inputs, PROMPT_FUNCTIONS
from argparse import ArgumentParser
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


def main(
    instances_file,
    model_name_or_path,
    github_token,
    retrieval_dir,
    prompt_style,
    file_source,
    k,
    output_dir,
    python_only_patch,
    hf_token,
):
    tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, token=hf_token)
    output_filename = f"{Path(instances_file).name}.ps-{prompt_style}__tok-{type(tokenizer).__name__}__po-{str(int(python_only_patch))}__fs-{file_source}"
    if k is not None:
        output_filename += f"__k-{k}"
    output_filename = Path(output_dir, output_filename + ".jsonl")
    if Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True, exist_ok=True)
    del tokenizer
    task_instances = [json.loads(line) for line in open(instances_file, "r")]
    input_instances = {instance["instance_id"]: instance for instance in task_instances}
    existing_ids = set()
    unknown_ids = set()
    if output_filename.exists():
        logger.info(f"Reading existing file {output_filename.as_posix()}")
        with open(output_filename) as f:
            for line in f:
                data = json.loads(line)
                instance_id = data["instance_id"]
                if (
                    instance_id in input_instances
                    and "input_ids" in data
                    and "labels" in data
                ):
                    existing_ids.add(instance_id)
                else:
                    unknown_ids.add(instance_id)
        logger.info(f"Read {len(existing_ids)} already completed ids")
        logger.info(f"Found {len(unknown_ids)} unknown instance ids in file")
    for key in existing_ids:
        del input_instances[key]
    if len(input_instances) == 0:
        print("Nothing left to do!")
        sys.exit(0)
    add_text_inputs(
        input_instances=input_instances,
        retrieval_dir=retrieval_dir,
        k=k,
        github_token=github_token,
        prompt_style=prompt_style,
        file_source=file_source,
        python_only=python_only_patch,
    )
    tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, token=hf_token)
    with open(output_filename, "+a") as f:
        print(f"Writing to {output_filename.as_posix()}")
        for instance_id, instance in tqdm(
            input_instances.items(), total=len(input_instances), desc="Tokenizing"
        ):
            if instance_id in existing_ids:
                continue
            if instance["text_inputs"] is None or instance["patch"] is None:
                print(f"No text for {instance_id}")
                continue
            text_inputs = instance["text_inputs"].strip() + '\n\n'
            if text_inputs is None or instance["patch"] is None:
                print(f"No inputs for {instance_id}")
                continue
            strlength = len(text_inputs.split())
            if strlength > 200_000:
                print(
                    f"Skipping {instance_id}. Way too long {strlength:_} word length (not token)"
                )
                continue
            patch = '\n'.join([f"<patch>", instance['patch'], "</patch>", tokenizer.eos_token])
            input_ids = tokenizer(
                text_inputs, add_special_tokens=False, return_attention_mask=False
            )["input_ids"]
            label_ids = tokenizer(
                    patch, add_special_tokens=False, return_attention_mask=False
                )["input_ids"]
            label_ids = label_ids
            instance["input_ids"] = input_ids
            instance["labels"] = label_ids
            print(json.dumps(instance), file=f, flush=True)
    print(f"Wrote to {output_filename.as_posix()}")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--instances_file", required=True, type=str)
    parser.add_argument(
        "--model_name_or_path", type=str, required=True, help="Path to the model."
    )
    parser.add_argument("--github_token", type=str, help="GitHub token.")
    parser.add_argument(
        "--retrieval_dir",
        type=str,
        help="Path to the directory where the retrieval results are stored.",
    )
    parser.add_argument(
        "--prompt_style",
        type=str,
        default="style-1",
        choices=PROMPT_FUNCTIONS.keys(),
        help="Prompt style to use.",
    )
    parser.add_argument(
        "--file_source",
        type=str,
        default="oracle",
        choices=["oracle", "bm25"],
        help="Where to get the files from.",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=None,
        help="Maximum number of files to use for retrieval.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        required=True,
        help="Path to the output file.",
    )
    parser.add_argument(
        "--python_only_patch",
        type=string_to_bool,
        default=None,
        required=True,
        const=True,
        nargs="?",
        help="Path to the output file.",
    )
    parser.add_argument(
        "--hf_token",
        type=str,
        default=None,
    )
    args = parser.parse_args()
    main(**vars(args))
