from __future__ import annotations

import logging
import math
import time
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass

import torch
from pykeen.metrics.ranking import ArithmeticMeanRank, HitsAtK, InverseHarmonicMeanRank
from tqdm import tqdm

from kge.dataset.dataset import TripleDataset
from kge.models import KGModel
from kge.types import EntityID, RelationID

HITSAT1 = "hits@1"
HITSAT3 = "hits@3"
HITSAT10 = "hits@10"
MRR = "mrr"
MR = "mr"


@dataclass
class RankMetrics:
    """Metrics for ranking."""

    hits_at_1: float
    hits_at_3: float
    hits_at_10: float
    mrr: float
    mr: float

    @classmethod
    def from_dict(cls, metrics: dict[str, float]) -> RankMetrics:
        """Initialize from a dictionary."""
        return cls(
            hits_at_1=metrics[HITSAT1],
            hits_at_3=metrics[HITSAT3],
            hits_at_10=metrics[HITSAT10],
            mrr=metrics[MRR],
            mr=metrics[MR],
        )

    def to_dict(self) -> dict[str, float]:
        """Convert to a dictionary."""
        return {
            HITSAT1: self.hits_at_1,
            HITSAT3: self.hits_at_3,
            HITSAT10: self.hits_at_10,
            MRR: self.mrr,
            MR: self.mr,
        }

    def to_string(self) -> str:
        """Format metrics dictionary as string for logging."""
        return " | ".join(f"{k}: {v:.4f}" for k, v in self.to_dict().items())


class KGEvaluator:
    def __init__(
        self,
        *,
        filtered: bool = True,
    ):
        """Initialize evaluator with PyKEEN metrics.

        Args:
            all_entities: Set of all entities in the dataset

        """
        self.rank_metrics: dict[str, Callable] = {
            MR: ArithmeticMeanRank(),
            MRR: InverseHarmonicMeanRank(),
            HITSAT1: HitsAtK(k=1),
            HITSAT3: HitsAtK(k=3),
            HITSAT10: HitsAtK(k=10),
        }
        self.filtered = filtered
        self._initialized_filters = False

    def initialize_filters(
        self,
        datasets: list[TripleDataset],
    ) -> None:
        """Initialize lhs and rhs filters for the ranking metrics."""
        self.skip_objects_train: dict[tuple[EntityID, RelationID], set[EntityID]] = defaultdict(set)
        self.skip_objects_valid: dict[tuple[EntityID, RelationID], set[EntityID]] = defaultdict(set)
        self.skip_objects_test: dict[tuple[EntityID, RelationID], set[EntityID]] = defaultdict(set)
        logging.info("Initializing filters for ranking metrics...")
        for ds in datasets:
            if ds.split == "train":
                self.skip_objects_train = ds.sr_to_objects
            elif ds.split == "valid":
                self.skip_objects_valid = ds.sr_to_objects
            elif ds.split == "test":
                self.skip_objects_test = ds.sr_to_objects
        logging.info("Filters initialized.")
        self._initialized_filters = True

    def evaluate_object_nll(
        self,
        model: KGModel,
        dataset: TripleDataset,
        batch_size: int = 256,
        device: str = "cuda",
        sample_size: int = 1000,
        filtered: bool = True,
    ) -> float:
        """Evaluate object likelihood on dataset.

        Args:
            model: KGModel to evaluate
            dataset: TripleDataset to evaluate on
            batch_size: Batch size for scoring
            device: Device to use for scoring
            sample_size: Number of triples to sample for evaluation. -1 means all triples.
            filtered: Whether to set the probabilities of objects in train to 0.
                Unlike for ranking metrics, we do not set the probabilities of objects in valid
                and test to 0.

        Returns:
            Average negative log likelihood per triple

        """
        model.eval()
        if sample_size == -1:
            sample_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        total_nll = 0.0
        total_samples = 0
        with torch.no_grad():
            for s, r, o in tqdm(
                dataloader,
                total=math.ceil(sample_size / batch_size),
                desc=f"Evaluating object likelihood on {dataset.split}",
            ):
                s = s.to(device)
                r = r.to(device)
                o = o.to(device)
                scores = model.score_o(s, r)
                if filtered and dataset.split != "train":
                    for i in range(len(o)):
                        for o_true in self.skip_objects_train[(int(s[i]), int(r[i]))]:
                            if o_true != int(o[i]):
                                scores[i, o_true] = -float("inf")
                if model.return_log_prob:
                    if filtered and dataset.split != "train":
                        # renormalize by subtracting the log-sum-exp over the allowed classes
                        log_probs = scores - torch.logsumexp(scores, dim=-1, keepdim=True)
                    else:
                        log_probs = scores
                else:
                    log_probs = torch.nn.functional.log_softmax(scores, dim=1)
                # Get the log prob for the true objects
                batch_nll = -log_probs[torch.arange(len(o)), o].sum().item()
                total_nll += batch_nll
                total_samples += len(o)
                if total_samples >= sample_size:
                    break
        return total_nll / total_samples

    def evaluate_score_matrix_rank(
        self,
        model: KGModel,
        dataset: TripleDataset,
        batch_size: int = 256,
        device: str = "cuda",
        *,
        log_prob: bool = True,
    ) -> float:
        """Get the rank of the score matrix evaluated on all entities and relations.

        In theory, we would want to compute the rank of the score matrix for all subject-relation
        pairs. I.e., of the size (num_entities * num_relations, num_entities).
        However, this is too large to store in memory.
        Instead, we only compute it for the subject-relation pairs which appear in the training set.

        Args:
            model: KGModel to evaluate
            dataset: TripleDataset to evaluate on
            batch_size: Batch size for scoring
            device: Device to use for scoring
            log_prob: Whether to convert the scores to log probabilities before computing the rank

        """
        if not log_prob and model.return_log_prob:
            msg = "Cannot evaluate the rank of logits scores because the model already returns"
            msg += " log probabilities. Ignoring log_prob argument."
            logging.warning(msg)
            log_prob = True
        model.eval()
        if dataset.split == "train":
            sr_pairs = list(self.skip_objects_train.keys())
        elif dataset.split == "valid":
            sr_pairs = list(self.skip_objects_valid.keys())
        elif dataset.split == "test":
            sr_pairs = list(self.skip_objects_test.keys())
        else:
            raise ValueError(f"Invalid dataset split: {dataset.split}")

        sr_pairs = torch.tensor(sr_pairs, dtype=torch.long)
        dataloader = torch.utils.data.DataLoader(
            sr_pairs,
            batch_size=batch_size,
            shuffle=False,
        )

        all_scores = []
        desc = "Evaluating score matrix" if not log_prob else "Evaluating log prob matrix"
        with torch.no_grad():
            for sr in tqdm(dataloader, desc=desc):
                batch_scores = model.score_o(sr[:, 0].to(device), sr[:, 1].to(device))
                if log_prob and not model.return_log_prob:
                    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
                all_scores.append(batch_scores.cpu())
            all_scores = torch.cat(all_scores)  # (sr_pairs, num_entities)
            msg = f"Evaluating rank of matrix of shape: {all_scores.shape}..."
            logging.info(msg)
            start_time = time.time()
            singular_values = torch.linalg.svdvals(all_scores)
            rank = torch.sum(singular_values > 0.1).item()
            msg = f"Rank of matrix: {rank} calculated in {time.time() - start_time:.2f} seconds"
            logging.info(msg)
        return rank

    def evaluate_object_ranks(
        self,
        model: KGModel,
        dataset: TripleDataset,
        batch_size: int = 256,
        device: str = "cuda",
        sample_size: int = 1000,
    ) -> RankMetrics:
        """Evaluate object ranks on test dataset.

        Args:
            model: KGModel to evaluate
            dataset: TripleDataset to evaluate on
            batch_size: Batch size for scoring
            device: Device to use for scoring
            sample_size: Number of triples to sample for evaluation. -1 means all triples.

        """
        if self.filtered and not self._initialized_filters:
            msg = "Filters not initialized. Call initialize_filters first."
            raise ValueError(msg)
        model.eval()
        if sample_size == -1:
            sample_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        ranks = torch.ones(sample_size)
        samples = 0
        with torch.no_grad():
            for s_batch, r_batch, o_batch in tqdm(
                dataloader,
                total=math.ceil(sample_size / batch_size),
                desc=f"Evaluating object ranks on {dataset.split}",
            ):
                scores = model.score_o(s_batch.to(device), r_batch.to(device)).cpu()
                # Compute ranks element by element
                for i in range(len(s_batch)):
                    samples += 1
                    # Apply filter set. Iterate over all skip objects from train, valid, or test set
                    for o_true in self.skip_objects_train[(int(s_batch[i]), int(r_batch[i]))]:
                        if o_true != int(o_batch[i]):
                            scores[i, o_true] = -float("inf")
                    for o_true in self.skip_objects_valid[(int(s_batch[i]), int(r_batch[i]))]:
                        if o_true != int(o_batch[i]):
                            scores[i, o_true] = -float("inf")
                    for o_true in self.skip_objects_test[(int(s_batch[i]), int(r_batch[i]))]:
                        if o_true != int(o_batch[i]):
                            scores[i, o_true] = -float("inf")
                    ranks[samples - 1] += torch.sum(scores[i] > scores[i, o_batch[i]])
                    if samples >= sample_size:
                        break
                if samples >= sample_size:
                    logging.info("Reached sample size limit.")
                    break
        return RankMetrics.from_dict(
            {metric: self.rank_metrics[metric](ranks) for metric in self.rank_metrics},
        )

    def plot_ranks_vs_nll(
        self,
        model: KGModel,
        dataset: TripleDataset,
        filepath: str,
        batch_size: int = 256,
        device: str = "cuda",
        sample_size: int = 1000,
    ) -> None:
        """Create a scatter plot of filtered ranks vs negative log likelihood.

        Args:
            model: KGModel to evaluate
            dataset: TripleDataset to evaluate on
            filepath: Path where to save the plot
            batch_size: Batch size for scoring
            device: Device to use for scoring
            sample_size: Number of triples to sample for evaluation. -1 means all triples.

        """
        from pathlib import Path

        import matplotlib.pyplot as plt
        import numpy as np

        # Set publication-quality plot settings
        plt.style.use("seaborn-v0_8-whitegrid")
        plt.rcParams.update(
            {
                "font.family": "serif",
                "font.size": 11,
                "axes.labelsize": 12,
                "axes.titlesize": 12,
                "xtick.labelsize": 10,
                "ytick.labelsize": 10,
                "legend.fontsize": 10,
                "figure.figsize": (5, 4),
                "figure.dpi": 300,
            },
        )

        model.eval()
        if sample_size == -1:
            sample_size = len(dataset)

        ranks = []
        nlls = []

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
        )

        with torch.no_grad():
            for s, r, o in tqdm(
                dataloader,
                total=math.ceil(sample_size / batch_size),
                desc="Computing ranks and NLL",
            ):
                s = s.to(device)
                r = r.to(device)
                o = o.to(device)

                # Compute NLL
                scores = model.score_o(s, r)
                log_probs = torch.nn.functional.log_softmax(scores, dim=1)
                batch_nlls = -log_probs[torch.arange(len(o)), o].cpu()

                # Compute ranks
                scores = scores.cpu()
                batch_ranks = torch.ones(len(o))

                for i in range(len(s)):
                    if self.filtered:
                        for o_true in self.skip_objects[(int(s[i]), int(r[i]))]:
                            if o_true != int(o[i]):
                                scores[i, o_true] = -float("inf")
                    batch_ranks[i] += torch.sum(scores[i] > scores[i, o[i]])

                ranks.extend(batch_ranks.tolist())
                nlls.extend(batch_nlls.tolist())

                if len(ranks) >= sample_size:
                    ranks = ranks[:sample_size]
                    nlls = nlls[:sample_size]
                    break

        # Create the scatter plot
        fig, ax = plt.subplots()
        ax.scatter(ranks, nlls, alpha=0.5, s=20)
        ax.set_xlabel("Filtered Rank")
        ax.set_ylabel("Negative Log-Likelihood")
        ax.set_xscale("log")

        # Compute correlation coefficient
        log_ranks = np.log(ranks)
        correlation = np.corrcoef(log_ranks, nlls)[0, 1]
        ax.set_title(f"Log Rank vs NLL (correlation: {correlation:.3f})")

        # Ensure the output directory exists
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)

        # Save the plot
        plt.tight_layout()
        plt.savefig(filepath, bbox_inches="tight")
        plt.close()
