#!/usr/bin/env python
# coding: utf-8

import json
import os
from tqdm.auto import tqdm
import random
import time
import torch
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import ContextManager, extract_diff
from prompting.extract_context import get_gold_filenames
from prompting.summarize_context import get_summarized_files

import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)


def main(task_instances, root_dir, output_dir, continue_from, k, model_path):
    root_dir = os.path.abspath(root_dir)
    assert os.path.exists(root_dir)
    with open(task_instances, "r") as f:
        instances = [json.loads(line) for line in f]
    random.seed(42)
    if k is not None:
        if k > len(instances):
            logger.warn(
                f"Warning: k is larger than the number of instances ({len(instances)}), setting k to {len(instances)}"
            )
        instances = random.sample(instances, k=min(k, len(instances)))
    instances = {instance["instance_id"]: instance for instance in instances}
    instance_outputs = set()
    if continue_from is not None:
        if os.path.exists(continue_from):
            output_file = continue_from
        elif os.path.exists(os.path.join(output_dir, continue_from)):
            output_file = os.path.join(output_dir, continue_from)
        else:
            raise ValueError(f"continue_from {continue_from} not found")
        logger.info(f"Continuing from {output_file}")
    else:
        output_file = os.path.join(
            output_dir, task_instances.split("/")[-1].rsplit('.', 1)[0] + '=' +  model_path.replace("/", "__")
        )
        if k is not None:
            output_file += f"=k{k}."
        else:
            output_file += "."
        output_file += f"={time.strftime('%Y%m%d_%H%M%S')}.jsonl"
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            for line in f:
                instance_outputs.add(json.loads(line)["instance_id"])
        logger.info(f"Loaded {len(instance_outputs)} outputs from {output_file}")
    logger.info(f"Loading tokenizer from {model_path}")

    logger.info(f"Loading model from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    print(model.hf_device_map)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    all_texts = dict()
    failed_prompts = list()
    for instance_id, instance in tqdm(
        instances.items(), total=len(instances), desc="Generating prompts", leave=False
    ):
        if instance_id in instance_outputs:
            continue
        try:
            with ContextManager(root_dir, instance["base_commit"]) as cm:
                all_text = ""
                readmes = list(sorted(cm.get_readme_files()))
                if len(readmes) == 0:
                    logger.warn(f"Warning: no README found for {instance_id}")
                    continue
                summarized_files = get_summarized_files(instance, root_dir)
                if len(summarized_files) == 0:
                    logger.warn(f"Warning: no gold files found for {instance_id}")
                    continue
                for readme in readmes:
                    try:
                        with open(readme, "r") as f:
                            filename = os.path.abspath(readme)[len(root_dir) + 1 :]
                            all_text += f"Filepath: {filename}\n"
                            for ix, line in enumerate(f):
                                all_text += f"{ix + 1} {line}"
                            all_text += "\n"
                    except:
                        logger.warn(f"Warning: unable to read {readme} in {instance_id}")
                        continue
                for filename in summarized_files:
                    all_text += f"Filepath: {filename}\n"
                    all_text = summarized_files[filename]
                    all_text += "\n"
                all_texts[instance_id] = all_text
        except:
            failed_prompts.append(instance_id)
    all_text_lengths = {
        instance_id: len(all_text.split())
        for instance_id, all_text in all_texts.items()
    }
    logger.info(
        f"Average number of tokens: {sum(all_text_lengths.values()) / len(all_text_lengths)}"
    )
    logger.info(
        f'Failed to generate prompts for {len(failed_prompts)} instances:\n {", ".join(failed_prompts)}'
    )
    failed_instances = list()
    with open(output_file, "a+") as f:
        for instance_id, all_text in tqdm(
            all_texts.items(), total=len(all_texts), desc="Invoking model", leave=True
        ):
            if instance_id in instance_outputs:
                continue
            try:
                issue = instances[instance_id]["problem_statement"]
                repo_metadata = ""  # TODO: add repo metadata, like README, etc.
                environment_metadata = (
                    ""  # TODO: add environment metadata, like libraries and versions
                )
                file_tree = (
                    ""  # TODO: add file tree from repo generated by python module ast
                )
                text_inputs = f"I will provide you with my code base and an issue that I need you to resolve.\n<code>\n{all_text}\n</code>\n\n<issue>\n{issue}\n<\issue>\n\n Now I need you to help solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format since I'm going to check your first code response.\n\n<patch>\n--- /path/to/original/file.c\n+++ /path/to/new/file.c\n@@ -1,5 +1,6 @@\n #include <stdio.h>\n\n\n int main() {{\n-    printf(\"Hello, world!\n\");\n+    printf(\"Hello, wonderful world!\n\");\n+    return 0;\n }}\n<\patch>\n\n"
                try:
                    with torch.no_grad():
                        inputs = tokenizer(text_inputs, return_tensors="pt")
                        max_generation_length = 1024
                        max_input_length = tokenizer.model_max_length - max_generation_length
                        if inputs["input_ids"].shape[1] > max_input_length:
                            logger.warn(
                                f"Warning: input length {inputs['input_ids'].shape[1]} is longer than max input length {max_input_length}. Skipping."
                            )
                            outputs = None
                            failed_instances.append(instance_id)
                        else:
                            outputs = model.generate(
                                inputs["input_ids"].to(model.device),
                                attention_mask=inputs["attention_mask"].to(model.device),
                                max_new_tokens=max_generation_length,
                                do_sample=True,
                                # top_k=50,
                                top_p=0.95,
                                num_return_sequences=1,
                            )
                            outputs = [output[inputs["input_ids"].shape[1]:] for output in outputs]
                            outputs = tokenizer.batch_decode(
                                outputs, skip_special_tokens=True
                            )
                            diff = list(map(extract_diff, outputs))
                except Exception as e:
                    logger.error(f"Error generating output for {instance_id}: {str(e)}")
                    outputs = None
                    diff = None
                    failed_instances.append(instance_id)
                instance_output = {
                    "instance_id": instance_id,
                    "inputs": text_inputs,
                    "output": outputs,
                    "diff": diff,
                }
                print(json.dumps(instance_output), file=f, flush=True)
            except Exception as e:
                logger.error(f"Error generating output for {instance_id}: {str(e)}")
                failed_instances.append(instance_id)
                continue
    logger.info(
        f'Failed to produce outputs for {len(failed_instances) - len(all_texts)} instances:\n {", ".join(failed_instances)}'
    )
    logger.info(
        f"Failed prompts: {failed_prompts}, Failed outputs: {failed_instances}"
    )
    logger.info(f"Total failed: {len(failed_prompts) + len(failed_instances)} / {len(instances)}")
    logger.info("Done!")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--task_instances",
        type=str,
        required=True,
        help="Path to the task instances JSONL file.",
    )
    parser.add_argument(
        "--root_dir", type=str, required=True, help="Path to the root directory."
    )
    parser.add_argument(
        "--output_dir", type=str, default="./", help="Path to the output directory."
    )
    parser.add_argument(
        "--continue_from",
        type=str,
        default=None,
        help="Path to the output file to continue from.",
    )
    parser.add_argument(
        "--k", type=int, help="If provided, only use k random instances", default=None
    )
    parser.add_argument(
        "--model_path",
        type=str,
        help="Path to the model shards + index.",
        default="./models/togethercomputer__LLaMA-2-7B-32K",
    )
    args = parser.parse_args()
    main(**vars(args))
