import torch
import numpy as np
import networkx as nx
from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.seed import seed_everything

import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch_geometric.data import Data, Batch


from .voc_superpixel import VOCSuperpixels
from .recGNN import PrefixSumK, Trees, DistanceK
from .convergecast import MajoritySubTree, TopK, Broadcast, BroadcastK, MajorityTree
from .expressivity import LimitsOne, LimitsTwo, LCC, FourCycles, SkipCircles, Triangles

class CustomGraphDataset(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size: int = 1,
        shuffle: bool = False,
        **kwargs,
    ):
        super().__init__(dataset, batch_size, shuffle,
                         collate_fn=self.collate, **kwargs)
    def collate(self, data_list):
        batch = Batch.from_data_list(data_list)
        return batch

class GraphDataModule(pl.LightningDataModule):
    def __init__(self, train_list, val_list, test_list, batch_size=1):
        super(GraphDataModule, self).__init__()
        self.train_list = train_list
        self.val_list = val_list
        self.test_list = test_list

        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_data = CustomGraphDataset(self.train_list,batch_size=self.batch_size,shuffle=True,num_workers=4)
        self.val_data   = CustomGraphDataset(self.val_list,batch_size=self.batch_size,shuffle=False,num_workers=4)
        self.test_data  = CustomGraphDataset(self.test_list,batch_size=self.batch_size,shuffle=False,num_workers=4)

    def train_dataloader(self):
        return self.train_data
    def val_dataloader(self):
        return self.val_data
    def test_dataloader(self):
        return self.test_data


def get_lightning_dataset(dataset_name, config):
    train, val, test = get_dataset(dataset_name, config)

    return GraphDataModule(train, val, test, batch_size=config.batch_size)

def get_dataset(dataset_name, config):

    if dataset_name == "prefix":
        seed_everything(1)
        dataset = PrefixSumK(k = config.mod_k)
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "path-finding":
        print("path-finding")
        seed_everything(1)
        dataset = Trees()
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "distance":
        seed_everything(1)
        dataset = DistanceK(k = config.mod_k)
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "majority-subtree":      
        seed_everything(1)
        dataset = MajoritySubTree()
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "majority-tree":
        seed_everything(1)
        dataset = MajorityTree()
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "topk":      
        seed_everything(1)
        dataset = TopK(k = config.mod_k)
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)
    
    elif dataset_name == "broadcast":      
        seed_everything(1)
        dataset = Broadcast()
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "broadcast-k":      
        seed_everything(1)
        dataset = BroadcastK(k = config.mod_k)
        train = dataset.makedata(config.train_size, config.train_graph_size)
        val = dataset.makedata(config.val_size, config.val_graph_size)
        test = dataset.makedata(config.test_size, config.test_graph_size)

    elif dataset_name == "voc-superpixel":
        seed_everything(1)
        train = VOCSuperpixels(root = config.datadir, split='train')
        val = VOCSuperpixels(root = config.datadir,split='val')
        test = VOCSuperpixels(root = config.datadir,split='test')
    
    #Note for the following expressivity datasets, they are often comprised of just a single graph,
    #so the train, val and test dataset is the same. 
    elif dataset_name == 'Limits-1':
        seed_everything(1)
        dataset = LimitsOne()
        train = dataset.makedata(size=10)
        val = dataset.makedata()
        test = dataset.makedata()

    elif dataset_name == 'Limits-2':
        seed_everything(1)
        dataset = LimitsTwo()
        train = dataset.makedata(size=10)
        val = dataset.makedata()
        test = dataset.makedata()
    
    elif dataset_name == 'LCC':
        seed_everything(1)
        dataset = LCC()
        train = dataset.makedata('flood_echo/datasets/LCC_dataset/LCC_train.txt')[:800]
        val = dataset.makedata('flood_echo/datasets/LCC_dataset/LCC_train.txt')[800:]
        test = dataset.makedata('flood_echo/datasets/LCC_dataset/LCC_test.txt')
    
    elif dataset_name == '4-Cycles':
        seed_everything(1)
        dataset = FourCycles()
        train = dataset.makedata(size = 800, p=4, vary_sizes=False)
        val = dataset.makedata(size = 200, p=4)
        test = dataset.makedata(size = 200, p=4)

    elif dataset_name == 'Skip-Circles':
        seed_everything(1)
        dataset = SkipCircles()
        train = dataset.makedata()
        val = dataset.makedata()
        test = dataset.makedata()   
    
    elif dataset_name == 'Triangles':
        seed_everything(1)
        dataset = Triangles()
        train = dataset.makedata('flood_echo/datasets/triangles_dataset/triangle_train.txt')[:800]
        val = dataset.makedata('flood_echo/datasets/triangles_dataset/triangle_train.txt')[800:]
        test = dataset.makedata('flood_echo/datasets/triangles_dataset/triangle_test.txt')
    
    return train, val, test





