from pathlib import Path
import pandas as pd
import numpy as np
import torch
import pickle
from tqdm.autonotebook import tqdm
from typing import (
    cast,
    Iterable,
)

from pykeen.evaluation import RankBasedEvaluator
from pykeen.evaluation.evaluator import optional_context_manager
from pykeen.utils import split_list_in_batches_iter
from pykeen.typing import LABEL_HEAD, LABEL_TAIL
from pykeen.constants import TARGET_TO_INDEX

from SEPAL import SEPAL_DIR
from SEPAL.baselines.run_baselines import PYKEEN_MODELS
from SEPAL.downstream_evaluation import DATASETS_NAMES
from SEPAL.visualization import MODEL_NAMES


# Adapted from pykeen.evaluation evaluate() function to scale to large datasets
def evaluate(
    model,
    mapped_triples,
    num_entities,
    num_negatives=10000,
    batch_size=256,
    device=None,
    use_tqdm=True,
    targets=(LABEL_HEAD, LABEL_TAIL),
):
    # Send to device
    if device is not None:
        model = model.to(device)
    device = model.device

    # Ensure evaluation mode
    model.eval()

    # Send tensors to device
    mapped_triples = mapped_triples.to(device=device)

    # Prepare batches
    if batch_size is None:
        batch_size = 32
    batches = cast(
        Iterable[np.ndarray],
        split_list_in_batches_iter(input_list=mapped_triples, batch_size=batch_size),
    )

    # Show progressbar
    num_triples = mapped_triples.shape[0]

    # Disable gradient tracking
    _tqdm_kwargs = dict(
        desc=f"Evaluating on {model.device}",
        total=num_triples,
        unit="triple",
        unit_scale=True,
        # Choosing no progress bar (use_tqdm=False) would still show the initial progress bar without disable=True
        disable=not use_tqdm,
    )

    # Initialize evaluator
    evaluator = RankBasedEvaluator(
        filtered=False,
        metrics=["HitsAtK"],
        metrics_kwargs=[{"k": 50}],
        add_defaults=True,
    )

    with optional_context_manager(
        use_tqdm, tqdm(**_tqdm_kwargs)
    ) as progress_bar, torch.inference_mode():
        # batch-wise processing
        for batch in batches:
            batch_size = batch.shape[0]
            for target in targets:
                _evaluate_batch(
                    batch=batch,
                    model=model,
                    target=target,
                    evaluator=evaluator,
                    num_negatives=num_negatives,
                    num_entities=num_entities,
                )

            if use_tqdm:
                progress_bar.update(batch_size)

        # Finalize
        result = evaluator.finalize()
    return result


def _evaluate_batch(
    batch,
    model,
    target,
    evaluator,
    num_negatives,
    num_entities,
):
    # Sample random ids for the negatives
    random_ids = torch.randint(0, num_entities, (num_negatives,))

    # Add the batch ids for the positives
    column = TARGET_TO_INDEX[target]
    ids = torch.cat((batch[:, column], random_ids))

    # Predict scores
    scores = model.predict(hrt_batch=batch, target=target, ids=ids)

    # Get the scores of true triples
    true_scores = scores[
        torch.arange(0, batch.shape[0]), torch.arange(0, batch.shape[0])
    ].unsqueeze(dim=-1)

    # Keep scores of num_negatives negatives, and the positive
    scores = torch.hstack((scores[:, batch.shape[0] :], true_scores))

    # process scores
    evaluator.process_scores_(
        hrt_batch=batch,
        target=target,
        true_scores=true_scores,
        scores=scores,
        dense_positive_mask=None,
    )
    return


def transductive_link_prediction(
    data,
    subset,
    model_name,
    dim,
    embeddings,
    rel_embeddings,
    id,
    results_file,
    num_negatives,
):
    # Load triple factories
    data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    if subset == "val":
        with open(data_dir / "validation_tf.pkl", "rb") as f:
            tf = pickle.load(f)
    elif subset == "test":
        with open(data_dir / "testing_tf.pkl", "rb") as f:
            tf = pickle.load(f)

    # Build embedding model
    model = PYKEEN_MODELS[model_name](
        triples_factory=tf, embedding_dim=dim, random_seed=0
    )

    # Remove inverse relations
    if model_name == "mure":
        rel_embeddings = [r[::2] for r in rel_embeddings]
    else:
        rel_embeddings = rel_embeddings[::2]

    # Load trained embeddings
    if model_name == "tucker":
        model.load_state_dict(
            state_dict={
                "interaction.core_tensor": torch.Tensor(embeddings[1]),
                "interaction.head_batch_norm.weight": torch.ones(dim),
                "interaction.head_batch_norm.running_mean": torch.zeros(dim),
                "interaction.head_batch_norm.running_var": torch.ones(dim),
                "interaction.head_batch_norm.num_batches_tracked": torch.tensor(0),
                "interaction.head_relation_batch_norm.weight": torch.ones(dim),
                "interaction.head_relation_batch_norm.bias": torch.zeros(dim),
                "interaction.head_relation_batch_norm.running_mean": torch.zeros(dim),
                "interaction.head_relation_batch_norm.running_var": torch.ones(dim),
                "interaction.head_relation_batch_norm.num_batches_tracked": torch.tensor(
                    0
                ),
                "entity_representations.0._embeddings.weight": torch.Tensor(
                    embeddings[0]
                ),
                "relation_representations.0._embeddings.weight": torch.Tensor(
                    rel_embeddings
                ),
            },
            strict=False,
        )
    elif model_name == "mure":
        model.load_state_dict(
            state_dict={
                "entity_representations.0._embeddings.weight": torch.Tensor(
                    embeddings[0]
                ),
                "entity_representations.1._embeddings.weight": torch.Tensor(
                    embeddings[1]
                ).unsqueeze(dim=-1),
                "entity_representations.2._embeddings.weight": torch.Tensor(
                    embeddings[2]
                ).unsqueeze(dim=-1),
                "relation_representations.0._embeddings.weight": torch.Tensor(
                    rel_embeddings[0]
                ),
                "relation_representations.1._embeddings.weight": torch.Tensor(
                    rel_embeddings[1]
                ),
            },
            strict=False,
        )
    else:
        model.load_state_dict(
            state_dict={
                "entity_representations.0._embeddings.weight": torch.Tensor(embeddings),
                "relation_representations.0._embeddings.weight": torch.Tensor(
                    rel_embeddings
                ),
                "relation_representations.0.regularizer.weight": torch.Tensor([0.0]),
                "relation_representations.0.regularizer.regularization_term": torch.Tensor(
                    [0.0]
                ),
            },
            strict=False,
        )

    # Evaluate model
    results = evaluate(
        model=model,
        mapped_triples=tf.mapped_triples,
        num_entities=tf.num_entities,
        num_negatives=num_negatives,
        batch_size=256,
    ).to_dict()

    # Save results
    results["id"] = id
    results["filtered"] = False
    results["sampled"] = True
    results["num_negatives"] = num_negatives
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    return


if __name__ == "__main__":
    for subset in ["val", "test"]:
        ## Evaluate models that have not been evaluated yet
        # Load scores if file exists
        scores_path = SEPAL_DIR / f"{subset}_lp_scores.parquet"
        if scores_path.is_file():
            scores = pd.read_parquet(scores_path)
            evaluated = scores["id"].tolist()
        else:
            evaluated = []

        # Load checkpoint info
        checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
        checkpoints = pd.read_parquet(checkpoint_path)
        checkpoints = checkpoints[checkpoints.subset == "train"]

        # Mask the models already evaluated
        mask = ~checkpoints["id"].isin(evaluated)
        to_evaluate = checkpoints[mask]

        # Begin with smaller models
        to_evaluate = to_evaluate.sort_values(by=["data"])

        # Evaluate yago4 only
        to_evaluate = to_evaluate[to_evaluate.data.isin(["yago4_lcc"])]

        # Evaluate distmult only
        to_evaluate = to_evaluate[to_evaluate.embed_method == "distmult"]

        # Loop over the models to evaluate
        for index, row in to_evaluate.iterrows():
            print(
                f"---------------- Evaluating: {MODEL_NAMES[row.embed_method]} {DATASETS_NAMES[row.data]} embeddings on {subset} set ----------------"
            )
            # Load models' parameters
            if row.embed_method == "tucker":
                embeddings = [
                    np.load(SEPAL_DIR / f"embeddings/{row.id}.npy"),
                    np.load(SEPAL_DIR / f"embeddings/{row.id}_tensor.npy"),
                ]
                rel_embeddings = np.load(
                    SEPAL_DIR / f"embeddings/{row.id}_relations.npy"
                )
            elif row.embed_method == "mure":
                embeddings = [
                    np.load(SEPAL_DIR / f"embeddings/{row.id}.npy"),
                    np.load(SEPAL_DIR / f"embeddings/{row.id}_biais_1.npy"),
                    np.load(SEPAL_DIR / f"embeddings/{row.id}_biais_2.npy"),
                ]
                rel_embeddings = [
                    np.load(SEPAL_DIR / f"embeddings/{row.id}_relations_0.npy"),
                    np.load(SEPAL_DIR / f"embeddings/{row.id}_relations_1.npy"),
                ]
            else:
                embeddings = np.load(SEPAL_DIR / f"embeddings/{row.id}.npy")
                rel_embeddings = np.load(
                    SEPAL_DIR / f"embeddings/{row.id}_relations.npy"
                )
            transductive_link_prediction(
                data=row.data,
                subset=subset,
                model_name=row.embed_method,
                dim=int(row.embed_dim),
                embeddings=embeddings,
                rel_embeddings=rel_embeddings,
                id=row["id"],
                results_file=scores_path,
                num_negatives=10000,
            )
