from torch.utils.data import DataLoader, Dataset as ptDataset
from graph_learning.tasker import Tasker, TaskerConfig, DataloaderTasker
from graph_learning.utils import merge_metrics
from graph_learning.dataset.graph import gl_batch

@TaskerConfig.register('static-graph-dl',
                       help='[Dataloader] for graph list.')
class GraphDLConfig(TaskerConfig):
    @property
    def builder(self):
        return StaticGraphDL

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--batch-size', type=int)

class StaticGraphDL(DataloaderTasker):
    class Dataset(ptDataset):
        def __init__(self, data, mode, transformer,
                     use_batch):
            self.data = data
            self.use_batch = use_batch
            self.transformer = transformer
            def flag_batch(_, batch):
                return batch[0]
            if isinstance(self.data, list):
                for data in self.data:
                    data.gdata['mode'] = mode
                    data.add_batch_schema('mode', flag_batch)
            else:
                self.data.gdata['mode'] = mode
                self.data.add_batch_schema('mode', flag_batch)

        def __len__(self):
            if isinstance(self.data, list):
                return len(self.data)
            else:
                return 1

        def __getitem__(self, idx):
            if isinstance(self.data, list):
                data = self.data[idx]
            else:
                data = self.data
            return self.transformer.transform(data)

        def collate_fn(self, batch):
            if self.use_batch:
                b = gl_batch(batch)
                return b
            else:
                return batch[0]

    def __init__(self, batch_size):
        self.batch_size = batch_size

    def train_dataloader(self):
        data = self.data
        if hasattr(data, 'train_index'):
            data = data[:data.train_index]
        dataset = self.Dataset(
            data, 'train', self.transformer,
            use_batch = self.batch_size is not None)
        batch_size = self.batch_size if self.batch_size is not None else 1
        return DataLoader(dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=dataset.collate_fn,
                          pin_memory=True)

    def valid_dataloader(self):
        data = self.data
        if hasattr(data, 'train_index'):
            data = data[data.train_index : data.valid_index]
        dataset = self.Dataset(data, 'valid', self.transformer,
            use_batch = self.batch_size is not None)
        batch_size = self.batch_size if self.batch_size is not None else 1
        return DataLoader(dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=dataset.collate_fn)


    def test_dataloader(self):
        data = self.data
        if hasattr(data, 'valid_index'):
            data = data[data.valid_index:]
        dataset = self.Dataset(data, 'test', self.transformer,
            use_batch = self.batch_size is not None)
        batch_size = self.batch_size if self.batch_size is not None else 1
        return DataLoader(dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=dataset.collate_fn)

    # Epoch end
    def valid_end(self, outputs):
        return merge_metrics(outputs)

    def test_end(self, outputs):
        return merge_metrics(outputs)

class GraphDataMixin():
    # Interfaces
    def graph(self, data):
        return data

    def node_seed_labels(self, data):
        return self.graph(data).ndata['seed_labels']

    def node_feature(self, data):
        return self.graph(data).ndata['x']

    def edge_feature(self, data):
        return self.graph(data).edata['x']
