#!/usr/bin/env python
# coding: utf-8

import json
import os
from tqdm.auto import tqdm
import random
import time
from argparse import ArgumentParser
from anthropic import HUMAN_PROMPT, AI_PROMPT, Anthropic
from utils import ContextManager, extract_diff
from prompting.extract_context import get_gold_filenames


def anthropic_call(api_key, model_name, inputs, history, *args):
    anthropic = Anthropic(api_key=api_key)
    completion = anthropic.completions.create(
        model=model_name,
        max_tokens_to_sample=2000,
        prompt=inputs,
        *args,
    )
    return completion.completion



def main(task_instances, root_dir, api_key, output_dir, continue_from, k):
    root_dir = os.path.abspath(root_dir)
    assert os.path.exists(root_dir)
    with open(task_instances, "r") as f:
        instances = [json.loads(line) for line in f]
    random.seed(42)
    if k is not None:
        if k > len(instances):
            print(f"Warning: k is larger than the number of instances ({len(instances)}), setting k to {len(instances)}")
        instances = random.sample(instances, k=min(k, len(instances)))
    instances = {instance["instance_id"]: instance for instance in instances}
    instance_outputs = set()
    if continue_from is not None:
        if os.path.exists(continue_from):
            output_file = continue_from
        elif os.path.exists(os.path.join(output_dir, continue_from)):
            output_file = os.path.join(output_dir, continue_from)
        else:
            raise ValueError(f"continue_from {continue_from} not found")
        print(f"Continuing from {output_file}")
    else:
        output_file = os.path.join(output_dir, task_instances.split("/")[-1] + ".claude2")
        if k is not None:
            output_file += f"_k{k}."
        else:
            output_file += "."
        output_file += f"{time.strftime('%Y%m%d_%H%M%S')}.jsonl"
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            for line in f:
                instance_outputs.add(json.loads(line)["instance_id"])
        print(f"Loaded {len(instance_outputs)} outputs from {output_file}")
    all_texts = dict()
    failed = list()
    for instance_id, instance in tqdm(instances.items(), total=len(instances), desc="Generating prompts", leave=False):
        if instance_id in instance_outputs:
            continue
        try:
            with ContextManager(root_dir, instance["base_commit"]) as cm:
                all_text = ""
                readmes = list(sorted(cm.get_readme_files()))
                if len(readmes) == 0:
                    print(f"Warning: no README found for {instance_id}")
                    continue
                gold_files = list(sorted(get_gold_filenames(instance)))
                if len(gold_files) == 0:
                    print(f"Warning: no gold files found for {instance_id}")
                    continue
                for gold_file in [*readmes, *gold_files]:
                    try:
                        with open(gold_file, "r") as f:
                            filename = os.path.abspath(gold_file)[len(root_dir) + 1 :]
                            all_text += f"Filepath: {filename}\n"
                            for ix, line in enumerate(f):
                                all_text += f"{ix + 1} {line}\n"
                            all_text += "\n"
                    except:
                        print(f"Warning: unable to read {gold_file} in {instance_id}")
                        continue
                all_texts[instance_id] = all_text
        except:
            failed.append(instance_id)
    all_text_lengths = {
        instance_id: len(all_text.split())
        for instance_id, all_text in all_texts.items()
    }
    print(
        f"Average number of tokens: {sum(all_text_lengths.values()) / len(all_text_lengths)}"
    )
    print(f'Failed to generate prompts for {len(failed)} instances:\n {", ".join(failed)}')
    failed_prompts = len(failed)
    with open(output_file, "a+") as f:
        for instance_id, all_text in tqdm(all_texts.items(), total=len(all_texts), desc="Invoking model", leave=True):
            if instance_id in instance_outputs:
                continue
            try:
                issue = instances[instance_id]["problem_statement"]
                repo_metadata = ""  # TODO: add repo metadata, like README, etc.
                environment_metadata = (
                    ""  # TODO: add environment metadata, like libraries and versions
                )
                file_tree = (
                    ""  # TODO: add file tree from repo generated by python module ast
                )
                text_inputs = f"{HUMAN_PROMPT} I will provide you with my code base and an issue that I need you to resolve.  <code>\n{all_text}\n</code>\n\n<issue>\n{issue}\n<\issue>\n\n{HUMAN_PROMPT} Now I need you to help solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file since I'm going to check your first code response.\n\n{AI_PROMPT}"
                try:
                    outputs = anthropic_call(api_key, "claude-2", text_inputs, None)
                except Exception as e:
                    print(e)
                    outputs = None
                diff = extract_diff(outputs)
                instance_output = {
                    "instance_id": instance_id,
                    "inputs": text_inputs,
                    "output": outputs,
                    "diff": diff,
                }
                print(json.dumps(instance_output), file=f, flush=True)
            except Exception as e:
                print(e)
                failed.append(instance_id)
                continue
    print(f'Failed to produce outputs for {len(failed) - failed_prompts} instances:\n {", ".join(failed)}')
    print(f"Failed prompts: {failed_prompts}, Failed outputs: {len(failed) - failed_prompts}")
    print(f"Total failed: {len(failed)} / {len(instances)}")
    print("Done!")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--task_instances",
        type=str,
        required=True,
        help="Path to the task instances JSONL file.",
    )
    parser.add_argument(
        "--root_dir", type=str, required=True, help="Path to the root directory."
    )
    parser.add_argument("--api_key", type=str, required=True, help="Anthropic API key.")
    parser.add_argument(
        "--output_dir", type=str, default='./', help="Path to the output directory."
    )
    parser.add_argument(
        "--continue_from",
        type=str,
        default=None,
        help="Path to the output file to continue from.",
    )
    parser.add_argument('--k', type=int, help='If provided, only use k random instances', default=None)
    args = parser.parse_args()
    main(**vars(args))
