import os
import json
import re
import pandas as pd

from collections import defaultdict


def parse_broken_json(x):
    improved_translation = ""
    errors = defaultdict(list)
    if '"errors": ' in x and "improved translation" in x:
        data = x.split('", "errors": ')
        if len(data) != 2:
            return {"improved translation": improved_translation, "errors": errors}
        improved_translation = data[0].split('"improved translation": "')[1]
        data[1] = data[1][:-1]

        try:
            errors = json.loads(data[1])
        except:
            words = re.findall(r"\b\w+\b", data[1].lower())
            keywords = ["critical", "major", "minor"]

            last_key = None
            for word in words:
                if word in keywords:
                    last_key = word
                elif last_key is not None and word == "class":
                    errors[last_key].append({"class": "other"})

    return {"improved translation": improved_translation, "errors": errors}


def parse_error_class(error):
    # parse error from error description, errors are ['accuracy', 'fluency', 'locale convention', 'style', 'terminology', 'non-translation', 'other']
    #  locale convention (currency, date, name, telephone, or time format), style (awkward), terminology (inappropriate for context, inconsistent use),
    class_name = "unknown"
    if "accuracy" in error:
        class_name = "accuracy"
        for subclass in ["addition", "mistranslation", "omission", "untranslated text"]:
            if subclass in error:
                class_name = f"accuracy-{subclass}"
    elif "fluency" in error:
        class_name = "fluency"
        for subclass in [
            "character encoding",
            "grammar",
            "inconsistency",
            "punctuation",
            "register",
            "spelling",
        ]:
            if subclass in error:
                class_name = f"fluency-{subclass}"
    elif "locale convention" in error:
        class_name = "locale convention"
        for subclass in ["currency", "date", "name", "telephone", "time"]:
            if subclass in error:
                class_name = f"locale convention-{subclass}"
    elif "style" in error:
        class_name = "style"
    elif "terminology" in error:
        class_name = "terminology"
        for subclass in ["inappropriate", "inconsistent"]:
            if subclass in error:
                class_name = f"terminology-{subclass}"
    elif "non-translation" in error:
        class_name = "non-translation"
    elif "other" in error:
        class_name = "other"

    return class_name


def parse_mqm_answer(x, list_mqm_errors=False, full_desc=True, list_all=False):
    if x is None:
        return None

    x = str(x)
    if x.startswith('{"improved translation"'):
        try:
            x = json.loads(x)
        except:
            x = parse_broken_json(x)
        errors = x["errors"]
    else:
        x = x.lower()
        errors = {"critical": [], "major": [], "minor": []}
        error_level = None
        for line in x.split("\n"):
            line = line.strip()
            if "no-error" in line or "no error" in line or "none" in line or "" == line:
                continue
            if (
                "critical:" == line
                or "**critical:**" == line
                or "```critical:" == line
                or "### critical:" == line
                or "**critical errors:**" == line
            ):
                error_level = "critical"
                continue
            elif (
                "major:" == line
                or "**major:**" == line
                or "```major:" == line
                or "### major:" == line
                or "**major errors:**" == line
            ):
                error_level = "major"
                continue
            elif (
                "minor:" == line
                or "**minor:**" == line
                or "```minor:" == line
                or "### minor:" == line
                or "**minor errors:**" == line
            ):
                error_level = "minor"
                continue

            if error_level is None:
                print(f"No error level for {line}")
                continue

            if "non-translation" in line:
                errors["critical"].append(line)
            else:
                errors[error_level].append(line)

    error_classes = defaultdict(list)
    for error_level in ["critical", "major", "minor"]:
        if error_level not in errors:
            continue
        for error in errors[error_level]:
            if full_desc:
                error_classes[error_level].append(error)
            else:
                class_name = parse_error_class(error)
                error_classes[error_level].append(class_name)

    clean_error_classes = defaultdict(list)
    for level in ["critical", "major", "minor"]:
        seen = set()
        for error in error_classes[level]:
            if error not in seen:
                seen.add(error)
                clean_error_classes[level].append(error)

    final_score = 0
    error_counter = 0
    for error_level in ["critical", "major", "minor"]:
        for error in clean_error_classes[error_level]:
            if error_counter < 5 and not list_mqm_errors:
                final_score += (
                    25
                    if error_level == "critical"
                    else 5 if error_level == "major" else 1
                )
                error_counter += 1
    if final_score > 25:
        final_score = 25

    if list_mqm_errors:
        return clean_error_classes
    elif list_all:
        return clean_error_classes, -final_score
    else:
        return -final_score


def process_single_json(file_path):
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        segment_scores = []
        all_errors = defaultdict(list)
        processed_items = []

        for item in data:
            if "output" not in item:
                continue

            output = item["output"]
            think_delimiter = "</think>\n\n"
            if think_delimiter in output:
                last_delimiter_pos = output.rfind(think_delimiter)
                parsed_output = output[last_delimiter_pos + len(think_delimiter) :]
            else:
                parsed_output = output

            explanation_patterns = [
                "\n\nExplanation",
                "\n\n**Explanation",
                "\n\n**Reasoning",
                "\n\n### Explanation",
                "\n\n### Output",
                "\n\n---",
                "\n**Strictly",
                "\n\n**Strict",
                "\n\n**Step-by-Step",
                "\n\n**Note",
                "\n\n**Rationale",
                "\n\n**formatted",
                "\n\n**Formatted",
            ]
            for pattern in explanation_patterns:
                if pattern in parsed_output:
                    explanation_pos = parsed_output.find(pattern)
                    parsed_output = parsed_output[:explanation_pos]
                    break

            errors, score = parse_mqm_answer(
                parsed_output, list_mqm_errors=False, full_desc=True, list_all=True
            )

            segment_scores.append(score)
            for level, error_list in errors.items():
                all_errors[level].extend(error_list)

            processed_items.append({"output": parsed_output, "error_classes": errors})

        doc_score = sum(segment_scores) / len(segment_scores) if segment_scores else 0

        return {
            "score": doc_score,
            "errors": dict(all_errors),
            "segment_count": len(segment_scores),
            "processed_items": processed_items,
            "segment_scores": segment_scores,
        }

    except Exception as e:
        print(f"Error in {file_path} as: {str(e)}")
        return None


def process_directory(
    directory_path,
    table_directory,
    clean_dictionary,
    output_directory=None,
    mo=None,
    lg=None,
):

    if output_directory is None:
        output_directory = os.path.join(directory_path, "processed")

    results = []
    sys_entries = []
    seg_entries = []

    os.makedirs(output_directory, exist_ok=True)

    for root, _, files in os.walk(directory_path):
        for file in files:
            if not file.endswith(".json"):
                continue

            file_path = os.path.join(root, file)
            result = process_single_json(file_path)

            if result:
                new_filename = f"{file}"
                new_file_path = os.path.join(clean_dictionary, new_filename)

                os.makedirs(os.path.dirname(new_file_path), exist_ok=True)

                with open(new_file_path, "w", encoding="utf-8") as f:
                    json.dump(
                        result["processed_items"], f, ensure_ascii=False, indent=4
                    )

                results.append(
                    {
                        "file_name": file,
                        "path": file_path,
                        "score": result["score"],
                        "segment_count": result["segment_count"],
                        "critical_errors": len(result["errors"].get("critical", [])),
                        "major_errors": len(result["errors"].get("major", [])),
                        "minor_errors": len(result["errors"].get("minor", [])),
                    }
                )

                base_name = os.path.splitext(file)[0]
                sys_entries.append((base_name, result["score"]))
                seg_scores = result["segment_scores"]
                for score in seg_scores:
                    seg_entries.append((base_name, score))

    df = pd.DataFrame(results)
    # df.to_excel(table_directory, index=False)

    if mo and lg:
        seg_filename = os.path.join(output_directory, f"{mo}_{lg}.seg.score")
        sys_filename = os.path.join(output_directory, f"{mo}_{lg}.sys.score")

        with open(sys_filename, "w", encoding="utf-8") as f_sys:
            for base, score in sys_entries:
                f_sys.write(f"{base}\t{score}\n")

        with open(seg_filename, "w", encoding="utf-8") as f_seg:
            for base, score in seg_entries:
                score_str = str(score) if score is not None else "None"
                f_seg.write(f"{base}\t{score_str}\n")

    return df


lang_pairs = ["en-de", "ja-zh", "en-es"]
models = ["qwq32b"]

for lg in lang_pairs:
    for mo in models:
        input_dir = f"lrm_results/{mo}/{lg}"
        clean_dir = f"lrm_results/clean/{mo}/{lg}"
        output_dir = f"lrm_results/scores/{mo}"
        table_dir = f"lrm_results/clean/{mo}_{lg}_quality_report.xlsx"

        results_df = process_directory(
            input_dir, table_dir, clean_dir, output_dir, mo=mo, lg=lg
        )
