import argparse
import json
import logging
import os
from typing import Dict, List

import numpy as np
import pandas as pd
from sklearn.metrics import cohen_kappa_score
import krippendorff
from statsmodels.stats.inter_rater import fleiss_kappa

logger = logging.getLogger(__name__)

ACTION_CATEGORY_ASSERT = "triple_assertion"
ACTION_CATEGORY_DEPRECATE = "triple_deprecation"


def load_annotated_instances(annotated_instances: List[Dict]) -> pd.DataFrame:
    """Load and flatten annotated instances into a DataFrame."""
    rows = []
    counts = {}

    for inst in annotated_instances:
        for triple in inst.get("tkgu_triples", []):
            human_ass = triple.get("human_assessment", [])
            llm_ass = triple.get("llm_assessment", [])

            if not human_ass:
                continue
            if not llm_ass:
                logger.error(
                    f"Missing llm_assessment in {inst['hash_id']} for triple {triple}"
                )
                continue

            for h in human_ass:
                for l in llm_ass:
                    if h["prompt_type"] != l["llm_prompt_type"]:
                        continue

                    for operation in triple.get("tkgu_operations", []):
                        valid = (
                            (operation == "d-triples" and l["llm_prompt_type"] == ACTION_CATEGORY_DEPRECATE)
                            or (
                                operation != "d-triples"
                                and l["llm_prompt_type"] == ACTION_CATEGORY_ASSERT
                            )
                        )
                        if not valid:
                            continue

                        key = (h["annotator_name"], operation)
                        counts[key] = counts.get(key, 0) + 1

                        rows.append(
                            {
                                "hash_id": inst["hash_id"],
                                "passage": inst.get("passage"),
                                "human_readable_triple": h.get("human_readable_triple"),
                                "definition_relation": h.get("definition_relation"),
                                "annotator_name": h["annotator_name"],
                                "llm_name": l["llm_name"],
                                "tkgu_operation": operation,
                                "prompt_type": l["llm_prompt_type"],
                                "llm_assessment": l["llm_assessment"],
                                "llm_prompt": l.get("llm_prompt"),
                                "human_assessment": h["assessment"],
                            }
                        )

    logger.info(f"Annotation counts per human/operation: {counts}")
    return pd.DataFrame(rows)


def return_paper_stats(df: pd.DataFrame, human_names: List[str]) -> pd.DataFrame:
    """Compute agreement statistics between two humans and LLM."""
    assert len(human_names) == 2

    subset = df.groupby(["tkgu_operation", "hash_id", "human_readable_triple"]).filter(
        lambda x: set(x["human_name"]) == set(human_names) and x["llm_assessment"].notna().all()
    )

    human_wide = subset.pivot(
        index=["tkgu_operation", "hash_id", "human_readable_triple"],
        columns="human_name",
        values="human_assessment",
    )

    llm_wide = (
        subset.drop_duplicates(
            subset=["tkgu_operation", "hash_id", "human_readable_triple"]
        )
        .set_index(["tkgu_operation", "hash_id", "human_readable_triple"])[
            ["llm_assessment"]
        ]
    )

    dfw = human_wide.join(llm_wide).astype(int)
    results = []

    for op, group in dfw.groupby("tkgu_operation"):
        data = group.to_numpy()
        rating_matrix = np.array([np.bincount(row, minlength=2) for row in data])

        fleiss = fleiss_kappa(rating_matrix)
        alpha = krippendorff.alpha(reliability_data=data.T, level_of_measurement="nominal")

        h_h = cohen_kappa_score(group[human_names[0]], group[human_names[1]])
        h1_l = cohen_kappa_score(group[human_names[0]], group["llm_assessment"])
        h2_l = cohen_kappa_score(group[human_names[1]], group["llm_assessment"])

        results.append(
            {
                "Operation": op,
                "H-H cohen": h_h,
                "H1-LLM cohen": h1_l,
                "H2-LLM cohen": h2_l,
                "H+LLM fleiss": fleiss,
                "H+LLM kripp": alpha,
            }
        )

    # Overall
    data_all = dfw.to_numpy()
    rm_all = np.array([np.bincount(row, minlength=2) for row in data_all])
    overall = {
        "Operation": "Overall",
        "H-H cohen": cohen_kappa_score(dfw[human_names[0]], dfw[human_names[1]]),
        "H1-LLM cohen": cohen_kappa_score(dfw[human_names[0]], dfw["llm_assessment"]),
        "H2-LLM cohen": cohen_kappa_score(dfw[human_names[1]], dfw["llm_assessment"]),
        "H+LLM fleiss": fleiss_kappa(rm_all),
        "H+LLM kripp": krippendorff.alpha(data_all.T, level_of_measurement="nominal"),
    }

    return pd.DataFrame(results + [overall])


def get_annotation_statistics(annotated_instances: List[Dict], annotator_names: List[str]) -> pd.DataFrame:
    return load_annotated_instances(annotated_instances)


def merge_annotations(instance1: Dict, instance2: Dict) -> Dict:
    """Merge human + LLM assessments from two annotation files."""
    t1, t2 = instance1["tkgu_triples"], instance2["tkgu_triples"]
    assert len(t1) == len(t2)

    for a, b in zip(t1, t2):
        assert a["triple"] == b["triple"]
        assert a["tkgu_operations"] == b["tkgu_operations"]

        # Merge human
        if b.get("human_assessment"):
            existing = {(h["annotator_name"], h["prompt_type"]) for h in a.get("human_assessment", [])}
            for h in b["human_assessment"]:
                key = (h["annotator_name"], h["prompt_type"])
                if key not in existing:
                    a.setdefault("human_assessment", []).append(h)

        # Merge LLM
        if b.get("llm_assessment"):
            existing = {(l["llm_name"], l["llm_prompt_type"]) for l in a.get("llm_assessment", [])}
            for l in b["llm_assessment"]:
                key = (l["llm_name"], l["llm_prompt_type"])
                if key not in existing:
                    a.setdefault("llm_assessment", []).append(l)

    return instance1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        type=str,
        default="anno_stats_config.json",
        help="Path to config file.",
    )
    args = parser.parse_args()

    pd.set_option("display.max_columns", None)
    pd.set_option("display.max_colwidth", 20)
    pd.set_option("display.width", 200)

    config = json.load(open(args.config_file))
    input_paths = config["input_annotation_paths"]
    human_annotators = config["human_annotators"]
    human_names = [h["annotator"] for h in human_annotators]

    logger.info("Loading annotations...")
    instances = {}

    for d in input_paths:
        for fn in os.listdir(d):
            with open(os.path.join(d, fn), "rt", encoding="utf-8") as f:
                for line in f:
                    obj = json.loads(line)
                    hid = obj["hash_id"]
                    if hid not in instances:
                        instances[hid] = obj
                    else:
                        instances[hid] = merge_annotations(instances[hid], obj)

    df = get_annotation_statistics(list(instances.values()), human_names)
    df = df.rename(columns={"annotator_name": "human_name"})

    llms_d = set(config["llms_to_compare"]["triple_deprecation"])
    llms_a = set(config["llms_to_compare"]["triple_assertion"])

    df = df[
        ((df["tkgu_operation"] == "d-triples") & df["llm_name"].isin(llms_d))
        | ((df["tkgu_operation"] != "d-triples") & df["llm_name"].isin(llms_a))
    ]

    df = df[[
        "human_name",
        "human_assessment",
        "llm_assessment",
        "tkgu_operation",
        "hash_id",
        "human_readable_triple",
    ]]

    stats = return_paper_stats(df, human_names)
    stats = stats.round(3)

    order = ["x-triples", "e-triples", "ee-triples", "ee-kg-triples", "d-triples", "Overall"]

    def fmt(op: str) -> str:
        parts = op.split("-")
        acr = {"EE", "KG", "D", "E", "X"}
        out = []
        for p in parts:
            out.append(p.upper() if p.upper() in acr else p.capitalize())
        return "-".join(out)

    stats["Operation_latex"] = stats["Operation"].apply(fmt)
    stats = stats.set_index("Operation").loc[order].reset_index()

    rows = []
    for _, r in stats.iterrows():
        rows.append(
            f"{r['Operation_latex']} & "
            f"{r['H-H cohen']:.3f} & {r['H1-LLM cohen']:.3f} & {r['H2-LLM cohen']:.3f} & "
            f"{r['H+LLM fleiss']:.3f} & {r['H+LLM kripp']:.3f} \\")

    latex = (
        """
\\begin{tabular}{lccccc}
\\toprule
\\shortstack{TKGU \\ Operation} &
\\shortstack{H-H \\ Cohen's $\\kappa$} &
\\shortstack{H1-LLM \\ Cohen's $\\kappa$} &
\\shortstack{H2-LLM \\ Cohen's $\\kappa$} &
\\shortstack{H+LLM \\ Fleiss' $\\kappa$} &
\\shortstack{H+LLM \\ Kripp. $\\alpha$} \\
\\midrule
"""
        + "\n".join(rows)
        + """
\\bottomrule
\\end{tabular}
"""
    )

    print(latex)
    # Friendly human-readable output
    print("\n=== Agreement Summary (Human-Friendly) ===")
    print(stats.to_string(index=False))
    print("==========================================\n")
    logger.info("Finished.")
