from collections import defaultdict
from functools import cached_property
from pathlib import Path

import numpy as np
import torch
from ogb.linkproppred import LinkPropPredDataset
from tqdm import tqdm

from kge.types import EntityID, RelationID

from .dataset import TripleDataset


class OGBDataset:
    def __init__(self, dataset_name: str, root: Path, *, add_inverse: bool = False):
        """Initialize the OGB dataset.

        Args:
            dataset_name (str): Name of the OGB dataset
            root (Path): Root directory for the dataset
            add_inverse (bool): Whether to add inverse relations to the dataset

        """
        self.dataset_name = dataset_name
        self.dataset = LinkPropPredDataset(name=dataset_name, root=root)
        edge_split = self.dataset.get_edge_split()
        do_offset_type = "num_nodes_dict" in self.dataset[0]
        if do_offset_type:
            num_nodes_per_type: dict[str, int] = self.dataset[0]["num_nodes_dict"]
            offsets_per_type = {}
            types = list(num_nodes_per_type.keys())
            types.sort()
            for t in types:
                offsets_per_type[t] = sum(
                    num_nodes_per_type[t_prev] for t_prev in types if t_prev < t
                )
        else:
            offsets_per_type = None
        edges = edge_split["train"]
        if do_offset_type:
            edges["head"] += [offsets_per_type[t] for t in edges["head_type"]]
            edges["tail"] += [offsets_per_type[t] for t in edges["tail_type"]]
        train_triples = torch.tensor(
            np.stack([edges["head"], edges["relation"], edges["tail"]], axis=1),
            dtype=torch.long,
        )
        self.train = TripleDataset(triples=train_triples, split="train")

        def _get_triples(
            split: str,
            offsets_per_type: dict[str, int] | None = None,
        ) -> TripleDatasetWithNeg:
            edges = edge_split[split]
            if do_offset_type:
                if offsets_per_type is None:
                    msg = "offsets_per_type must be provided if do_offset_type is True"
                    raise ValueError(msg)
                head_offsets = [offsets_per_type[t] for t in edges["head_type"]]
                tail_offsets = [offsets_per_type[t] for t in edges["tail_type"]]
                edges["head"] += head_offsets
                edges["tail"] += tail_offsets
                edges["head_neg"] += np.stack([head_offsets], axis=1)
                edges["tail_neg"] += np.stack([tail_offsets], axis=1)
            triples = torch.tensor(
                np.stack([edges["head"], edges["relation"], edges["tail"]], axis=1),
                dtype=torch.long,
            )
            head_neg = torch.tensor(edges["head_neg"], dtype=torch.long)
            tail_neg = torch.tensor(edges["tail_neg"], dtype=torch.long)
            return TripleDatasetWithNeg(triples, head_neg, tail_neg, split=split)

        self.valid = _get_triples("valid", offsets_per_type)
        self.test = _get_triples("test", offsets_per_type)

        if add_inverse:
            num_original_relations = len(
                self.train.relations | self.valid.relations | self.test.relations,
            )
            self.train.add_inverse_triples(num_original_relations)
            self.valid.add_inverse_triples(num_original_relations)
            self.test.add_inverse_triples(num_original_relations)

        self.entities = self.train.entities | self.valid.entities | self.test.entities
        self.relations = self.train.relations | self.valid.relations | self.test.relations
        self.num_entities = len(self.entities)
        self.num_relations = len(self.relations)


class TripleDatasetWithNeg(TripleDataset):
    """TripleDataset with negative samples for both head and tail prediction."""

    def __init__(
        self,
        triples: torch.Tensor,
        head_negative_samples: torch.Tensor,
        tail_negative_samples: torch.Tensor,
        split: str = "test",
        *,
        prepare_inverse_negatives: bool = True,
    ):
        """Initialize the TripleDatasetWithNeg.

        Args:
            triples (torch.Tensor): Tensor of triples with shape (n, 3)
            head_negative_samples (torch.Tensor): Tensor of head negative samples with shape (n, m)
            tail_negative_samples (torch.Tensor): Tensor of tail negative samples with shape (n, m)
            split (str, optional): Split of the dataset. Defaults to "test".
            prepare_inverse_negatives (bool, optional): If False, only tail negative samples are
                used. Defaults to True.

        """
        super().__init__(triples, split=split)
        if prepare_inverse_negatives:
            # The regular triples are for object (tail) prediction.
            # The inverse triples, which are (optionally) concatenated after the regular triples,
            # are for subject (head) prediction.
            # Therefore, the negative samples are the concatenation of the tail negative samples
            # and the head negative samples, in the same order as the concatenation of the triples
            # and their inverse triples.
            self.negative_samples = torch.cat([tail_negative_samples, head_negative_samples], dim=0)
        else:
            self.negative_samples = tail_negative_samples

    def __getitem__(
        self,
        idx: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        return (
            self.triples[idx, 0],
            self.triples[idx, 1],
            self.triples[idx, 2],
            self.negative_samples[idx],
        )

    @cached_property
    def sr_to_objects(self) -> dict[tuple[EntityID, RelationID], set[EntityID]]:
        sr_to_objects = defaultdict(set)
        for s, r, o, _ in tqdm(self, desc="Building sr_to_objects"):
            sr_to_objects[(s.item(), r.item())].add(o.item())
        return sr_to_objects
