#!/usr/bin/env python
# coding: utf-8

import json
import os
import chardet
from tqdm.auto import tqdm
import random
import argparse
from anthropic import HUMAN_PROMPT, AI_PROMPT
from utils import anthropic_call, AutoContextManager, ingest_directory_contents
from tempfile import TemporaryDirectory
import re


def detect_encoding(filename):
    with open(filename, "rb") as file:
        rawdata = file.read()
    return chardet.detect(rawdata)["encoding"]


def extract_diff(response):
    """
    Extracts the diff from a response formatted in different ways
    """
    if response is None:
        return None
    diff_matches = []
    other_matches = []
    pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
    for code, match in pattern.findall(response):
        if code in {"diff", "patch"}:
            diff_matches.append(match)
        else:
            other_matches.append(match)
    if diff_matches:
        return diff_matches[0]
    if other_matches:
        return other_matches[0]
    pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
    for code, match in pattern.findall(response):
        if code == "patch":
            other_matches = [match] + other_matches
        else:
            other_matches.append(match)
    if other_matches:
        return other_matches[0]
    return response.split("</s>")[0]


def main(task_instances, k, root_dir, api_key, output_file):
    root_dir = os.path.abspath(root_dir)
    assert os.path.exists(root_dir)
    instances = [json.loads(line) for line in open(task_instances)]
    random.seed(42)
    if k:
        instances = random.sample(instances, k=k)
    instances = {instance["instance_id"]: instance for instance in instances}
    temp_dir = TemporaryDirectory()
    root_dir = temp_dir.name
    contents = dict()
    for instance_id, instance in instances.items():
        with AutoContextManager(instance, root_dir) as repo:
            contents[instance_id] = ingest_directory_contents(repo.repo_path)
    all_texts = dict()
    for instance_id, root_contents in contents.items():
        all_text = ""
        for filepath in sorted(root_contents.keys()):
            if "test" in filepath:
                continue
            content = root_contents[filepath]
            all_text += "filepath: " + filepath + "\n"
            all_text += content + "\n\n"
        all_texts[instance_id] = all_text
    all_text_lengths = {instance_id: len(all_text.split()) for instance_id, all_text in all_texts.items()}
    print(f"Average number of tokens: {sum(all_text_lengths.values()) / len(all_text_lengths)}")
    instance_outputs = dict()
    for instance_id, all_text in tqdm(all_texts.items(), total=len(all_texts)):
        issue = instances[instance_id]["problem_statement"]
        repo_metadata = ''  # TODO: add repo metadata, like README, etc.
        environment_metadata = ''  # TODO: add environment metadata, like libraries and versions
        file_tree = '' # TODO: add file tree from repo generated by python module ast
        text_inputs = f"{HUMAN_PROMPT} I will provide you with my code base and an issue that I need you to resolve.  <code>\n{all_text}\n</code>\n\n<issue>\n{issue}\n<\issue>\n\n{HUMAN_PROMPT} Now I need you to help solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file since I'm going to check your first code response.\n\n{AI_PROMPT}"
        try:
            outputs = anthropic_call(api_key, "claude-2", text_inputs, None)
        except Exception as e:
            print(e)
            outputs = None
        diff = extract_diff(outputs)
        instance_outputs[instance_id] = {
            "instance_id": instance_id,
            "text_inputs": text_inputs,
            "all_text": all_text,
            "issue": issue,
            "output": outputs,
            "diff": diff,
        }
    json.dump(instance_outputs, open(output_file, "w"))
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task-instances",
        type=str,
        required=True,
        help="Path to the task instances JSONL file.",
    )
    parser.add_argument(
        "--k",
        type=int,
        help="Number of instances to sample randomly.",
    )
    parser.add_argument(
        "--root-dir", type=str, required=True, help="Path to the root directory."
    )
    parser.add_argument("--api-key", type=str, required=True, help="Anthropic API key.")
    parser.add_argument(
        "--output-file", type=str, help="Path to output JSON file.", required=True
    )
    args = parser.parse_args()
    main(**vars(args))
