from __future__ import annotations

import logging
from collections import defaultdict
from functools import cached_property
from pathlib import Path
from typing import Generic, TypeVar

import pandas as pd
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

from kge.types import EntityID, RelationID

T = TypeVar("T", bound="TripleDataset | OneToManyDataset")


class SplitDataset(Dataset, Generic[T]):
    def __init__(
        self,
        train: T,
        valid: T,
        test: T,
        *,
        add_inverse: bool = False,  # whether to add inverse relations to all datasets
    ):
        self.train = train
        self.valid = valid
        self.test = test

        test_entities_not_in_train = self.test.entities - self.train.entities
        if len(test_entities_not_in_train) > 0:
            # Count appearances of each entity in test triples
            test_entity_counts = {}
            for entity in test_entities_not_in_train:
                count = (
                    ((self.test.triples[:, 0] == entity) | (self.test.triples[:, 2] == entity))
                    .sum()
                    .item()
                )
                test_entity_counts[entity] = count

            msg = (
                f"Test set contains {len(test_entities_not_in_train)} entities not in train set.\n"
                f"They appear in a total of {sum(test_entity_counts.values())} test triples. "
                "Results will likely be random for these entities and their appearances."
            )

            logging.warning(msg)
        if add_inverse:
            num_original_relations = len(
                self.train.relations | self.valid.relations | self.test.relations,
            )
            logging.info("Adding inverse relations to the dataset...")
            self.train.add_inverse_triples(num_original_relations)
            self.valid.add_inverse_triples(num_original_relations)
            self.test.add_inverse_triples(num_original_relations)
        self.entities = self.train.entities | self.valid.entities | self.test.entities
        self.relations = self.train.relations | self.valid.relations | self.test.relations
        self.num_entities = len(self.entities)
        self.num_relations = len(self.relations)


class TripleDataset(Dataset):
    """Dataset of pointwise triples (s,r,o).

    Attributes:
        triples: Tensor of shape (n_triples, 3)
        split: Split of the dataset
        entities: Set of entities in the split
        relations: Set of relations in the split

    """

    def __init__(
        self,
        triples: torch.Tensor,
        split: str = "train",
    ):
        """Initialize the dataset from a tsv file."""
        self.split = split

        self.triples = triples
        entities: torch.Tensor = torch.unique(torch.cat([self.triples[:, 0], self.triples[:, 2]]))
        self.entities: set[EntityID] = set(entities.tolist())
        relations: torch.Tensor = torch.unique(self.triples[:, 1])
        self.relations: set[RelationID] = set(relations.tolist())

        msg = f"Loaded {split} dataset with {len(self)} triples"
        logging.info(msg)

    @classmethod
    def from_tsv(cls, data_path: str | Path, split: str = "train") -> TripleDataset:
        """Initialize the dataset from a tsv file."""
        triples_df = pd.read_csv(
            data_path,
            sep="\t",
            header=None,
            names=["subject", "relation", "object"],
        )
        triples = torch.tensor(triples_df.values, dtype=torch.long)
        return cls(triples, split=split)

    def add_inverse_triples(self, relation_offset: int) -> None:
        """Add inverse triples and relations to the dataset."""
        inverse_triples = self.triples.clone()
        # Swap subject and object
        inverse_triples[:, [0, 2]] = inverse_triples[:, [2, 0]]
        # Add n_relations to relation id to mark as inverse
        inverse_triples[:, 1] += relation_offset
        self.triples = torch.cat([self.triples, inverse_triples], dim=0)
        self.relations = self.relations | set(inverse_triples[:, 1].tolist())
        # delete cached property if it exists
        if "sr_to_objects" in self.__dict__:
            del self.__dict__["sr_to_objects"]

    def __len__(self) -> int:
        return len(self.triples)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return (
            self.triples[idx, 0],
            self.triples[idx, 1],
            self.triples[idx, 2],
        )

    @cached_property
    def sr_to_objects(self) -> dict[tuple[EntityID, RelationID], set[EntityID]]:
        sr_to_objects = defaultdict(set)
        for s, r, o in tqdm(self, desc="Building sr_to_objects"):
            sr_to_objects[(s.item(), r.item())].add(o.item())
        return sr_to_objects


class OneToManyDataset(Dataset):
    """Dataset where each (s, r) maps to multiple objects (1-vs-N).

    Attributes:
        triples: Tensor of shape (n_triples, 3)
        split: Split of the dataset
        entities: Set of entities in the split
        relations: Set of relations in the split

    """

    def __init__(self, triples: torch.Tensor, num_entities: int, split: str = "train"):
        self.split = split
        self.triples = triples
        self.entities = set(triples[:, 0].tolist()) | set(triples[:, 2].tolist())
        self.relations = set(triples[:, 1].tolist())
        self.num_entities = num_entities

        unique_sr_pairs, sr_indices = torch.unique(triples[:, :2], dim=0, return_inverse=True)
        self.unique_sr_pairs = unique_sr_pairs
        self.o_masks = torch.zeros(len(unique_sr_pairs), num_entities, dtype=torch.bool)
        for i in range(len(unique_sr_pairs)):
            objects = triples[sr_indices == i, 2]
            self.o_masks[i, objects] = True
        logging.info(
            f"Loaded {split} OneToManyDataset with {len(self.unique_sr_pairs)} (s, r) pairs.",
        )

    def add_inverse_triples(self, relation_offset: int) -> None:
        """Add inverse triples and relations to the dataset."""
        inverse_triples = self.triples.clone()
        # Swap subject and object
        inverse_triples[:, [0, 2]] = inverse_triples[:, [2, 0]]
        # Add n_relations to relation id to mark as inverse
        inverse_triples[:, 1] += relation_offset
        unique_sr_pairs, sr_indices = torch.unique(
            inverse_triples[:, :2],
            dim=0,
            return_inverse=True,
        )
        o_masks = torch.zeros(len(unique_sr_pairs), self.num_entities, dtype=torch.bool)
        for i in range(len(unique_sr_pairs)):
            objects = inverse_triples[sr_indices == i, 2]
            o_masks[i, objects] = True
        self.unique_sr_pairs = torch.cat([self.unique_sr_pairs, unique_sr_pairs], dim=0)
        self.o_masks = torch.cat([self.o_masks, o_masks], dim=0)
        self.relations = self.relations | set(inverse_triples[:, 1].tolist())
        # delete cached property if it exists
        if "sr_to_objects" in self.__dict__:
            del self.__dict__["sr_to_objects"]

    def __len__(self):
        return len(self.unique_sr_pairs)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.unique_sr_pairs[idx, 0], self.unique_sr_pairs[idx, 1], self.o_masks[idx]

    @classmethod
    def from_tsv(
        cls,
        data_path: str | Path,
        num_entities: int,
        split: str = "train",
    ) -> OneToManyDataset:
        """Load dataset from a TSV file."""
        triples_df = pd.read_csv(
            data_path,
            sep="\t",
            header=None,
            names=["subject", "relation", "object"],
        )
        triples = torch.tensor(triples_df.values, dtype=torch.long)
        return cls(triples, num_entities, split=split)

    @cached_property
    def sr_to_objects(self) -> dict[tuple[EntityID, RelationID], set[EntityID]]:
        sr_to_objects = defaultdict(set)
        for s, r, o_mask in tqdm(self, desc="Building sr_to_objects"):
            objects = torch.nonzero(o_mask, as_tuple=True)[0]
            sr_to_objects[(s.item(), r.item())].update(objects.tolist())
        return sr_to_objects
