import os
from typing import List
from dataclasses import dataclass
from collections import defaultdict

import torch
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import pytorch_lightning as pl
from omegaconf import DictConfig

import src.data_hub as data_hub


@dataclass
class Query:
    data_index: torch.Tensor = None
    h: torch.Tensor = None
    r: torch.Tensor = None
    t: torch.Tensor = None
    h_text: List[str] = None
    r_text: List[str] = None
    t_text: List[str] = None
    edge_index: torch.Tensor = None
    filter_mask: torch.Tensor = None
    negative_target_node: torch.Tensor = None
    

class HomogenousGraph:
    def __init__(self,
                 name: str,
                 structure_dict: dict,
                 text_dict: dict,
                 ind_text_dict: dict = None):
        self.name = name
        self.structure_dict = structure_dict
        self.text_dict = text_dict
        self.ind_text_dict = ind_text_dict
    
    @classmethod
    def from_config(cls, config: DictConfig):
        base_data_path = os.path.join('./data/', config.data.name)
        structure_dict = data_hub.load_structure_dict(base_data_path)
        text_dict = data_hub.load_text_dict(base_data_path)
        if config.data.inductive:
            ind_text_dict = data_hub.load_ind_text_dict(base_data_path)
        else:
            ind_text_dict = None
        
        return cls(
            config.data.name,
            structure_dict,
            text_dict,
            ind_text_dict
        )
        
    @property
    def train_split(self):
        return self.structure_dict['train']
    
    @property
    def valid_split(self):
        return self.structure_dict['valid']
    
    @property
    def test_split(self):
        return self.structure_dict['test']
    
    def get_textual_information(self, idx, ind=False):
        idx = str(idx)
        if ind and self.ind_text_dict:
            return (self.ind_text_dict[idx]['nid'], 'Node Description: ' + self.ind_text_dict[idx]['text'])
        else:
            return (self.text_dict[idx]['nid'], 'Node Description: ' + self.text_dict[idx]['text'])


class HeterogeneousGraph:
    def __init__(self,
                 name: str,
                 structure_dict: dict,
                 node_text_dict: dict,
                 relation_text_dict: dict,
                 ind_node_text_dict: dict = None,
                 ind_relation_text_dict: dict = None):
        self.name = name
        self.structure_dict = structure_dict
        self.node_text_dict = node_text_dict
        self.relation_text_dict = relation_text_dict
        self.ind_node_text_dict = ind_node_text_dict
        self.ind_relation_ind_text_dict = ind_relation_text_dict
        
    @classmethod
    def from_config(cls, config: DictConfig):
        base_data_path = os.path.join('./data/', config.data.name)
        structure_dict = data_hub.load_structure_dict(base_data_path)
        node_text_dict = data_hub.load_node_text_dict(base_data_path)
        relation_text_dict = data_hub.load_relation_text_dict(base_data_path)
        if config.data.inductive:
            ind_node_text_dict = data_hub.load_ind_node_text_dict(base_data_path)
            ind_relation_text_dict = data_hub.load_ind_relation_text_dict(base_data_path)
        else:
            ind_node_text_dict = None
            ind_relation_text_dict = None
        
        return cls(
            config.data.name,
            structure_dict,
            node_text_dict,
            relation_text_dict,
            ind_node_text_dict,
            ind_relation_text_dict,
        )

    @property
    def train_split(self):
        return self.structure_dict['train']
    
    @property
    def valid_split(self):
        return self.structure_dict['valid']
    
    @property
    def test_split(self):
        return self.structure_dict['test']
    
    def get_node_textual_information(self, idx, ind=False):
        idx = str(idx)
        if ind and self.ind_node_text_dict:
            return (self.ind_node_text_dict[idx]['nid'], 'Node Description: ' + self.ind_node_text_dict[idx]['text'])
        else:
            return (self.node_text_dict[idx]['nid'], 'Node Description: ' + self.node_text_dict[idx]['text'])
        
    def get_relation_textual_information(self, idx, ind=False):
        idx = str(idx)
        if ind and self.ind_relation_text_dict:
            return (self.ind_relation_text_dict[idx]['rid'], 'Relation Description: ' + self.ind_relation_text_dict[idx]['text'])
        else:
            return (self.relation_text_dict[idx]['rid'], 'Relation Description: ' + self.relation_text_dict[idx]['text'])


class HomogeneousDataset(Dataset):
    def __init__(self,
                 graph: HomogenousGraph,
                 relation_description: str,
                 split: str,
                 only_return_node_text: bool = False,
                 only_return_relation_text: bool = False,
                 for_plm: bool = False):
        super().__init__()
        
        self.graph = graph
        self.split = split
        self.only_return_node_text = only_return_node_text
        self.only_return_relation_text = only_return_relation_text
        self.for_plm = for_plm
        
        self.relation_description = 'Relation Description: ' + relation_description
        self.rev_relation_description = 'Relation Description: the reverese of' + relation_description
        
        self.split_data_dict = getattr(self.graph, f'{self.split}_split')
        
        self.g = Data(edge_index=self.split_data_dict['edge_index'], 
                      num_nodes=torch.max(self.split_data_dict['edge_index'].view(-1)) + 1)
        
        if self.only_return_node_text:
            self.data = [self.graph.get_textual_information(i, ind=self.split_data_dict['inductive'])
                         for i in range(self.g.num_nodes)]
        elif self.only_return_relation_text:
            self.data = sum([[(0 + i * 2, self.relation_description), (1 + i * 2, self.rev_relation_description)] for i in range(16)], [])
        else:
            self.data = self.split_data_dict['data']
        
    @classmethod
    def from_config(cls, 
                    config: DictConfig, 
                    split: str, 
                    only_return_node_text: bool = False, 
                    only_return_relation_text: bool = False, 
                    for_plm: bool = False):
        graph = HomogenousGraph.from_config(config)

        return cls(
            graph,
            config.data.relation_description,
            split,
            only_return_node_text,
            only_return_relation_text,
            for_plm
        )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        if self.only_return_node_text or self.only_return_relation_text:
            return self.data[index]
        else:
            return torch.cat([self.data[index], torch.tensor([index])], dim=0)
        
    def collate_fn(self, batch):
        if self.only_return_node_text or self.only_return_relation_text:
            return self._collate_fn_fot_text(batch)
        else:
            return self._collate_fn(batch)
    
    def _collate_fn_fot_text(self, batch):
        indexs, texts = zip(*batch)
        indexs = torch.tensor(indexs).view(-1)
        return indexs, texts, self.only_return_node_text, self.split
    
    def _collate_fn(self, batch):
        batch_size = len(batch)
        
        h, t, indexs = zip(*batch)
        r = [0] * batch_size
        
        h_text = [self.graph.get_textual_information(i.item(), ind=self.split_data_dict['inductive'])[1] for i in h]
        t_text = [self.graph.get_textual_information(i.item(), ind=self.split_data_dict['inductive'])[1] for i in t]
        r_text = [self.relation_description] * batch_size
        
        h, r, t = map(lambda x: torch.tensor(x), [h, r, t])
        
        edge_index = torch.cat([self.g.edge_index, self.g.edge_index[:, [1, 0]]], dim=0)
        edge_index = torch.cat([edge_index, torch.zeros_like(edge_index)[:, 0:1]], dim=1)[:, [0, 2, 1]]
        edge_index[self.g.edge_index.size(0)//2:, 1] = 1
        
        all_facts = edge_index[:, [0, 2]]
        if self.split == 'valid':
            all_facts = torch.cat([edge_index[:, [0, 2]], self.data], dim=0)
            test_facts = getattr(self.graph, 'test_split')['data']
            all_facts = torch.cat([all_facts, test_facts], dim=0)
        if self.split == 'test':
            all_facts = torch.cat([edge_index[:, [0, 2]], self.data], dim=0)
            valid_facts = getattr(self.graph, 'valid_split')['data']
            all_facts = torch.cat([all_facts, valid_facts], dim=0)

        filter_mask = torch.zeros(batch_size, self.g.num_nodes).long()
        for i in range(batch_size):
            filter_mask[i, all_facts[all_facts[:, 0] == h[i], 1]] = 1
        
        negative_target_node = self.split_data_dict.get('negative_target_node', None)
        if negative_target_node is not None:
            negative_target_node = [negative_target_node[index] for index in indexs]
            negative_target_node = torch.stack(negative_target_node, dim=0)
       
        return Query(
            data_index=indexs,
            h=h,
            r=r,
            t=t,
            h_text=h_text,
            r_text=r_text,
            t_text=t_text,
            edge_index=edge_index,
            filter_mask=filter_mask,
            negative_target_node=negative_target_node
        )       


class HeterogeneousDataset(Dataset):
    def __init__(self,
                 graph: HeterogeneousGraph,
                 split: str,
                 only_return_node_text: bool = False,
                 only_return_relation_text: bool = False,
                 for_plm: bool = False):
        super().__init__()
        
        self.graph = graph
        self.split = split
        self.only_return_node_text = only_return_node_text
        self.only_return_relation_text = only_return_relation_text
        self.for_plm = for_plm
        
        self.split_data_dict = getattr(self.graph, f'{self.split}_split')
        
        edge_index = self.split_data_dict['edge_index']

        self.g = Data(edge_index=edge_index[:, [0, 2]], edge_type=edge_index[:, 1], 
                      num_nodes=torch.max(edge_index.view(-1)) + 1)
        
        if self.only_return_node_text:
            self.data = [self.graph.get_node_textual_information(i, ind=self.split_data_dict['inductive']) 
                         for i in range(self.g.num_nodes)]
        elif self.only_return_relation_text:
            self.data = [self.graph.get_relation_textual_information(i, ind=self.split_data_dict['inductive']) 
                         for i in range(torch.max(self.g.edge_type) + 1)]
        else:
            self.data = self.split_data_dict['data']
            self.data = self.data[:self.data.size(0)//2]
        
    @classmethod
    def from_config(cls, 
                    config: DictConfig, 
                    split: str, 
                    only_return_node_text: bool = False, 
                    only_return_relation_text: bool = False, 
                    for_plm: bool = False):
        graph = HeterogeneousGraph.from_config(config)

        return cls(
            graph,
            split,
            only_return_node_text,
            only_return_relation_text,
            for_plm,
        )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        if self.only_return_node_text or self.only_return_relation_text:
            return self.data[index]
        else:
            return torch.cat([self.data[index], torch.tensor([index])], dim=0)

    def collate_fn(self, batch):
        if self.only_return_node_text or self.only_return_relation_text:
            return self._collate_fn_fot_text(batch)
        else:
            return self._collate_fn(batch)
    
    def _collate_fn_fot_text(self, batch):
        indexs, texts = zip(*batch)
        indexs = torch.tensor(indexs).view(-1)
        return indexs, texts, self.only_return_node_text, self.split
    
    def _collate_fn(self, batch):
        h, r, t, indexs = zip(*batch)
        def _reverse_relation(r):
            max_relationn = torch.max(self.g.edge_type) + 1
            return torch.where(
                r >= max_relationn // 2,
                r - max_relationn // 2,
                r + max_relationn // 2,
            )
            
        def _reverse_fact(fact):
            h, r, t = fact.split(1, dim=1)
            return torch.cat([
                t,
                _reverse_relation(r),
                h
            ], dim=1)
            
        h, r, t = map(lambda x: torch.tensor(x).view(-1), [h, r, t])
        
        if self.split == 'train' and not self.for_plm:
            h, r, t = (
                torch.cat([h.chunk(2)[0], t.chunk(2)[1]], dim=0),
                torch.cat([r.chunk(2)[0], _reverse_relation(r.chunk(2)[1])], dim=0),
                torch.cat([t.chunk(2)[0], h.chunk(2)[1]], dim=0)
            )
        else:
            h, r, t = (
                torch.cat([h, t], dim=0),
                torch.cat([r, _reverse_relation(r)], dim=0),
                torch.cat([t, h], dim=0)
            )

        batch_size = len(h)

        h_text = [self.graph.get_node_textual_information(i.item(), ind=self.split_data_dict['inductive'])[1] for i in h]
        t_text = [self.graph.get_node_textual_information(i.item(), ind=self.split_data_dict['inductive'])[1] for i in t]
        r_text = [self.graph.get_relation_textual_information(i.item(), ind=self.split_data_dict['inductive'])[1] for i in r]
        
        
        edge_index = torch.cat([self.g.edge_index, self.g.edge_type.unsqueeze(1)], dim=1)[:, [0, 2, 1]]
        
        all_facts = edge_index
        if self.split == 'valid':
            all_facts = torch.cat([edge_index, self.data, _reverse_fact(self.data)], dim=0)
            test_facts = getattr(self.graph, 'test_split')['data']
            all_facts = torch.cat([all_facts, test_facts, _reverse_fact(test_facts)], dim=0)
        if self.split == 'test':
            all_facts = torch.cat([edge_index, self.data, _reverse_fact(self.data)], dim=0)
            valid_facts = getattr(self.graph, 'valid_split')['data']
            all_facts = torch.cat([all_facts, valid_facts, _reverse_fact(valid_facts)], dim=0)
            
        filter_mask = torch.zeros(batch_size, self.g.num_nodes).long()
        key = all_facts[:, 1] * self.g.num_nodes + all_facts[:, 0]
        for i in range(batch_size):
            filter_mask[i, all_facts[key == r[i] * self.g.num_nodes + h[i]][:, 2]] = 1
            
        return Query(
            data_index=indexs,
            h=h,
            r=r,
            t=t,
            h_text=h_text,
            r_text=r_text,
            t_text=t_text,
            edge_index=edge_index,
            filter_mask=filter_mask
        ) 


class CoSTPLMDataModule(pl.LightningDataModule):
    def __init__(self, config: DictConfig, cotrain: bool):
        super().__init__()
        self.config = config
        self.cotrain = cotrain
        
    @classmethod
    def from_config(cls, config: DictConfig, cotrain: bool = False):
        return cls(config, cotrain)
        
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        cls = HomogeneousDataset if self.config.data.homogeneous else HeterogeneousDataset
        if stage == 'fit':
            self.train_dataset = cls.from_config(self.config, split='train', for_plm=True)
            self.valid_dataset = cls.from_config(self.config, split='valid', for_plm=True)
        elif stage == 'validate':
            self.valid_dataset = cls.from_config(self.config, split='valid', for_plm=True)
        elif stage == 'test':
            self.test_dataset = cls.from_config(self.config, split='test', for_plm=True)
        
    def _dataloader(self, dataset, batch_size, shuffle):
        return DataLoader(
            dataset,
            batch_size,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=dataset.collate_fn,
        )
        
    def train_dataloader(self):
        return self._dataloader(
            self.train_dataset,
            self.config.plm.pretrain.train_batch_size if not self.cotrain else self.config.plm.cotrain.train_batch_size,
            shuffle=True
        )
        
    def val_dataloader(self):
        return self._dataloader(
            self.valid_dataset,
            self.config.plm.pretrain.test_batch_size if not self.cotrain else self.config.plm.cotrain.test_batch_size,
            shuffle=False
        )
        
    def test_dataloader(self):
        return self._dataloader(
            self.test_dataset,
            self.config.plm.pretrain.test_batch_size if not self.cotrain else self.config.plm.cotrain.test_batch_size,
            shuffle=False
        )


class CoSTGNNDataModule(pl.LightningDataModule):
    def __init__(self, config: DictConfig, cotrain: bool):
        super().__init__()
        self.config = config
        self.cotrain = cotrain
        
    @classmethod
    def from_config(cls, config: DictConfig, cotrain: bool = False):
        return cls(config, cotrain)
        
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        cls = HomogeneousDataset if self.config.data.homogeneous else HeterogeneousDataset
        if stage == 'fit':
            self.train_dataset = cls.from_config(self.config, split='train')
            self.valid_dataset = cls.from_config(self.config, split='valid')
        elif stage == 'validate':
            self.valid_dataset = cls.from_config(self.config, split='valid')
        elif stage == 'test':
            self.test_dataset = cls.from_config(self.config, split='test')
        
    def _dataloader(self, dataset, batch_size, shuffle):
        return DataLoader(
            dataset,
            batch_size,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=dataset.collate_fn,
        )
        
    def train_dataloader(self):
        return self._dataloader(
            self.train_dataset,
            self.config.gnn.pretrain.train_batch_size if not self.cotrain else self.config.gnn.cotrain.train_batch_size,
            shuffle=True
        )
        
    def val_dataloader(self):
        return self._dataloader(
            self.valid_dataset,
            self.config.gnn.pretrain.test_batch_size if not self.cotrain else self.config.gnn.cotrain.test_batch_size,
            shuffle=False
        )
        
    def test_dataloader(self):
        return self._dataloader(
            self.test_dataset,
            self.config.gnn.pretrain.test_batch_size if not self.cotrain else self.config.gnn.cotrain.test_batch_size,
            shuffle=False
        )
        

class CoSTNodeTextDataModule(pl.LightningDataModule):
    def __init__(self, config: DictConfig):
        super().__init__()
        self.config = config

    @classmethod
    def from_config(cls, config: DictConfig):
        return cls(config)
        
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        cls = HomogeneousDataset if self.config.data.homogeneous else HeterogeneousDataset
        if stage == 'predict':
            self.train_dataset = cls.from_config(self.config, split='train', only_return_node_text=True)
            if getattr(self.train_dataset.graph, 'ind_node_text_dict', None) is not None or \
                getattr(self.train_dataset.graph, 'ind_text_dict', None) is not None:
                self.test_dataset = cls.from_config(self.config, split='test', only_return_node_text=True)
            else:
                self.test_dataset = None
        
    def _dataloader(self, dataset, shuffle):
        return DataLoader(
            dataset,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=self.train_dataset.collate_fn,
        )
    
    def predict_dataloader(self):
        dataset = ConcatDataset([d for d in [self.train_dataset, self.test_dataset] if d is not None])
        return self._dataloader(
            dataset,
            shuffle=False,
        )


class CoSTRelationTextDataModule(pl.LightningDataModule):
    def __init__(self, config: DictConfig):
        super().__init__()
        self.config = config

    @classmethod
    def from_config(cls, config: DictConfig):
        return cls(config)
        
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        cls = HomogeneousDataset if self.config.data.homogeneous else HeterogeneousDataset
        if stage == 'predict':
            self.train_dataset = cls.from_config(self.config, split='train', only_return_relation_text=True)
            if getattr(self.train_dataset.graph, 'ind_relation_text_dict', None) is not None or \
                getattr(self.train_dataset.graph, 'ind_text_dict', None) is not None:
                self.test_dataset = cls.from_config(self.config, split='test', only_return_relation_text=True)
            else:
                self.test_dataset = None
        
    def _dataloader(self, dataset, shuffle):
        return DataLoader(
            dataset,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=self.train_dataset.collate_fn,
        )
    
    def predict_dataloader(self):
        dataset = ConcatDataset([d for d in [self.train_dataset, self.test_dataset] if d])
        return self._dataloader(
            dataset,
            shuffle=False,
        )


class CoSTPseudoFactDataModule(pl.LightningDataModule):
    def __init__(self, config: DictConfig):
        super().__init__()
        self.config = config

    @classmethod
    def from_config(cls, config: DictConfig):
        return cls(config)
    
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        from torch.utils.data import TensorDataset
        cls = HomogeneousDataset if self.config.data.homogeneous else HeterogeneousDataset
        if stage == 'predict':
            train_dataset = cls.from_config(self.config, split='train')
            query = defaultdict()
            for i in range(len(train_dataset)):
                h, r, *_ = train_dataset[i]
                query[(h, r)] = 1
            query = torch.tensor(list(query))
            self.train_dataset = TensorDataset(query)
            self.edge_index = torch.cat([train_dataset.g.edge_index, train_dataset.g.edge_type.unsqueeze(1)], dim=1)[:, [0, 2, 1]]
            
    def _collate_fn(self, batch):
        h, r = map(lambda x: x.squeeze(-1), torch.stack([b[0] for b in batch], dim=0).chunk(2, dim=1))
        return h, r, self.edge_index
            
    def _dataloader(self, dataset, shuffle):
        return DataLoader(
            dataset,
            batch_size=32,
            shuffle=shuffle,
            num_workers=8,
            collate_fn=self._collate_fn
        )
    
    def predict_dataloader(self):
        return self._dataloader(
            self.train_dataset,
            shuffle=False,
        )







# @dataclass
# class Data:
#     h: torch.Tensor
#     r: torch.Tensor
#     t: torch.Tensor
#     x_ent: torch.Tensor
#     x_rel: torch.Tensor
#     edge_index: torch.Tensor
#     enhanced_edge_index: torch.Tensor
#     filter_mask: torch.Tensor
    
    
# class DataLoader(pt_DataLoader):
#     def __init__(self, *args, entid2text, relid2text, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.entid2text = entid2text
#         self.relid2text = relid2text
    

# class ReasonDataset(Dataset):
#     def __init__(self,
#                  data,
#                  entity2idx,
#                  idx2entity,
#                  relation2idx,
#                  idx2relation,
#                  entid2text,
#                  relid2text,
#                  filter_dict,
#                  edge_index):
#         super().__init__()
#         self.data = data
#         self.entity2idx = entity2idx
#         self.idx2entity = idx2entity
#         self.relation2idx = relation2idx
#         self.idx2relation = idx2relation
#         self.entid2text = entid2text
#         self.relid2text = relid2text
#         self.filter_dict = filter_dict
#         self.edge_index=edge_index
        
#         self.x_ent = None
#         self.x_rel = None
    
#     @classmethod
#     def from_config(cls, config: DictConfig):
#         entity2idx, idx2entity = dict(), dict()
#         with open(config.entity_path, 'r', encoding='utf-8') as fread:
#             for idx, line in enumerate(fread):
#                 entity2idx[line.strip()] = idx
#                 idx2entity[idx] = line.strip()
                      
#         relation2idx, idx2relation = dict(), dict()
#         with open(config.relation_path, 'r', encoding='utf-8') as fread:
#             for idx, line in enumerate(fread):
#                 relation2idx[line.strip()] = idx
#                 idx2relation[idx] = line.strip()
                
#         entid2text = dict()
#         with open(config.entity_text_path, 'r', encoding='utf-8') as fread:
#             for line in fread:
#                 ent, text = line.strip().split('\t')
#                 if ent in entity2idx:
#                     entid2text[entity2idx[ent]] = text
                    
#         relid2text = dict()
#         with open(config.relation_text_path, 'r', encoding='utf-8') as fread:
#             for line in fread:
#                 rel, text = line.strip().split('\t')
#                 if rel in relation2idx:
#                     relid2text[relation2idx[rel]] = text
#                     # add reverse relations
#                     relid2text[relation2idx[rel] + len(relation2idx)] = f'reverse of {text}'
        
#         data = list()
#         with open(config.data_path, 'r', encoding='utf-8') as fread:
#             for line in fread:
#                 h, r, t = line.strip().split('\t')
#                 data.append((entity2idx[h], relation2idx[r], entity2idx[t]))
#                 data.append((entity2idx[t], relation2idx[r] + len(relation2idx), entity2idx[h]))
        
#         filter_dict = defaultdict(set)
#         for path in config.filter_path:
#                 with open(path, 'r', encoding='utf-8') as fread:
#                     for line in fread:
#                         h, r, t = line.strip().split('\t')
#                         filter_dict[(entity2idx[h], relation2idx[r])].add(entity2idx[t])
#                         filter_dict[(entity2idx[t], relation2idx[r] + len(relation2idx))].add(entity2idx[h])
        
#         edge_index = list()
#         with open(config.graph_path, 'r', encoding='utf-8') as fread:
#             for line in fread:
#                 h, r, t = line.strip().split('\t')
#                 edge_index.append((entity2idx[h], relation2idx[r], entity2idx[t]))
#                 edge_index.append((entity2idx[t], relation2idx[r] + len(relation2idx), entity2idx[h]))    
        
#         return cls(
#             data,
#             entity2idx,
#             idx2entity,
#             relation2idx,
#             idx2relation,
#             entid2text,
#             relid2text,
#             filter_dict,
#             edge_index
#         )
        
#     def __len__(self):
#         return len(self.data)
    
#     def __getitem__(self, index):
#         h, r, t = self.data[index]
        
#         filter_mask = torch.zeros(len(self.entid2idx))
#         filter_mask[list(self.filter_dict[(h, r)])] = 1.0
        
#         return h, r, t, filter_mask
    
#     def collate_fn(self, unbatched_data):
#         _h, _r, _t, _filter_mask = zip(*unbatched_data)
        
#         h, r, t = map(lambda x: torch.tensor(x), (_h, _r, _t))
#         filter_mask = torch.cat(_filter_mask, dim=0)
        
#         if self.x_ent:
#             x_ent = einops.rearrange(self.x_ent, 'n d -> b n d', b=len(unbatched_data))
#         else:
#             x_ent = None
#         if self.x_rel:
#             x_rel = einops.rearrange(self.x_rel, 'n d -> b n d', b=len(unbatched_data))
#         else:
#             x_rel = None
        
#         edge_index = torch.tensor(self.edge_index)
#         enhanced_edge_index = torch.tenosr(self.edge_index + list(zip(*[_h, _r, _t])))
        
#         return Data(
#             h,
#             r,
#             t,
#             x_ent,
#             x_rel,
#             edge_index,
#             enhanced_edge_index,
#             filter_mask
#         )
    
#     def _whitening(self, emb):
#         emb = emb.cpu().numpy()
        
#         mean = np.mean(emb, axis=0, keepdims=True)
#         cov = np.cov(emb.T)
#         u, s, vh = np.linalg.svd(cov)
#         kernel, bias = np.dot(u, np.diag(1. / np.sqrt(s))), -mean
#         kernel = kernel[:, :self.config.gnn.model.input_dim]
#         emb_whitening = (emb + bias).dot(kernel)
        
#         emb_whitening = torch.from_numpy(emb_whitening).to(self.device)
#         emb_whitening = F.normalize(emb_whitening, p=2, dim=-1)
#         return emb_whitening
    
#     def get_ent_rel_embeddings(self, model):
#         model.eval()
        
#         entity_embeddings = []
#         for i in range(len(self.entid2text), 128):
#             with torch.no_grad():
#                 entity_embeddings.append(
#                     model.encode([self.entid2text[j] for j in range(i, i + 128)], model.entity_encoder)
#                 )
#         entity_embeddings = torch.cat(entity_embeddings, dim=0)
#         self.x_ent = self._whitening(entity_embeddings)
        
#         relation_embeddings = []
#         for i in range(len(self.relid2text), 128):
#             with torch.no_grad():
#                 relation_embeddings.append(
#                     model.encode([self.relid2text[j] for j in range(i, i + 128)], model.entity_encoder)
#                 )
#         relation_embeddings = torch.cat(relation_embeddings, dim=0)
#         self.x_rel = self._whitening(relation_embeddings)


# class ReasonDataModule(pl.LightningDataModule):
#     def __init__(self, config: DictConfig):
#         super().__init__()
#         self.config = config
        
#         self.train_datasets = list()
#         for train_dataset_config in self.config.data.train:
#             train_dataset = ReasonDataset.from_config(train_dataset_config)
#             self.train_datasets.append(train_dataset)
            
#         self.valid_datasets = list()
#         for valid_dataset_config in self.config.data.valid:
#             valid_dataset = ReasonDataset.from_config(valid_dataset_config)
#             self.valid_datasets.append(valid_dataset)
            
#         self.test_datasets = list()
#         for test_dataset_config in self.config.data.test:
#             test_dataset = ReasonDataset.from_config(test_dataset_config)
#             self.test_datasets.append(test_dataset)
        
#     @classmethod
#     def from_config(cls, config: DictConfig, tokenizer):
#         return cls(config, tokenizer)
        
#     def prepare_data(self):
#         pass

#     def setup(self, stage: str):
#         pass
        
#     def _dataloader(self, dataset, batch_size, shuffle):
#         return DataLoader(
#             dataset,
#             batch_size,
#             shuffle=shuffle,
#             num_workers=8,
#             collate_fn=dataset.collate_fn,
#             entid2text=dataset.entid2text,
#             relid2text=dataset.relid2text
#         )
        
#     def train_dataloader(self):
#         return CombinedLoader([self._dataloader(dataset, True) for dataset in self.train_datasets], 'sequential')
        
#     def val_dataloader(self):
#         return CombinedLoader([self._dataloader(dataset, False) for dataset in self.valid_datasets], 'sequential')
        
#     def test_dataloader(self):
#         return CombinedLoader([self._dataloader(dataset, False) for dataset in self.test_datasets], 'sequential')
    
#     def set_all_ent_rel_embeddings(self, model):
#         for dataset in self.train_datasets + self.valid_datasets + self.test_datasets:
#             dataset.get_ent_rel_embeddings(model)
