import logging
import math
from collections import defaultdict
from collections.abc import Callable

import torch
from pykeen.metrics.ranking import ArithmeticMeanRank, HitsAtK, InverseHarmonicMeanRank
from tqdm import tqdm

from kge.dataset import TripleDataset, TripleDatasetWithNeg
from kge.models import KGModel
from kge.types import EntityID, RelationID

from .evaluator import HITSAT1, HITSAT3, HITSAT10, MR, MRR, RankMetrics

# Mapping from OGB evaluator metric names to our metric names
metrics_name_mapping = {
    "hits@1_list": HITSAT1,
    "hits@3_list": HITSAT3,
    "hits@10_list": HITSAT10,
    "mrr_list": MRR,
}


class OGBEvaluator:
    def __init__(self, dataset_name: str):
        # We don't use the OGB evaluator because it only support MRR, not MR.
        # self.evaluator = Evaluator(name=dataset_name)
        self.rank_metrics: dict[str, Callable] = {
            MR: ArithmeticMeanRank(),
            MRR: InverseHarmonicMeanRank(),
            HITSAT1: HitsAtK(k=1),
            HITSAT3: HitsAtK(k=3),
            HITSAT10: HitsAtK(k=10),
        }
        self._initialized_filters = False

    def initialize_filters(
        self,
        datasets: list[TripleDataset],
    ) -> None:
        """Initialize lhs and rhs filters for the ranking metrics.

        In OGB datasets, we only need to save the filters for the training set. This is used
        when computing filtered NLL.
        """
        self.skip_objects_train: dict[tuple[EntityID, RelationID], set[EntityID]] = defaultdict(set)
        logging.info("Initializing filters for ranking metrics...")
        for ds in datasets:
            if ds.split == "train":
                self.skip_objects_train = ds.sr_to_objects
        logging.info("Filters initialized.")
        self._initialized_filters = True

    def evaluate_object_ranks(
        self,
        model: KGModel,
        dataset: TripleDatasetWithNeg,
        batch_size: int = 32,
        device: str = "cuda",
        sample_size: int = 1000,
    ) -> RankMetrics:
        """Evaluate object ranks on dataset.

        Args:
            model: KGE model to evaluate
            dataset: triple dataset with negative samples to evaluate on
            batch_size: Batch size for scoring
            device: Device to use for scoring
            sample_size: Number of triples to sample for evaluation. -1 means all triples.

        Returns:
            RankMetrics object containing the metrics

        """
        model.eval()
        if sample_size == -1:
            sample_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        ranks = torch.ones(sample_size)
        samples = 0
        # metrics: dict[str, torch.Tensor] = defaultdict(lambda: torch.tensor([]))
        with torch.no_grad():
            for s, r, o, o_neg in tqdm(
                dataloader,
                total=math.ceil(sample_size / batch_size),
                desc=f"Evaluating object ranks on {dataset.split}",
            ):
                scores = model.score_o(s.to(device), r.to(device)).cpu()
                for i in range(len(scores)):
                    score_o_i = scores[i, o[i]]  # (,)
                    score_o_neg_i = scores[i, o_neg[i]]  # (b,)
                    ranks[samples] += torch.sum(score_o_neg_i > score_o_i).cpu()
                    samples += 1
                    if samples >= sample_size:
                        break
                if samples >= sample_size:
                    logging.info("Reached sample size limit.")
                    break
        # return RankMetrics(
        #     hits_at_1=metrics[HITSAT1].mean().item(),
        #     hits_at_3=metrics[HITSAT3].mean().item(),
        #     hits_at_10=metrics[HITSAT10].mean().item(),
        #     mrr=metrics[MRR].mean().item(),
        #     mr=-1.0,
        # )
        return RankMetrics.from_dict(
            {metric: self.rank_metrics[metric](ranks) for metric in self.rank_metrics},
        )

    def evaluate_object_nll(
        self,
        model: KGModel,
        dataset: TripleDataset | TripleDatasetWithNeg,
        batch_size: int = 1,
        device: str = "cuda",
        sample_size: int = 1000,
        *,
        filtered: bool = False,
    ) -> float:
        """Evaluate object likelihood on dataset.

        Args:
            model: KGE model to evaluate
            dataset: triple dataset , or triple dataset with negative samples to evaluate on
            batch_size: Batch size for scoring
            device: Device to use for scoring
            sample_size: Number of triples to sample for evaluation. -1 means all triples.
            filtered: Whether to set the probabilities of objects in train to 0.

        Returns:
            Average                log likelihood per triple

        """
        if filtered and not self._initialized_filters:
            msg = "Filters not initialized. Call initialize_filters first."
            raise ValueError(msg)
        model.eval()
        if sample_size == -1:
            sample_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        total_nll = 0.0
        total_samples = 0
        with torch.no_grad():
            if isinstance(dataset, TripleDatasetWithNeg):
                for s, r, o, _ in tqdm(
                    dataloader,
                    total=math.ceil(sample_size / batch_size),
                    desc=f"Evaluating object likelihood on {dataset.split}",
                ):
                    s = s.to(device)
                    r = r.to(device)
                    o = o.to(device)
                    scores = model.score_o(s, r)
                    if filtered and dataset.split != "train":
                        for i in range(len(o)):
                            for o_true in self.skip_objects_train[(int(s[i]), int(r[i]))]:
                                if o_true != int(o[i]):
                                    scores[i, o_true] = -float("inf")
                    if model.return_log_prob and filtered:
                        # renormalize by subtracting the log-sum-exp over the allowed classes
                        log_probs = scores - torch.logsumexp(scores, dim=-1, keepdim=True)
                    else:
                        log_probs = torch.nn.functional.log_softmax(scores, dim=1)
                    # Get the log prob for the true objects
                    batch_nll = -log_probs[torch.arange(len(o)), o].sum().item()
                    total_nll += batch_nll
                    total_samples += len(o)
                    if total_samples >= sample_size:
                        break
            elif isinstance(dataset, TripleDataset):
                for s, r, o in tqdm(
                    dataloader,
                    total=math.ceil(sample_size / batch_size),
                    desc=f"Evaluating object likelihood on {dataset.split}",
                ):
                    s = s.to(device)
                    r = r.to(device)
                    o = o.to(device)
                    scores = model.score_o(s, r)
                    if filtered and dataset.split != "train":
                        for i in range(len(o)):
                            for o_true in self.skip_objects_train[(int(s[i]), int(r[i]))]:
                                if o_true != int(o[i]):
                                    scores[i, o_true] = -float("inf")
                    if model.return_log_prob:
                        log_probs = scores - torch.logsumexp(scores, dim=-1, keepdim=True)
                    else:
                        log_probs = torch.nn.functional.log_softmax(scores, dim=1)
                    # Get the log prob for the true objects
                    batch_nll = -log_probs[torch.arange(len(o)), o].sum().item()
                    total_nll += batch_nll
                    total_samples += len(o)
                    if total_samples >= sample_size:
                        break
            else:
                raise ValueError(f"Unknown dataset type: {type(dataset)}")
        return total_nll / total_samples
