import os
from argparse import Namespace
from pathlib import Path
from typing import Callable, Optional, Tuple, Union, Dict, Any, Text
import urllib.request

import torchvision
from torchvision import transforms as T
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from PIL import Image

from torch_geometric.datasets import Planetoid
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as gT

from equislt.data.datasets.mnist_fliprot.data_loader_mnist_fliprot import build_mnist_rot_loader as build_mnist_fliprot_loader
from equislt.data.datasets.mnist_rot.data_loader_mnist_rot import build_mnist_rot_loader

from equislt.data.graph_data_loader import GraphDataLoader, GRAPH_NUM_CLASSES

def prepare_data(
    dataset: str,
    data_dir: Optional[Union[str, Path]] = None,
    batch_size: int = 1,
    num_workers: int = 0, device=None,
    download: bool = True,
) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[Text, Any]]:
    if data_dir is None:
        sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
        data_dir = sandbox_dir / "datasets"
    else:
        data_dir = Path(data_dir)

    if dataset in ['Cora', 'CiteSeer']:
        transform = gT.NormalizeFeatures()
        dataset = Planetoid(data_dir, dataset,
                            split='public',
                            transform=transform,
                            )
        data = dataset[0].to(device)

        class SameLoader(object):
            def __init__(self, steps_per_epoch=100):
                self.steps_per_epoch = steps_per_epoch
                self.index = 0

            def __iter__(self):
                self.index = 0
                return self

            def __next__(self):
                if self.index >= self.steps_per_epoch:
                    raise StopIteration
                self.index += 1
                return data

        train_loader = SameLoader(10)
        val_loader = test_loader = DataLoader(dataset, batch_size=1, num_workers=0)
        metadata = dict(num_features=dataset.num_features, num_classes=dataset.num_classes)
    elif dataset == 'RotMNIST':
        train_loader, _, _ = build_mnist_rot_loader('train', batch_size, num_workers,
                                                    rot_interpol_augmentation=True,
                                                    interpolation=3)
        val_loader, _, _ = build_mnist_rot_loader('valid', batch_size, num_workers,
                                                  rot_interpol_augmentation=False)
        test_loader, _, _ = build_mnist_rot_loader('test', batch_size, num_workers,
                                                   rot_interpol_augmentation=False)
        metadata = dict(feature_size=(1, 28, 28), num_classes=10)
    elif dataset == 'FlipRotMNIST':
        train_loader, _, _ = build_mnist_fliprot_loader('train', batch_size, num_workers,
                                                        rot_interpol_augmentation=True,
                                                        interpolation=3)
        val_loader, _, _ = build_mnist_fliprot_loader('valid', batch_size, num_workers,
                                                      rot_interpol_augmentation=False)
        test_loader, _, _ = build_mnist_fliprot_loader('test', batch_size, num_workers,
                                                       rot_interpol_augmentation=False)
        metadata = dict(feature_size=(1, 28, 28), num_classes=10)
    elif dataset in ['MUTAG', 'PTC', 'PROTEINS', "IMDBBINARY", "NCI1"]:
        config = Namespace()
        config.dataset_name = dataset
        config.num_fold = 10
        config.batch_size = batch_size
        train_loader = GraphDataLoader(data_dir, config, is_train=True)
        val_loader = GraphDataLoader(data_dir, config, is_train=False)
        test_loader = val_loader
        metadata = dict(feature_size=train_loader.train_graphs[0].shape[0],
                        num_classes=GRAPH_NUM_CLASSES[dataset])
    else:
        raise NotImplementedError()

    return train_loader, val_loader, test_loader, metadata
