import json
from pathlib import Path
from typing import Tuple

import torch


def _read_data(data: dict, device: torch.device, is_train: bool):
    if is_train:
        for i in range(len(data["background_knowledge"]) - 1, -1, -1):
            f = data["background_knowledge"][i]
            p = f.split("(")[0]
            consts = f.split("(")[1].split(")")[0].split(", ")
            consts_num = list(map(lambda x: int(x[1:]), consts))
            consts_num_unique = set(consts_num)
            assert len(consts_num_unique) == 1  # since constants are grouped in GeoILP
            number = consts_num_unique.pop()

        for i in range(len(data["positive_examples"]) - 1, -1, -1):
            f = data["positive_examples"][i]
            p = f.split("(")[0]
            consts = f.split("(")[1].split(")")[0].split(", ")
            consts_num = list(map(lambda x: int(x[1:]), consts))
            consts_num_unique = set(consts_num)
            assert len(consts_num_unique) == 1
            number = consts_num_unique.pop()
        

    all_bk_predicates, all_consts = set(), set()  # WARN: In GeoILP, target predicates can also be in background knowledge
    for bk_fact in data["background_knowledge"]:
        p = bk_fact.split("(")[0]
        consts = bk_fact.split("(")[1].split(")")[0].split(", ")
        all_bk_predicates.add((p, len(consts)))
        all_consts.update(consts)
    all_tgt_predicates = set()
    for tgt_fact in data["positive_examples"]:
        p = tgt_fact.split("(")[0]
        consts = tgt_fact.split("(")[1].split(")")[0].split(", ")
        all_tgt_predicates.add((p, len(consts)))
        all_consts.update(consts)
    all_bk_predicates = all_bk_predicates - all_tgt_predicates
    
    all_predicates = all_bk_predicates | all_tgt_predicates
    assert len(set(map(lambda x: x[0], all_predicates))) == len(all_predicates)
    all_bk_predicates = sorted(all_bk_predicates, key=lambda x: (x[1], x[0]))
    all_tgt_predicates = sorted(all_tgt_predicates, key=lambda x: (x[1], x[0]))
    all_consts = sorted(all_consts, key=lambda x: (x[1:], x[0]))  # first sort by number, then by alphabet
    num_constants = len(all_consts)

    background_knowledge = tuple(
        torch.zeros([num_constants] * arity, dtype=bool, device=device)
        for _, arity in all_bk_predicates
    )
    targets_label = tuple(
        torch.zeros([num_constants] * arity, dtype=bool, device=device)
        for _, arity in all_tgt_predicates
    )
    targets_init = tuple(
        torch.zeros([num_constants] * arity, dtype=bool, device=device)
        for _, arity in all_tgt_predicates
    )
    all_bk_predicates = list(map(lambda x: x[0], all_bk_predicates))
    all_tgt_predicates = list(map(lambda x: x[0], all_tgt_predicates))
    for fact in data["background_knowledge"] + data["positive_examples"]:
        p = fact.split("(")[0]
        consts = fact.split("(")[1].split(")")[0].split(", ")
        if p in all_bk_predicates:
            assert p not in all_tgt_predicates
            index_p = all_bk_predicates.index(p)
            index_consts = tuple(map(lambda x: all_consts.index(x), consts))
            background_knowledge[index_p][index_consts] = True
        elif p in all_tgt_predicates:
            assert p not in all_bk_predicates
            index_p = all_tgt_predicates.index(p)
            index_consts = tuple(map(lambda x: all_consts.index(x), consts))
            targets_label[index_p][index_consts] = True
        else:
            raise ValueError(f"Unknown predicates: {p}")
    for fact in data["background_knowledge"]:
        p = fact.split("(")[0]
        consts = fact.split("(")[1].split(")")[0].split(", ")
        if p in all_tgt_predicates:
            assert p not in all_bk_predicates
            index_p = all_tgt_predicates.index(p)
            index_consts = tuple(map(lambda x: all_consts.index(x), consts))
            targets_init[index_p][index_consts] = True

    background_knowledge_names = all_bk_predicates
    targets_names = all_tgt_predicates

    return background_knowledge, targets_label, targets_init, background_knowledge_names, targets_names, num_constants


class GeoILPDataset:

    def __init__(self, task_name: str, device: torch.device):
        self.task_name = task_name

        dataset_path = Path(__file__).parent.parent.parent / "data" / (self.task_name + ".json")
        task = json.load(dataset_path.open())
        assert task["meta_data"]["is_multitask"] is False
        targets_names = tuple(task["meta_data"]["target_predicates"])
        train_data, eval_data = task["data"]["train"], task["data"]["eval"]

        self.train_background_knowledge, self.train_targets_label, self.train_targets_init, train_background_knowledge_names, train_targets_names, self.train_num_constants = _read_data(
            train_data, device, True
        )
        self.eval_background_knowledge, self.eval_targets_label, self.eval_targets_init, eval_background_knowledge_names, eval_targets_names, self.eval_num_constants = _read_data(
            eval_data, device, False
        )

        assert sorted(targets_names) == train_targets_names
        assert sorted(targets_names) == eval_targets_names
        assert train_background_knowledge_names == eval_background_knowledge_names
        assert train_targets_names == eval_targets_names

        self.background_knowledge_names = tuple(train_background_knowledge_names)
        self.targets_names = tuple(train_targets_names)
        
    
    def predicate_names(self) -> Tuple[Tuple[str], Tuple[str]]:
        return self.background_knowledge_names, self.targets_names

    def generate_data(self, is_train: bool) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        if is_train:
            return self.train_background_knowledge, self.train_targets_label, self.train_targets_init
        else:
            return self.eval_background_knowledge, self.eval_targets_label, self.eval_targets_init
