from pathlib import Path
import json
import re
import traceback
from copy import deepcopy
from tqdm.auto import tqdm
from argparse import ArgumentParser
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


def extract_diff(response):
    """
    Extracts the diff from a response formatted in different ways
    """
    if response is None:
        return None
    diff_matches = []
    other_matches = []
    pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
    for code, match in pattern.findall(response):
        if code in {"diff", "patch"}:
            diff_matches.append(match)
        else:
            other_matches.append(match)
    if diff_matches:
        return diff_matches[0]
    if other_matches:
        return other_matches[0]
    pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
    for code, match in pattern.findall(response):
        if code == "patch":
            other_matches = [match] + other_matches
        else:
            other_matches.append(match)
    if other_matches:
        return other_matches[0]
    return response.split("</s>")[0]


def main(predictions_path, output_file):
    predictions = []
    with open(predictions_path, "r") as f:
        for line in f:
            predictions.append(json.loads(line))

    restore_patch = True
    if output_file is None:
        output_file = predictions_path
        restore_patch = False
    try:
        with open(output_file, "w") as f:
            for prediction in tqdm(predictions):
                model_patch = extract_diff(prediction["full_output"])
                if model_patch is not None and model_patch == prediction['full_output'].split("</s>")[0]:
                    if not (prediction['model_patch'] == prediction['full_output'].split("</s>")[0] or prediction['model_patch'] is None or prediction['model_patch'] == prediction['full_output'] or prediction['model_patch'] == ""):
                        model_patch = prediction['model_patch']
                new_prediction = deepcopy(prediction)
                new_prediction['model_patch'] = model_patch
                print(json.dumps(new_prediction), file=f, flush=True)
        logger.warning(f"Completed patch extraction, wrote to {output_file}")
    except Exception as e:
        traceback.print_exc()
        if not restore_patch:
            raise e
        logger.exception("Error while extracting patch, restoring original file")
        with open(predictions_path, "w") as f:
            for prediction in tqdm(predictions):
                print(json.dumps(prediction), file=f, flush=True)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--predictions_path", type=str, required=True)
    parser.add_argument("--output_file", type=str, default=None)
    args = parser.parse_args()
    main(**vars(args))