"""
@author: lxy
@email: linxy59@mail2.sysu.edu.cn
@date: 2021/10/26
@description: null
"""
from typing import List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset


class AlignDataset(Dataset):
    def __init__(self,
                 seeds: List[Tuple[int, int]],
                 kg1_entity_list: List[int],
                 kg2_entity_list: List[int],
                 nentity, negative_sample_size, mode):
        self.seeds = seeds
        self.len = len(seeds)
        self.kg1_entity_list = kg1_entity_list
        self.kg2_entity_list = kg2_entity_list
        self.kg1_entity_size = len(kg1_entity_list)
        self.kg2_entity_size = len(kg2_entity_list)
        self.nentity = nentity
        self.negative_sample_size = negative_sample_size
        self.mode = mode
        self.count = self.count_frequency(seeds)
        self.true_head, self.true_tail = self.get_true_head_and_tail(seeds)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        positive_sample = self.seeds[idx]

        head, tail = positive_sample

        subsampling_weight = self.count[head] + self.count[tail]
        subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))

        negative_sample_list = []
        negative_sample_size = 0

        while negative_sample_size < self.negative_sample_size:

            if self.mode == 'align-head-batch':
                negative_sample = np.random.randint(self.kg1_entity_size, size=self.negative_sample_size * 2)
                negative_sample = np.array(list(map(lambda x: self.kg1_entity_list[x], negative_sample)))
                mask = np.in1d(
                    negative_sample,
                    self.true_head[tail],
                    assume_unique=True,
                    invert=True
                )
            elif self.mode == 'align-tail-batch':
                negative_sample = np.random.randint(self.kg2_entity_size, size=self.negative_sample_size * 2)
                negative_sample = np.array(list(map(lambda x: self.kg2_entity_list[x], negative_sample)))
                mask = np.in1d(
                    negative_sample,
                    self.true_tail[head],
                    assume_unique=True,
                    invert=True
                )
            else:
                raise ValueError('Training batch mode %s not supported' % self.mode)
            negative_sample = negative_sample[mask]
            negative_sample_list.append(negative_sample)
            negative_sample_size += negative_sample.size

        negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]

        negative_sample = torch.LongTensor(negative_sample)

        positive_sample = torch.LongTensor(positive_sample)

        return positive_sample, negative_sample, subsampling_weight, self.mode

    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode

    def count_frequency(self, seeds: List[Tuple[int, int]], start=4):
        """
        Get frequency of a partial triple like (head, relation) or (relation, tail)
        The frequency will be used for subsampling like word2vec
        """
        count = {}
        for a, b in seeds:
            if a not in count:
                count[a] = start
            else:
                count[a] += 1

            if b not in count:
                count[b] = start
            else:
                count[b] += 1
        return count

    @staticmethod
    def get_true_head_and_tail(seeds):
        """
        Build a dictionary of true triples that will
        be used to filter these true triples for negative sampling
        """

        true_head = {}
        true_tail = {}

        for a, b in seeds:
            if a not in true_tail:
                true_tail[a] = []
            true_tail[a].append(b)
            if b not in true_head:
                true_head[b] = []
            true_head[b].append(a)

        for b in true_head:
            true_head[b] = np.array(list(set(true_head[b])))
        for a in true_tail:
            true_tail[a] = np.array(list(set(true_tail[a])))

        return true_head, true_tail


class TrainDataset(Dataset):
    def __init__(self,
                 triples: List[Tuple[int, int, int]],
                 nentity, nrelation, nvalue,
                 negative_sample_size, mode):
        self.len = len(triples)
        self.triples = triples
        self.triple_set = set(triples)
        self.nentity = nentity
        self.nrelation = nrelation
        self.nvalue = nvalue
        self.negative_sample_size = negative_sample_size
        self.mode = mode
        self.count = self.count_frequency(triples)
        self.true_head, self.true_tail = self.get_true_head_and_tail(self.triples)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        positive_sample = self.triples[idx]

        head, relation, tail = positive_sample

        subsampling_weight = self.count[(head, relation)] + self.count[(tail, -relation - 1)]
        subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))

        negative_sample_list = []
        negative_sample_size = 0

        while negative_sample_size < self.negative_sample_size:

            if self.mode == 'head-batch':
                negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2)
                mask = np.in1d(
                    negative_sample,
                    self.true_head[(relation, tail)],
                    assume_unique=True,
                    invert=True
                )
            elif self.mode == 'tail-batch':
                negative_sample = np.random.randint(self.nvalue, size=self.negative_sample_size * 2)
                mask = np.in1d(
                    negative_sample,
                    self.true_tail[(head, relation)],
                    assume_unique=True,
                    invert=True
                )
            else:
                raise ValueError('Training batch mode %s not supported' % self.mode)
            negative_sample = negative_sample[mask]
            negative_sample_list.append(negative_sample)
            negative_sample_size += negative_sample.size

        negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]

        negative_sample = torch.LongTensor(negative_sample)

        positive_sample = torch.LongTensor(positive_sample)

        return positive_sample, negative_sample, subsampling_weight, self.mode

    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode

    @staticmethod
    def count_frequency(triples, start=4):
        """
        Get frequency of a partial triple like (head, relation) or (relation, tail)
        The frequency will be used for subsampling like word2vec
        """
        count = {}
        for head, relation, tail in triples:
            if (head, relation) not in count:
                count[(head, relation)] = start
            else:
                count[(head, relation)] += 1

            if (tail, -relation - 1) not in count:
                count[(tail, -relation - 1)] = start
            else:
                count[(tail, -relation - 1)] += 1
        return count

    @staticmethod
    def get_true_head_and_tail(triples):
        """
        Build a dictionary of true triples that will
        be used to filter these true triples for negative sampling
        """

        true_head = {}
        true_tail = {}

        for head, relation, tail in triples:
            if (head, relation) not in true_tail:
                true_tail[(head, relation)] = []
            true_tail[(head, relation)].append(tail)
            if (relation, tail) not in true_head:
                true_head[(relation, tail)] = []
            true_head[(relation, tail)].append(head)

        for relation, tail in true_head:
            true_head[(relation, tail)] = np.array(list(set(true_head[(relation, tail)])))
        for head, relation in true_tail:
            true_tail[(head, relation)] = np.array(list(set(true_tail[(head, relation)])))

        return true_head, true_tail


class TestDataset(Dataset):
    def __init__(self, triples, all_true_triples, nentity, nrelation, mode):
        self.len = len(triples)
        self.triple_set = set(all_true_triples)
        self.triples = triples
        self.nentity = nentity
        self.nrelation = nrelation
        self.mode = mode

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        head, relation, tail = self.triples[idx]

        if self.mode == 'head-batch':
            tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set
                   else (-1, head) for rand_head in range(self.nentity)]
            tmp[head] = (0, head)
        elif self.mode == 'tail-batch':
            tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set
                   else (-1, tail) for rand_tail in range(self.nentity)]
            tmp[tail] = (0, tail)
        elif self.mode == 'relation-batch':
            tmp = [(0, rand_relation) if (head, rand_relation, tail) not in self.triple_set
                   else (-1, relation) for rand_relation in range(self.nentity)]
            tmp[relation] = (0, relation)
        else:
            raise ValueError('negative batch mode %s not supported' % self.mode)

        tmp = torch.LongTensor(tmp)
        filter_bias = tmp[:, 0].float()
        negative_sample = tmp[:, 1]

        positive_sample = torch.LongTensor((head, relation, tail))

        return positive_sample, negative_sample, filter_bias, self.mode

    @staticmethod
    def collate_fn(data):
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        filter_bias = torch.stack([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, filter_bias, mode


def one_shot_iterator(dataloader):
    """
    Transform a PyTorch Dataloader into python iterator
    """
    while True:
        for data in dataloader:
            yield data


class BidirectionalOneShotIterator(object):
    def __init__(self, dataloader_head, dataloader_tail):
        self.iterator_head = one_shot_iterator(dataloader_head)
        self.iterator_tail = one_shot_iterator(dataloader_tail)
        self.step = 0

    def __next__(self):
        self.step += 1
        if self.step % 2 == 0:
            data = next(self.iterator_head)
        else:
            data = next(self.iterator_tail)
        return data

    def __iter__(self):
        self.step += 1
        if self.step % 2 == 0:
            data = next(self.iterator_head)
        else:
            data = next(self.iterator_tail)
        return data


class SingledirectionalOneShotIterator(object):
    def __init__(self, dataloader):
        self.iterator = one_shot_iterator(dataloader)
        self.step = 0

    def __next__(self):
        self.step += 1
        data = next(self.iterator)
        return data
