import os
import numpy as np

from federatedscope.register import register_data

# Run with mini_graph_dt:
# python federatedscope/main.py --cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml --client_cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml
# Test Accuracy: ~0.7


def load_mini_graph_dt(config, client_cfgs=None):
    import torch
    from torch_geometric.data import InMemoryDataset, Data
    from torch_geometric.datasets import TUDataset, MoleculeNet
    from federatedscope.core.splitters.graph.scaffold_lda_splitter import \
        GenFeatures
    from federatedscope.core.data import DummyDataTranslator

    class MiniGraphDCDataset(InMemoryDataset):
        NAME = 'mini_graph_dt'
        DATA_NAME = ['BACE', 'BBBP', 'CLINTOX', 'ENZYMES', 'PROTEINS_full']
        IN_MEMORY_DATA = {}

        def __init__(self, root, splits=[0.8, 0.1, 0.1]):
            self.root = root
            self.splits = splits
            super(MiniGraphDCDataset, self).__init__(root)

        @property
        def processed_dir(self):
            return os.path.join(self.root, self.NAME, 'processed')

        @property
        def processed_file_names(self):
            return ['pre_transform.pt', 'pre_filter.pt']

        def __len__(self):
            return len(self.DATA_NAME)

        def __getitem__(self, idx):
            if idx not in self.IN_MEMORY_DATA:
                self.IN_MEMORY_DATA[idx] = {}
                for split in ['train', 'val', 'test']:
                    split_data = self._load(idx, split)
                    if split_data:
                        self.IN_MEMORY_DATA[idx][split] = split_data
            return self.IN_MEMORY_DATA[idx]

        def _load(self, idx, split):
            try:
                data = torch.load(
                    os.path.join(self.processed_dir, str(idx), f'{split}.pt'))
            except:
                data = None
            return data

        def process(self):
            np.random.seed(0)
            for idx, name in enumerate(self.DATA_NAME):
                if name in ['BACE', 'BBBP', 'CLINTOX']:
                    dataset = MoleculeNet(self.root, name)
                    featurizer = GenFeatures()
                    ds = []
                    for graph in dataset:
                        graph = featurizer(graph)
                        ds.append(
                            Data(edge_index=graph.edge_index,
                                 x=graph.x,
                                 y=graph.y))
                    dataset = ds
                    if name in ['BACE', 'BBBP']:
                        for i in range(len(dataset)):
                            dataset[i].y = dataset[i].y.long()
                    if name in ['CLINTOX']:
                        for i in range(len(dataset)):
                            dataset[i].y = torch.argmax(
                                dataset[i].y).view(-1).unsqueeze(0)
                else:
                    # Classification
                    dataset = TUDataset(self.root, name)
                    dataset = [
                        Data(edge_index=graph.edge_index, x=graph.x, y=graph.y)
                        for graph in dataset
                    ]

                # We fix train/val/test
                index = np.random.permutation(np.arange(len(dataset)))
                train_idx = index[:int(len(dataset) * self.splits[0])]
                valid_idx = index[int(len(dataset) * self.splits[0]):int(
                    len(dataset) * sum(self.splits[:2]))]
                test_idx = index[int(len(dataset) * sum(self.splits[:2])):]

                if not os.path.isdir(os.path.join(self.processed_dir,
                                                  str(idx))):
                    os.makedirs(os.path.join(self.processed_dir, str(idx)))

                train_path = os.path.join(self.processed_dir, str(idx),
                                          'train.pt')
                valid_path = os.path.join(self.processed_dir, str(idx),
                                          'val.pt')
                test_path = os.path.join(self.processed_dir, str(idx),
                                         'test.pt')

                torch.save([dataset[i] for i in train_idx], train_path)
                torch.save([dataset[i] for i in valid_idx], valid_path)
                torch.save([dataset[i] for i in test_idx], test_path)

                print(name, len(dataset), dataset[0])

        def meta_info(self):
            return {
                'BACE': {
                    'task': 'classification',
                    'input_dim': 74,
                    'output_dim': 2,
                    'num_samples': 1513,
                },
                'BBBP': {
                    'task': 'classification',
                    'input_dim': 74,
                    'output_dim': 2,
                    'num_samples': 2039,
                },
                'CLINTOX': {
                    'task': 'classification',
                    'input_dim': 74,
                    'output_dim': 2,
                    'num_samples': 1478,
                },
                'ENZYMES': {
                    'task': 'classification',
                    'input_dim': 3,
                    'output_dim': 6,
                    'num_samples': 600,
                },
                'PROTEINS_full': {
                    'task': 'classification',
                    'input_dim': 3,
                    'output_dim': 2,
                    'num_samples': 1113,
                },
            }

    dataset = MiniGraphDCDataset(config.data.root)
    # Convert to dict
    datadict = {
        client_id + 1: dataset[client_id]
        for client_id in range(len(dataset))
    }
    config.merge_from_list(['federate.client_num', len(dataset)])
    translator = DummyDataTranslator(config, client_cfgs)

    return translator(datadict), config


def call_mini_graph_dt(config, client_cfgs):
    if config.data.type == "mini-graph-dc":
        data, modified_config = load_mini_graph_dt(config, client_cfgs)
        return data, modified_config


register_data("mini-graph-dc", call_mini_graph_dt)
