from __future__ import annotations

import logging
from pathlib import Path

import pandas as pd


def create_mappings(
    input_folder: str | Path,
    output_folder: str | Path,
    extension: str | None = None,
) -> None:
    """Process KG triple files and create ID mappings for entities and relations.

    For example, given the source data
        src/
        - train.tsv
        - valid.tsv
        - test.tsv

    where each file contains triples of the form:
        bob knows alice
        bob knows charlie
        alice knows bob

    This will produce
        mapped/
        - train.tsv
        - valid.tsv
        - test.tsv
        - entity_ids.txt
        - relation_ids.txt

    where the train.tsv, valid.tsv, and test.tsv files contain triples of the form:
        0   0   1
        0   0   2
        1   0   0

    and the entity_ids.txt and relation_ids.txt files contain the mappings of entities and relations
    to their respective IDs (based on the order they appear in the files).
    For entity_ids.txt:
        bob
        alice
        charlie
    For relation_ids.txt:
        knows

    Args:
        input_folder: Path to folder containing train/valid/test TSV files
        output_folder: Path to save mapped files and mappings
        extension: Extension of the input files. If None, tries to infer it from the file it finds.

    """
    logging.info("Processing folder %s", input_folder)
    input_folder = Path(input_folder)
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    entities: set[str] = set()
    relations: set[str] = set()

    if extension is None:
        # Try to infer the extension from `train` file if it exists
        if (input_folder / "train.tsv").exists():
            extension = "tsv"
        elif (input_folder / "train.csv").exists():
            extension = "csv"
        elif (input_folder / "train.txt").exists():
            extension = "txt"
        else:
            msg = (
                "No extension provided. Tried to infer `tsv` `csv` or `txt` but could "
                "not find any `train.tsv`, `train.csv` or `train.txt` in input folder."
            )
            logging.error(msg)
            raise ValueError(msg)
        logging.info("Inferred extension: %s", extension)

    # Add tracking for train entities
    train_entities: set[str] = set()
    entities_by_split: dict[str, set[str]] = {"train": set(), "valid": set(), "test": set()}

    for split in ["train", "valid", "test"]:
        file_path = input_folder / f"{split}.{extension}"
        try:
            logging.info("Processing file %s", file_path)
            df = pd.read_csv(
                file_path,
                sep="\t",
                header=None,
                names=["subject", "relation", "object"],
            )
            logging.info("Number of %s triples: %d", split, len(df))
            split_entities = set(df["subject"]) | set(df["object"])
            entities_by_split[split] = split_entities
            entities.update(split_entities)
            relations.update(df["relation"])
        except Exception as e:
            logging.exception("Error processing file %s: %s", file_path, e)
            continue

    # Create mappings, storing entities and relations alphabetically
    entity_to_id: dict[str, int] = {ent: idx for idx, ent in enumerate(sorted(entities))}
    relation_to_id: dict[str, int] = {rel: idx for idx, rel in enumerate(sorted(relations))}

    logging.info("Number of entities: %d", len(entity_to_id))
    logging.info("Number of relations: %d", len(relation_to_id))

    # Find entities in valid/test but not in train
    entities_not_in_train = set()
    entity_counts = {"valid": {}, "test": {}}

    for split in ["valid", "test"]:
        missing = entities_by_split[split] - entities_by_split["train"]
        entities_not_in_train.update(missing)
        if missing:
            df = pd.read_csv(
                input_folder / f"{split}.{extension}",
                sep="\t",
                header=None,
                names=["subject", "relation", "object"],
            )
            for ent in missing:
                count = len(df[df["subject"] == ent]) + len(df[df["object"] == ent])
                entity_counts[split][ent] = count

    if entities_not_in_train:
        # Create a DataFrame with missing entities, their mapped IDs, and occurrence counts
        missing_entities_df = pd.DataFrame(
            {
                "entity": sorted(entities_not_in_train),
                "mapped_id": [entity_to_id[ent] for ent in sorted(entities_not_in_train)],
                "valid_occurrences": [
                    entity_counts["valid"].get(ent, 0) for ent in sorted(entities_not_in_train)
                ],
                "test_occurrences": [
                    entity_counts["test"].get(ent, 0) for ent in sorted(entities_not_in_train)
                ],
            },
        )

        # Save to file
        missing_entities_path = output_folder / "entities_not_in_train.tsv"
        missing_entities_df.to_csv(missing_entities_path, sep="\t", index=False)

        total_valid_occurrences = sum(entity_counts["valid"].values())
        total_test_occurrences = sum(entity_counts["test"].values())

        logging.warning(
            "Found %d entities in valid/test that don't appear in train (affecting %d valid and %d test triples). "
            "Models that do not support entities not seen during training will struggle on this split. "
            "Full list saved to %s.",
            len(entities_not_in_train),
            total_valid_occurrences,
            total_test_occurrences,
            missing_entities_path,
        )

    for split in ["train", "valid", "test"]:
        file_path = input_folder / f"{split}.{extension}"
        try:
            logging.info("Processing file %s", file_path)
            df = pd.read_csv(
                file_path,
                sep="\t",
                header=None,
                names=["subject", "relation", "object"],
            )
            mapped_df = pd.DataFrame(
                {
                    "subject": df["subject"].map(entity_to_id),
                    "relation": df["relation"].map(relation_to_id),
                    "object": df["object"].map(entity_to_id),
                },
            )
            output_file_path = output_folder / (file_path.stem + ".tsv")
            mapped_df.to_csv(
                output_file_path,
                sep="\t",
                header=False,
                index=False,
            )
            logging.info("Mapped and saved file %s", output_file_path)
        except Exception as e:
            logging.exception("Error processing file %s: %s", file_path, e)
            continue

    # Save entity and relation mappings: one concept by line, being at the line that corresponds
    # to its index.
    logging.info("Saving entity and relation mappings")
    with open(output_folder / "entity_ids.txt", "w") as f:
        for ent in sorted(entities):
            f.write(f"{ent}\n")

    with open(output_folder / "relation_ids.txt", "w") as f:
        for rel in sorted(relations):
            f.write(f"{rel}\n")
