from pathlib import Path
import pandas as pd
import numpy as np
import torch
import pickle
from tqdm.autonotebook import tqdm
from typing import (
    cast,
    Iterable,
)
import json

from pykeen.evaluation import RankBasedEvaluator
from pykeen.evaluation.evaluator import optional_context_manager
from pykeen.utils import split_list_in_batches_iter
from pykeen.typing import LABEL_HEAD, LABEL_TAIL
from pykeen.constants import TARGET_TO_INDEX

from SEPAL import SEPAL_DIR
from SEPAL.baselines.run_baselines import PYKEEN_MODELS
from SEPAL.downstream_evaluation import DATASETS_NAMES
from SEPAL.baselines.PBG.utils import load_PBG_embeddings, load_PBG_relation_embeddings


# Adapted from pykeen.evaluation evaluate() function to scale to large datasets
def evaluate(
    model,
    mapped_triples,
    num_entities,
    num_negatives=10000,
    batch_size=256,
    device=None,
    use_tqdm=True,
    targets=(LABEL_HEAD, LABEL_TAIL),
):
    # Send to device
    if device is not None:
        model = model.to(device)
    device = model.device

    # Ensure evaluation mode
    model.eval()

    # Send tensors to device
    mapped_triples = mapped_triples.to(device=device)

    # Prepare batches
    if batch_size is None:
        batch_size = 32
    batches = cast(
        Iterable[np.ndarray],
        split_list_in_batches_iter(input_list=mapped_triples, batch_size=batch_size),
    )

    # Show progressbar
    num_triples = mapped_triples.shape[0]

    # Disable gradient tracking
    _tqdm_kwargs = dict(
        desc=f"Evaluating on {model.device}",
        total=num_triples,
        unit="triple",
        unit_scale=True,
        # Choosing no progress bar (use_tqdm=False) would still show the initial progress bar without disable=True
        disable=not use_tqdm,
    )

    # Initialize evaluator
    evaluator = RankBasedEvaluator(
        filtered=False,
        metrics=["HitsAtK"],
        metrics_kwargs=[{"k": 50}],
        add_defaults=True,
    )

    with optional_context_manager(
        use_tqdm, tqdm(**_tqdm_kwargs)
    ) as progress_bar, torch.inference_mode():
        # batch-wise processing
        for batch in batches:
            batch_size = batch.shape[0]
            for target in targets:
                _evaluate_batch(
                    batch=batch,
                    model=model,
                    target=target,
                    evaluator=evaluator,
                    num_negatives=num_negatives,
                    num_entities=num_entities,
                )

            if use_tqdm:
                progress_bar.update(batch_size)

        # Finalize
        result = evaluator.finalize()
    return result


def _evaluate_batch(
    batch,
    model,
    target,
    evaluator,
    num_negatives,
    num_entities,
):
    # Sample random ids for the negatives
    random_ids = torch.randint(0, num_entities, (num_negatives,))

    # Add the batch ids for the positives
    column = TARGET_TO_INDEX[target]
    ids = torch.cat((batch[:, column], random_ids))

    # Predict scores
    scores = model.predict(hrt_batch=batch, target=target, ids=ids)

    # Get the scores of true triples
    true_scores = scores[
        torch.arange(0, batch.shape[0]), torch.arange(0, batch.shape[0])
    ].unsqueeze(dim=-1)

    # Keep scores of num_negatives negatives, and the positive
    scores = torch.hstack((scores[:, batch.shape[0] :], true_scores))

    # process scores
    evaluator.process_scores_(
        hrt_batch=batch,
        target=target,
        true_scores=true_scores,
        scores=scores,
        dense_positive_mask=None,
    )
    return


def transductive_link_prediction(
    data,
    subset,
    model_name,
    dim,
    embeddings,
    rel_embeddings,
    id,
    results_file,
    num_negatives,
):
    # Load triple factories
    data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    if subset == "val":
        with open(data_dir / "validation_tf.pkl", "rb") as f:
            tf = pickle.load(f)
    elif subset == "test":
        with open(data_dir / "testing_tf.pkl", "rb") as f:
            tf = pickle.load(f)

    # Build embedding model
    model = PYKEEN_MODELS[model_name](
        triples_factory=tf, embedding_dim=dim, random_seed=0
    )

    # Load trained embeddings
    model.load_state_dict(
        state_dict={
            "entity_representations.0._embeddings.weight": torch.Tensor(embeddings),
            "relation_representations.0._embeddings.weight": torch.Tensor(
                rel_embeddings
            ),
            "relation_representations.0.regularizer.weight": torch.Tensor([0.0]),
            "relation_representations.0.regularizer.regularization_term": torch.Tensor(
                [0.0]
            ),
        },
        strict=False,
    )

    # Evaluate model
    results = evaluate(
        model=model,
        mapped_triples=tf.mapped_triples,
        num_entities=tf.num_entities,
        num_negatives=num_negatives,
        batch_size=256,
    ).to_dict()

    # Save results
    results["id"] = id
    results["filtered"] = False
    results["sampled"] = True
    results["num_negatives"] = num_negatives
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    return


def get_pbg2sepal_reordering(data):
    """SEPAL and PBG do not index entities and relations in the same way.
    This function returns the reordering needed to match PBG indices to SEPAL indices.
    """
    # Load SEPAL training triples factory
    sepal_data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    with open(sepal_data_dir / "training_tf.pkl", "rb") as f:
        tf = pickle.load(f)
    sepal_entity_to_idx = tf.entity_to_id
    sepal_relation_to_idx = tf.relation_to_id

    # Load PBG entity and relation names
    pbg_data_dir = SEPAL_DIR / f"baselines/PBG/data/{data}"
    with open(pbg_data_dir / "entity_names_all_0.json") as f:
        entity_names = json.load(f)
    with open(pbg_data_dir / "dynamic_rel_names.json") as f:
        relation_names = json.load(f)
    pbg_entity_to_idx = {k: v for v, k in enumerate(entity_names)}
    pbg_relation_to_idx = {k: v for v, k in enumerate(relation_names)}

    # Get the reorderings
    entity_reordering = [pbg_entity_to_idx[k] for k in sepal_entity_to_idx.keys()]
    relation_reordering = [pbg_relation_to_idx[k] for k in sepal_relation_to_idx.keys()]

    return entity_reordering, relation_reordering



if __name__ == "__main__":
    for data in [
        "mini_yago3_lcc",
        "yago3_lcc",
        "yago4_lcc",
        "yago4.5_lcc",
        "yago4_with_full_ontology",
    ]:
        entity_reordering, relation_reordering = get_pbg2sepal_reordering(data)
        for subset in ["val", "test"]:
            print(f"Evaluating {DATASETS_NAMES[data]} embeddings on {subset} set")
            scores_path = SEPAL_DIR / f"{subset}_lp_scores_pbg.parquet"
            
            # Load embeddings
            embeddings = load_PBG_embeddings(data, subset="train")
            embed_dim = embeddings.shape[1]

            # Load relation embeddings
            rel_embeddings = load_PBG_relation_embeddings(data, subset="train")

            # Reorder embeddings
            embeddings = embeddings[entity_reordering]
            rel_embeddings = rel_embeddings[relation_reordering]

            # Evaluate embeddings
            transductive_link_prediction(
                data=data,
                subset=subset,
                model_name="distmult",
                dim=embed_dim,
                embeddings=embeddings,
                rel_embeddings=rel_embeddings,
                id=f"PBG - {DATASETS_NAMES[data]} train",
                results_file=scores_path,
                num_negatives=10000,
            )
