import os
import pickle
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import DataLoader, InMemoryDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import Compose
from trainable_scattering.transforms.transforms import (
    Eccentricity,
    ClusteringCoefficient,
)
from trainable_scattering.models.fast_scatter import (
    FastScatterTransform,
    FastScatterTransformSort,
)
from joblib import Parallel, delayed

TU_DATASETS = [
    "NCI1",
    "NCI109",
    "DD",
    "PROTIENS",
    "PROTEINS",
    "MUTAG",
    "PTC",
    "ENZYMES",
    "REDDIT-BINARY",
    "REDDIT-MULTI-12K",
    "IMDB-BINARY",
    "IMDB-MULTI",
    "COLLAB",
    "REDDIT-MULTI-5K",
]


class CaspDataset(InMemoryDataset):
    def __init__(
        self, root, transform=None, pre_transform=None, pre_filter=None, raw_dir=None, targets=False
    ):
        self.supplied_raw_dir = raw_dir
        self.targets = targets
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices, self.splits = torch.load(self.processed_paths[0])
        self.data.y = self.data.gdt
        print(self.data.gdt)
        self.slices["y"] = self.slices["gdt"]
        if len(self.data.x.shape) == 1:
            self.data.x = self.data.x[
                :, None
            ].float()  # if only 1 feature reshape to be N x Features

    def process(self):
        with open(os.path.join(self.raw_dir, "casp_11_decoys.pkl"), "rb") as f:
            casp11 = pickle.load(f)
        with open(os.path.join(self.raw_dir, "casp_12_decoys.pkl"), "rb") as f:
            casp12 = pickle.load(f)
        with open(os.path.join(self.raw_dir, "casp_13_decoys.pkl"), "rb") as f:
            casp13 = pickle.load(f)
        self.splits = [len(casp11), len(casp12), len(casp13)]
        data_list = casp11 + casp12 + casp13
        # data_list = data_list[:10]

        # all elements should be tensors
        for dd in data_list:
            dd.x = torch.tensor(dd.x)
        self.data, self.slices = self.collate(data_list)
        if self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [self.pre_transform(data) for data in data_list]
            # data_list = Parallel(n_jobs=20, verbose=1)(delayed(self.pre_transform)(data) for data in data_list)
            self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices, self.splits), self.processed_paths[0])

    def get_split(self):
        print("splits of zero", self.splits[0])
        return (
            self[: self.splits[0]],
            self[self.splits[0] : self.splits[0] + self.splits[1]],
            self[-self.splits[2] :],
        )

    @property
    def raw_dir(self):
        if self.supplied_raw_dir:
            return self.supplied_raw_dir
        return os.path.join(self.root, "raw")

    @property
    def processed_dir(self):
        return os.path.join(self.root, "processed")

    @property
    def num_node_labels(self):
        if self.data.x is None:
            return 0
        for i in range(self.data.x.size(1)):
            x = self.data.x[:, i:]
            if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
                return self.data.x.size(1) - i
        return 0

    @property
    def num_node_features(self):
        if self.data.x is None:
            return 0
        return self.data.x.size(1) - self.num_node_labels

    @property
    def processed_file_names(self):
        return "data.pt"

    def __repr__(self):
        return "{}({})".format("CASP", len(self))


class CaspDatasetV2(CaspDataset):
    def process(self):
        if self.targets:
            with open(os.path.join(self.raw_dir, "casp_12_targets.pkl"), "rb") as f:
                casp12 = pickle.load(f)
        else:
            with open(os.path.join(self.raw_dir, "casp_12_decoys.pkl"), "rb") as f:
                casp12 = pickle.load(f)
        n = len(casp12)
        self.splits = [n - 2 * int(n * 0.1), int(n * 0.1), int(n * 0.1)]
        np.random.shuffle(casp12)
        data_list = casp12
        # all elements should be tensors
        for dd in data_list:
            dd.x = torch.tensor(dd.x)
        self.data, self.slices = self.collate(data_list)
        if self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]
            data_list = [self.pre_transform(data) for data in data_list]
            self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices, self.splits), self.processed_paths[0])


def get_transform(name, device):
    if name == "eccentricity":
        transform = Eccentricity()
    elif name == "clustering_coefficient":
        transform = ClusteringCoefficient()
    elif name == "scatter":
        transform = Compose([Eccentricity(), ClusteringCoefficient(cat=True)])
    elif name == "scatter_cat":
        transform = Compose([Eccentricity(cat=True), ClusteringCoefficient(cat=True)])
    elif name == "fast_scatter":
        transform = Compose(
            [
                Eccentricity(),
                ClusteringCoefficient(cat=True),
                FastScatterTransform(device),
            ]
        )
    elif name == "fast_scatter_cat":
        transform = Compose(
            [
                Eccentricity(cat=True),
                ClusteringCoefficient(cat=True),
                FastScatterTransform(device),
            ]
        )

    elif name == "fast_scatter_sort":
        transform = Compose(
            [
                Eccentricity(),
                ClusteringCoefficient(cat=True),
                FastScatterTransformSort(device),
            ]
        )
    elif name == "none":
        transform = None
    else:
        raise NotImplementedError("Unknown transform %s" % name)
    return transform


def old_split_dataset(
    dataset, splits=(0.8, 0.1), test_seed=0, val_seed=0, cv=True, shuffle=True
):
    """ Splits data into non-overlapping datasets of given proportions.

    if cv then only the initial shuffle is random and the test_seed and val_seed are partition numbers

    TODO: (anonymous) sometimes produces unexpected splits on due to rounding errors
    """
    n = len(dataset)
    if cv:
        rand_ind = np.arange(n)
        if shuffle:
            np.random.seed(42)
            np.random.shuffle(rand_ind)
    else:
        raise NotImplementedError

    test_split_begin = int(np.floor((1 - splits[0] - splits[1]) * n * test_seed))
    test_split_end = int(np.floor((1 - splits[0] - splits[1]) * n * (test_seed + 1)))
    test_mask = (np.arange(n) >= test_split_begin) & (np.arange(n) < test_split_end)

    train_val_ind, test_ind = rand_ind[~test_mask], rand_ind[test_mask]

    train_size = len(train_val_ind)
    val_split_begin = int(np.floor(splits[1] * train_size * val_seed))
    val_split_end = int(np.floor(splits[1] * train_size * (val_seed + 1)))
    val_mask = (np.arange(train_size) >= val_split_begin) & (
        np.arange(train_size) < val_split_end
    )

    train_ind, val_ind = train_val_ind[~val_mask], train_val_ind[val_mask]
    assert len(val_ind) > 0
    assert len(train_ind) > 0
    assert len(test_ind) > 0
    return dataset[list(train_ind)], dataset[list(val_ind)], dataset[list(test_ind)]


def split_dataset(
    dataset,
    splits=(8, 1, 1),
    test_seed=0,
    val_seed=0,
    shuffle=True,
    rand_test_splits=False,
    rand_val_splits=False,
):
    assert sum(splits) == 10
    n = len(dataset)
    rand_ind = np.arange(n)
    if shuffle:
        np.random.seed(42)
        np.random.shuffle(rand_ind)
    # partitions indices into 10 pieces, then distributes these to train, val, test
    partitioned_ind = np.array_split(rand_ind, 10)
    if rand_test_splits:
        np.random.seed(test_seed)
        np.random.shuffle(partitioned_ind)
    else:
        # preserves the behavior of fully covering the dataset with test
        # e.g. for 7, 1, 2, with test seeds [0 ... 4] each datapoint will occur
        # in the test set exactly once
        partitioned_ind = np.roll(partitioned_ind, shift=test_seed * splits[2], axis=0)

    # partition out the test indices first to deal with the train_val split next
    test_ind = np.concatenate(partitioned_ind[: splits[2]])
    partitioned_ind = partitioned_ind[splits[2] :]
    if rand_val_splits:
        np.random.seed(val_seed)
        np.random.shuffle(partitioned_ind, axis=0)
    else:
        partitioned_ind = np.roll(partitioned_ind, shift=val_seed * splits[1], axis=0)

    train_ind = np.concatenate(partitioned_ind[: splits[0]])
    val_ind = np.concatenate(partitioned_ind[-splits[1] :])

    assert (set(val_ind) | set(train_ind) | set(test_ind)) == set(np.arange(n))
    assert len(set(val_ind) & set(train_ind) & set(test_ind)) == 0
    return dataset[list(train_ind)], dataset[list(val_ind)], dataset[list(test_ind)]

def test_split_dataset():
    dataset = np.arange(10)
    for splits in [(2,1,7)]:
        for test_seed in range(10):
            for val_seed in range(1):
                print(split_dataset(dataset, splits, test_seed, val_seed))



def get_dataset(args, device):
    if "transform" in args:
        transform = get_transform(args["transform"], device)
    else:
        transform = None

    if args["dataset"] in TU_DATASETS:
        name = args["dataset"]
        # Munge exact names
        if name == "PROTIENS":
            name = "PROTEINS"
        if name == "PTC":
            name = "PTC_MR"
        import socket
        if socket.gethostname() == "galloway":
            root = "/orkney/data/anonymous/tu/%s" % args["transform"]
        else:
            root = "/data/anonymous/tu/%s" % args["transform"]
        dataset = TUDataset(
            root=root,
            name=name,
            pre_transform=transform,
            use_node_attr=True,
        )
        train_ds, val_ds, test_ds = split_dataset(
            dataset,
            args["splits"],
            test_seed=args["test_seed"],
            val_seed=args["val_seed"],
        )
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    # TODO fix this datasete
    elif args["dataset"] in ["ogbg-molhiv"]:
        from ogb.graphproppred import PygGraphPropPredDataset

        dataset = PygGraphPropPredDataset(name=args["dataset"])
        dataset.num_edge_attributes = 3
        split_idx = dataset.get_idx_split()
        train_ds = dataset[split_idx["train"]]
        test_ds = dataset[split_idx["test"]]
        val_ds = dataset[split_idx["valid"]]
        train_loader = DataLoader(
            dataset[split_idx["train"]], batch_size=32, shuffle=True, num_workers=8
        )
        valid_loader = DataLoader(
            dataset[split_idx["valid"]], batch_size=32, shuffle=False
        )
        test_loader = DataLoader(
            dataset[split_idx["test"]], batch_size=32, shuffle=False
        )
    elif args["dataset"] in ["MNISTSuperpixels"]:
        dataset = MNISTSuperpixels(
            root="/data/anonymous/pytorch_geometric_datasets/MNIST/%s" % args["transform"],
            pre_transform=transform,
        )
        train_ds, val_ds, test_ds = split_dataset(
            dataset,
            args["splits"],
            test_seed=args["test_seed"],
            val_seed=args["val_seed"],
        )
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    elif args["dataset"] in ["CASP"]:
        path = "/home/egbert/data/protein_data/casp/processed_data4/"
        root = "/data/anonymous/casp/%s" % args["transform"]
        dataset = CaspDataset(root=root, raw_dir=path, pre_transform=transform)
        train_ds, val_ds, test_ds = dataset.get_split()

        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    elif args["dataset"] in ["CASP2"]:
        import socket
        if socket.gethostname() == "galloway":
            path = "/orkney/home/egbert/data/protein_data/casp/processed_data4/"
            root = "/orkney/data/anonymous/casp2/%s" % args["transform"]
        else:
            path = "/home/egbert/data/protein_data/casp/processed_data4/"
            root = "/data/anonymous/casp2/%s" % args["transform"]
        dataset = CaspDatasetV2(root=root, raw_dir=path, pre_transform=transform)
        train_ds, val_ds, test_ds = dataset.get_split()
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    elif args["dataset"] in ["CASP2_targets"]:
        import socket
        if socket.gethostname() == "galloway":
            path = "/orkney/home/egbert/data/protein_data/casp/processed_data4/"
            root = "/orkney/data/anonymous/casp2_targets/%s" % args["transform"]
        else:
            path = "/home/egbert/data/protein_data/casp/processed_data4/"
            root = "/data/anonymous/casp2_targets/%s" % args["transform"]
        dataset = CaspDatasetV2(root=root, raw_dir=path, pre_transform=transform, targets=True)
        train_ds, val_ds, test_ds = dataset.get_split()
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
    else:
        raise NotImplementedError("Dataset %s not implemented" % args["dataset"])
    return train_ds, val_ds, test_ds, train_loader, dataset


if __name__ == "__main__":
    test_split_dataset()
    exit()
    path = "/home/egbert/data/protein_data/casp/processed_data4/"
    # root = "/data/anonymous/casp/%s" % args["transform"]
    with open(os.path.join(path, "casp_11_decoys.pkl"), "rb") as f:
        casp11 = pickle.load(f)
    ds = get_dataset(
        {
            # "dataset": "IMDB-BINARY",
            "dataset": "CASP",
            "transform": "none",
            "splits": [8, 1, 1],
            "test_seed": 0,
            "val_seed": 0,
        },
        "cpu",
    )
    print(ds[0].data.x.shape)
    print(type(ds[0].data.x))
