import os
import json
import argparse
import difflib

ALLOWED_TYPES = ["Entity", "Activity", "Usage", "Generation"]


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def validate_jsonld(data_or_path):
    if isinstance(data_or_path, str):
        try:
            data = load_json(data_or_path)
        except json.decoder.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON-LD file: {e}")
    else:
        data = data_or_path

    for d in data:
        graph = d["@graph"]
        id_set = set()
        for item in graph:
            # Check if item has @type and is in allowed types
            if item["@type"] not in ALLOWED_TYPES:
                raise ValueError(f"Invalid JSON-LD type: {item['@type']}")
            # Check if item has @id and if it's unique
            if item_id := item.get("@id"):
                if item_id in id_set:
                    raise ValueError(f"Duplicate @id: {item_id}")
                id_set.add(item_id)
        # Check entity/activity id validity for Usage/Generation edges
        for item in graph:
            if item["@type"] in ["Usage", "Generation"]:
                entity_id = item.get("entity")
                activity_id = item.get("activity")
                if entity_id not in id_set:
                    raise ValueError(
                        f"Edge {item.get('@id')} references missing entity id: {entity_id}"
                    )
                if activity_id not in id_set:
                    raise ValueError(
                        f"Edge {item.get('@id')} references missing activity id: {activity_id}"
                    )
    return data


def evaluate(pred_path, true_path, verbose=False):
    if os.path.isdir(pred_path):
        pred_files = [
            os.path.join(pred_path, f)
            for f in os.listdir(pred_path)
            if f.endswith(".json")
        ]
    else:
        pred_files = [pred_path]

    if os.path.isdir(true_path):
        true_files = [
            os.path.join(true_path, f)
            for f in os.listdir(true_path)
            if f.endswith(".json")
        ]
    else:
        true_files = [true_path]

    pred_files_dict = {os.path.basename(f): f for f in pred_files}
    true_files_dict = {os.path.basename(f): f for f in true_files}

    common_files = list(set(pred_files_dict.keys()) & set(true_files_dict.keys()))
    common_files.sort()

    if not common_files:
        raise ValueError("No matching prediction and ground truth files found")

    all_pred_data = []
    all_true_data = []
    valid_files = []
    skipped_files = []

    for file_name in common_files:
        pred_file = pred_files_dict[file_name]
        true_file = true_files_dict[file_name]
        try:
            pred_data = validate_jsonld(pred_file)
            true_data = validate_jsonld(true_file)
            all_pred_data.append(pred_data)
            all_true_data.append(true_data)
            valid_files.append(file_name)
        except Exception as e:
            skipped_files.append((file_name, str(e)))

    if not valid_files:
        raise ValueError("No valid files to evaluate after validation.")

    results = evaluate_jsonld(all_pred_data, all_true_data, valid_files, verbose)
    if verbose:
        print(f"\nEvaluated files: {len(valid_files)} / {len(common_files)}")
        if skipped_files:
            print(
                "[Warning] The following files were skipped due to validation errors:"
            )
            for file_name, err in skipped_files:
                print(f"  {file_name}: {err}")
    return results


def evaluate_jsonld(pred_data, true_data, file_names, verbose=False):
    total_missing_samples = 0
    total_true_samples = 0
    total_matched_samples = 0

    # For graph entities extraction metrics
    total_true_nodes = 0
    total_pred_nodes = 0
    total_true_edges = 0
    total_pred_edges = 0
    total_true_positive_nodes = 0
    total_true_positive_edges = 0

    # For parameter extraction metrics
    total_true_params = 0
    total_pred_params = 0
    total_true_positive_params = 0

    for idx, (pred_samples, true_samples) in enumerate(zip(pred_data, true_data)):
        # For graph entities extraction metrics (per paper)
        paper_true_nodes = 0
        paper_pred_nodes = 0
        paper_true_edges = 0
        paper_pred_edges = 0
        paper_true_positive_nodes = 0
        paper_true_positive_edges = 0

        # For parameter extraction metrics (per paper)
        paper_true_params = 0
        paper_pred_params = 0
        paper_true_positive_params = 0

        total_true_samples += len(true_samples)
        if len(pred_samples) < len(true_samples):
            total_missing_samples += len(true_samples) - len(pred_samples)
        if verbose:
            print("-" * 80)
            print(f"# Evaluation per paper: {os.path.splitext(file_names[idx])[0]}")
            print(f"Number of true samples: {len(true_samples)}")
            print(f"Number of predicted samples: {len(pred_samples)}")
        pred_samples, true_samples = sample_match(pred_samples, true_samples)
        total_matched_samples += len(pred_samples)
        for pred, true in zip(pred_samples, true_samples):
            pred_graph = pred["@graph"]
            true_graph = true["@graph"]

            # Nodes (Entity, Activity)
            true_nodes = [
                item for item in true_graph if item["@type"] in ["Entity", "Activity"]
            ]
            pred_nodes = [
                item for item in pred_graph if item["@type"] in ["Entity", "Activity"]
            ]
            paper_true_nodes += len(true_nodes)
            paper_pred_nodes += len(pred_nodes)
            total_true_nodes += len(true_nodes)
            total_pred_nodes += len(pred_nodes)

            matched_pred_ids = set()
            for true_item in true_nodes:
                true_label = true_item["label"][0]["@value"]
                for pred_item in pred_nodes:
                    pred_label = pred_item["label"][0]["@value"]
                    pred_id = pred_item.get("@id")
                    if (
                        pred_item["@type"] == true_item["@type"]
                        and _labels_match(pred_label, true_label)
                        and pred_id not in matched_pred_ids
                    ):
                        paper_true_positive_nodes += 1
                        total_true_positive_nodes += 1
                        matched_pred_ids.add(pred_id)

                        # Parameter extraction accuracy
                        true_params = {
                            k: v[0]["@value"] if v and "@value" in v[0] else ""
                            for k, v in true_item.items()
                            if k.startswith("matprov:")
                        }
                        pred_params = {
                            k: v[0]["@value"] if v and "@value" in v[0] else ""
                            for k, v in pred_item.items()
                            if k.startswith("matprov:")
                        }

                        # Evaluate by key
                        for key in set(true_params.keys()).union(pred_params.keys()):
                            true_value = true_params.get(key, "")
                            pred_value = pred_params.get(key, "")
                            if key == "matprov:form":
                                true_label = (
                                    true_item["label"][0]["@value"]
                                    if "label" in true_item and true_item["label"]
                                    else ""
                                )
                                pred_label = (
                                    pred_item["label"][0]["@value"]
                                    if "label" in pred_item and pred_item["label"]
                                    else ""
                                )
                                if isinstance(true_value, list):
                                    true_form_norms = {
                                        _normalize_label(tv) for tv in true_value
                                    }
                                else:
                                    true_form_norms = {_normalize_label(true_value)}
                                pred_form_norms = {_normalize_label(pred_value)}
                                if isinstance(true_label, list):
                                    true_label_norm = {
                                        _normalize_label(tv) for tv in true_label
                                    }
                                else:
                                    true_label_norm = {_normalize_label(true_label)}
                                pred_label_norm = _normalize_label(pred_label)
                                # If any form value is included in the label, skip this parameter for both pred and true
                                if any(
                                    form and form in pred_label_norm
                                    for form in true_form_norms
                                ) or any(
                                    form and any(form in tln for tln in true_label_norm)
                                    for form in pred_form_norms
                                ):
                                    continue  # Do not count this parameter in recall/precision/f1
                            # Normal parameter counting
                            if true_value:
                                paper_true_params += 1
                                total_true_params += 1
                            if pred_value:
                                paper_pred_params += 1
                                total_pred_params += 1
                            if true_value and pred_value:
                                if isinstance(true_value, list):
                                    true_norms = {
                                        _normalize_label(tv) for tv in true_value
                                    }
                                else:
                                    true_norms = {_normalize_label(true_value)}
                                pred_norm = _normalize_label(pred_value)
                                if pred_norm in true_norms:
                                    paper_true_positive_params += 1
                                    total_true_positive_params += 1
                        break

            # Edges (Usage, Generation)
            true_edges = [
                item for item in true_graph if item["@type"] in ["Usage", "Generation"]
            ]
            pred_edges = [
                item for item in pred_graph if item["@type"] in ["Usage", "Generation"]
            ]
            paper_true_edges += len(true_edges)
            paper_pred_edges += len(pred_edges)
            total_true_edges += len(true_edges)
            total_pred_edges += len(pred_edges)

            matched_pred_edge_pair_ids = set()
            matched_true_edge_pair_ids = set()
            for true_item in true_edges:
                true_entity_item = _id2item(true_graph, true_item["entity"])
                true_activity_item = _id2item(true_graph, true_item["activity"])
                true_pair_id = (true_item.get("entity"), true_item.get("activity"))
                for pred_item in pred_edges:
                    if pred_item["@type"] == true_item["@type"]:
                        pred_entity_item = _id2item(pred_graph, pred_item["entity"])
                        pred_activity_item = _id2item(pred_graph, pred_item["activity"])
                        pred_pair_id = (
                            pred_item.get("entity"),
                            pred_item.get("activity"),
                        )
                        if (
                            pred_entity_item
                            and true_entity_item
                            and _labels_match(
                                pred_entity_item["label"][0]["@value"],
                                true_entity_item["label"][0]["@value"],
                            )
                            and pred_activity_item
                            and true_activity_item
                            and _labels_match(
                                pred_activity_item["label"][0]["@value"],
                                true_activity_item["label"][0]["@value"],
                            )
                            and pred_pair_id not in matched_pred_edge_pair_ids
                            and true_pair_id not in matched_true_edge_pair_ids
                        ):
                            paper_true_positive_edges += 1
                            total_true_positive_edges += 1
                            matched_pred_edge_pair_ids.add(pred_pair_id)
                            matched_true_edge_pair_ids.add(true_pair_id)
                            break

        # Per-paper node/edge and node+edge scores
        node_recall = (
            paper_true_positive_nodes / paper_true_nodes if paper_true_nodes else 0.0
        )
        node_precision = (
            paper_true_positive_nodes / paper_pred_nodes if paper_pred_nodes else 0.0
        )
        node_f1 = (
            2 * node_recall * node_precision / (node_recall + node_precision)
            if (node_recall + node_precision) > 0
            else 0.0
        )
        edge_recall = (
            paper_true_positive_edges / paper_true_edges if paper_true_edges else 0.0
        )
        edge_precision = (
            paper_true_positive_edges / paper_pred_edges if paper_pred_edges else 0.0
        )
        edge_f1 = (
            2 * edge_recall * edge_precision / (edge_recall + edge_precision)
            if (edge_recall + edge_precision) > 0
            else 0.0
        )
        # Node+Edge (combined) scores
        paper_true_items = paper_true_nodes + paper_true_edges
        paper_pred_items = paper_pred_nodes + paper_pred_edges
        paper_true_positive_items = (
            paper_true_positive_nodes + paper_true_positive_edges
        )
        item_recall = (
            paper_true_positive_items / paper_true_items if paper_true_items else 0.0
        )
        item_precision = (
            paper_true_positive_items / paper_pred_items if paper_pred_items else 0.0
        )
        item_f1 = (
            2 * item_recall * item_precision / (item_recall + item_precision)
            if (item_recall + item_precision) > 0
            else 0.0
        )
        # Parameter extraction scores
        param_recall = (
            paper_true_positive_params / paper_true_params if paper_true_params else 0.0
        )
        param_precision = (
            paper_true_positive_params / paper_pred_params if paper_pred_params else 0.0
        )
        param_f1 = (
            2 * param_recall * param_precision / (param_recall + param_precision)
            if (param_recall + param_precision) > 0
            else 0.0
        )

        if verbose:
            print("## Scores")
            print(
                f"[Node] Recall: {node_recall:.4f}, Precision: {node_precision:.4f}, F1: {node_f1:.4f}"
            )
            print(
                f"[Edge] Recall: {edge_recall:.4f}, Precision: {edge_precision:.4f}, F1: {edge_f1:.4f}"
            )
            print(
                f"[Structural] Recall: {item_recall:.4f}, Precision: {item_precision:.4f}, F1: {item_f1:.4f}"
            )
            print(
                f"[Parametric] Recall: {param_recall:.4f}, Precision: {param_precision:.4f}, F1: {param_f1:.4f}"
            )

    # Overall node/edge and node+edge scores
    total_node_recall = (
        total_true_positive_nodes / total_true_nodes if total_true_nodes else 0.0
    )
    total_node_precision = (
        total_true_positive_nodes / total_pred_nodes if total_pred_nodes else 0.0
    )
    total_node_f1 = (
        2
        * total_node_recall
        * total_node_precision
        / (total_node_recall + total_node_precision)
        if (total_node_recall + total_node_precision) > 0
        else 0.0
    )
    total_edge_recall = (
        total_true_positive_edges / total_true_edges if total_true_edges else 0.0
    )
    total_edge_precision = (
        total_true_positive_edges / total_pred_edges if total_pred_edges else 0.0
    )
    total_edge_f1 = (
        2
        * total_edge_recall
        * total_edge_precision
        / (total_edge_recall + total_edge_precision)
        if (total_edge_recall + total_edge_precision) > 0
        else 0.0
    )
    total_true_items = total_true_nodes + total_true_edges
    total_pred_items = total_pred_nodes + total_pred_edges
    total_true_positive_items = total_true_positive_nodes + total_true_positive_edges
    total_item_recall = (
        total_true_positive_items / total_true_items if total_true_items else 0.0
    )
    total_item_precision = (
        total_true_positive_items / total_pred_items if total_pred_items else 0.0
    )
    total_item_f1 = (
        2
        * total_item_recall
        * total_item_precision
        / (total_item_recall + total_item_precision)
        if (total_item_recall + total_item_precision) > 0
        else 0.0
    )
    # Parameter extraction overall scores
    total_param_recall = (
        total_true_positive_params / total_true_params if total_true_params else 0.0
    )
    total_param_precision = (
        total_true_positive_params / total_pred_params if total_pred_params else 0.0
    )
    total_param_f1 = (
        2
        * total_param_recall
        * total_param_precision
        / (total_param_recall + total_param_precision)
        if (total_param_recall + total_param_precision) > 0
        else 0.0
    )

    if verbose:
        print("=" * 80)
        print(f"# Overall Evaluation")
        print(f"#True samples: {total_true_samples}")
        print(f"#Matched samples: {total_matched_samples}")
        print(
            f"Collection rate: {total_matched_samples / total_true_samples:.4f} (matched/true)"
        )
        print(
            f"[Node] Recall: {total_node_recall:.4f}, Precision: {total_node_precision:.4f}, F1: {total_node_f1:.4f}"
        )
        print(
            f"[Edge] Recall: {total_edge_recall:.4f}, Precision: {total_edge_precision:.4f}, F1: {total_edge_f1:.4f}"
        )
        print(
            f"[Structural] Recall: {total_item_recall:.4f}, Precision: {total_item_precision:.4f}, F1: {total_item_f1:.4f}"
        )
        print(
            f"[Parametric] Recall: {total_param_recall:.4f}, Precision: {total_param_precision:.4f}, F1: {total_param_f1:.4f}"
        )

    return {
        "node": (total_node_recall, total_node_precision, total_node_f1),
        "edge": (total_edge_recall, total_edge_precision, total_edge_f1),
        "node+edge": (total_item_recall, total_item_precision, total_item_f1),
        "parameter": (total_param_recall, total_param_precision, total_param_f1),
    }


def sample_match(pred_samples, true_samples):
    pred_labels = [sample["label"] for sample in pred_samples]
    true_labels = [sample["label"] for sample in true_samples]
    pairwise_similarities = []
    for t_idx, true_label in enumerate(true_labels):
        for p_idx, pred_label in enumerate(pred_labels):
            sim = difflib.SequenceMatcher(
                None, pred_label.replace(" ", ""), true_label.replace(" ", "")
            ).ratio()
            pairwise_similarities.append((sim, t_idx, p_idx))
    pairwise_similarities.sort(reverse=True)
    matched_true_indices = set()
    matched_pred_indices = set()
    matched_pred_samples = []
    matched_true_samples = []

    max_matches = min(len(pred_labels), len(true_labels))
    matched_pairs = []
    for sim, t_idx, p_idx in pairwise_similarities:
        if t_idx not in matched_true_indices and p_idx not in matched_pred_indices:
            matched_true_samples.append(true_samples[t_idx])
            matched_pred_samples.append(pred_samples[p_idx])
            matched_true_indices.add(t_idx)
            matched_pred_indices.add(p_idx)
            matched_pairs.append((pred_labels[p_idx], true_labels[t_idx], sim))
            if len(matched_pred_samples) >= max_matches:
                break

    return matched_pred_samples, matched_true_samples


def _id2item(data_graph, target_id):
    for item in data_graph:
        if item.get("@id") == target_id:
            return item
    return None


def _labels_match(pred_label, true_label):
    def normalize_list_or_str(val):
        if isinstance(val, list):
            return [_normalize_label(v) for v in val]
        else:
            return [_normalize_label(val)]

    pred_norms = normalize_list_or_str(pred_label)
    true_norms = normalize_list_or_str(true_label)
    if any(p == t for p in pred_norms for t in true_norms):
        return True
    return False


def _normalize_label(label):
    for ch in [
        " ",
        "-",
        "–",
        "_",
        ".",
        "·",
        "•",
        ":",
        ",",
        "°",
        "∘",
        "/",
        "=",
        "+",
        "×",
        "*",
        "$",
        "#",
        "@",
        "®",
        "£",
    ]:
        label = label.replace(ch, "")
    label_lower = label.lower()
    return label_lower


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate JSON-LD files.")
    parser.add_argument(
        "--pred-path",
        type=str,
        required=True,
        help="Path to the LLM-generated JSON file or directory.",
    )
    parser.add_argument(
        "--true-path",
        type=str,
        required=True,
        help="Path to the ground truth file or directory.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Enable verbose output.",
    )
    args = parser.parse_args()
    evaluate(args.pred_path, args.true_path, args.verbose)
