#!/usr/bin/env python
# coding: utf-8

import json
import os
from tqdm.auto import tqdm
import random
import time
import torch
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import ContextManager, extract_diff
from prompting.extract_context import get_gold_filenames


def main(task_instances, root_dir, api_key, output_dir, continue_from, k, model_path):
    root_dir = os.path.abspath(root_dir)
    assert os.path.exists(root_dir)
    with open(task_instances, "r") as f:
        instances = [json.loads(line) for line in f]
    random.seed(42)
    if k is not None:
        if k > len(instances):
            print(
                f"Warning: k is larger than the number of instances ({len(instances)}), setting k to {len(instances)}"
            )
        instances = random.sample(instances, k=min(k, len(instances)))
    instances = {instance["instance_id"]: instance for instance in instances}
    instance_outputs = set()
    if continue_from is not None:
        if os.path.exists(continue_from):
            output_file = continue_from
        elif os.path.exists(os.path.join(output_dir, continue_from)):
            output_file = os.path.join(output_dir, continue_from)
        else:
            raise ValueError(f"continue_from {continue_from} not found")
        print(f"Continuing from {output_file}")
    else:
        output_file = os.path.join(
            output_dir, task_instances.split("/")[-1] + ".claude2"
        )
        if k is not None:
            output_file += f"_k{k}."
        else:
            output_file += "."
        output_file += f"{time.strftime('%Y%m%d_%H%M%S')}.jsonl"
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            for line in f:
                instance_outputs.add(json.loads(line)["instance_id"])
        print(f"Loaded {len(instance_outputs)} outputs from {output_file}")
    print(f"Loading tokenizer from {model_path}")

    print(f"Loading model from {model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        dtype=torch.float16,
        device_map="auto",
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    all_texts = dict()
    failed = list()
    for instance_id, instance in tqdm(
        instances.items(), total=len(instances), desc="Generating prompts", leave=False
    ):
        if instance_id in instance_outputs:
            continue
        try:
            with ContextManager(root_dir, instance["base_commit"]) as cm:
                all_text = ""
                readmes = list(sorted(cm.get_readme_files()))
                if len(readmes) == 0:
                    print(f"Warning: no README found for {instance_id}")
                    continue
                gold_files = list(sorted(get_gold_filenames(instance)))
                if len(gold_files) == 0:
                    print(f"Warning: no gold files found for {instance_id}")
                    continue
                for gold_file in [*readmes, *gold_files]:
                    try:
                        with open(gold_file, "r") as f:
                            filename = os.path.abspath(gold_file)[len(root_dir) + 1 :]
                            all_text += f"Filepath: {filename}\n"
                            for ix, line in enumerate(f):
                                all_text += f"{ix + 1} {line}\n"
                            all_text += "\n"
                    except:
                        print(f"Warning: unable to read {gold_file} in {instance_id}")
                        continue
                all_texts[instance_id] = all_text
        except:
            failed.append(instance_id)
    all_text_lengths = {
        instance_id: len(all_text.split())
        for instance_id, all_text in all_texts.items()
    }
    print(
        f"Average number of tokens: {sum(all_text_lengths.values()) / len(all_text_lengths)}"
    )
    print(
        f'Failed to generate prompts for {len(failed)} instances:\n {", ".join(failed)}'
    )
    failed_prompts = len(failed)
    with open(output_file, "a+") as f:
        for instance_id, all_text in tqdm(
            all_texts.items(), total=len(all_texts), desc="Invoking model", leave=True
        ):
            if instance_id in instance_outputs:
                continue
            try:
                issue = instances[instance_id]["problem_statement"]
                repo_metadata = ""  # TODO: add repo metadata, like README, etc.
                environment_metadata = (
                    ""  # TODO: add environment metadata, like libraries and versions
                )
                file_tree = (
                    ""  # TODO: add file tree from repo generated by python module ast
                )
                text_inputs = f"I will provide you with my code base and an issue that I need you to resolve.  <code>\n{all_text}\n</code>\n\n<issue>\n{issue}\n<\issue>\n\n Now I need you to help solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file since I'm going to check your first code response.\n\n"
                try:
                    with torch.no_grad():
                        inputs = tokenizer(text_inputs, return_tensors="pt")
                        outputs = model.generate(
                            inputs["input_ids"].to(model.device),
                            attention_mask=inputs["attention_mask"].to(model.device),
                            max_length=1024,
                            do_sample=True,
                            top_k=50,
                            top_p=0.95,
                            num_return_sequences=1,
                        )
                        outputs = 
                except Exception as e:
                    print(e)
                    outputs = None
                diff = extract_diff(outputs)
                instance_output = {
                    "instance_id": instance_id,
                    "inputs": text_inputs,
                    "output": outputs,
                    "diff": diff,
                }
                print(json.dumps(instance_output), file=f, flush=True)
            except Exception as e:
                print(e)
                failed.append(instance_id)
                continue
    print(
        f'Failed to produce outputs for {len(failed) - failed_prompts} instances:\n {", ".join(failed)}'
    )
    print(
        f"Failed prompts: {failed_prompts}, Failed outputs: {len(failed) - failed_prompts}"
    )
    print(f"Total failed: {len(failed)} / {len(instances)}")
    print("Done!")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--task_instances",
        type=str,
        required=True,
        help="Path to the task instances JSONL file.",
    )
    parser.add_argument(
        "--root_dir", type=str, required=True, help="Path to the root directory."
    )
    parser.add_argument("--api_key", type=str, required=True, help="Anthropic API key.")
    parser.add_argument(
        "--output_dir", type=str, default="./", help="Path to the output directory."
    )
    parser.add_argument(
        "--continue_from",
        type=str,
        default=None,
        help="Path to the output file to continue from.",
    )
    parser.add_argument(
        "--k", type=int, help="If provided, only use k random instances", default=None
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=None,
        help="Path to the model shards + index.",
        default="./models/togethercomputer__LLaMA-2-7B-32K",
    )
    args = parser.parse_args()
    main(**vars(args))
