import pickle, random
import torch
from copy import deepcopy
from typing import Tuple

import numpy as np
from torch_geometric.utils.convert import to_networkx
from graphxai_local.utils import Explanation

from torch_geometric.data import Dataset, data
from torch_geometric.loader import DataLoader

# from torch.utils.data.dataloader import DataLoader
from torch_geometric.utils import k_hop_subgraph
from sklearn.model_selection import train_test_split

from typing import List, Optional, Callable, Union, Any, Tuple

from graphxai_local.utils.explanation import EnclosingSubgraph
from graphxai_local.utils import to_networkx_conv
import graphxai_local.datasets as gxai_data


def get_dataset(dataset, download=False):
    """
    Base function for loading all datasets.
    Args:
        dataset (str): Name of dataset to retrieve.
        download (bool): If `True`, downloads dataset to local directory.
    """

    if dataset == "MUTAG":
        # Get MUTAG dataset
        return gxai_data.real_world.MUTAG.MUTAG()
    else:
        raise NameError("Cannot find dataset '{}'.".format(dataset))


class NodeDataset:
    def __init__(
        self,
        name,
        num_hops: int,
        download: Optional[bool] = False,
        root: Optional[str] = None,
    ):
        self.name = name
        self.num_hops = num_hops

        # self.graph = WholeGraph() # Set up whole_graph with all None
        # self.explanations = []

    def get_graph(
        self,
        use_fixed_split: bool = True,
        split_sizes: Tuple = (0.7, 0.2, 0.1),
        stratify: bool = True,
        seed: int = None,
    ):
        """
        Gets graph object for training/validation/testing purposes
            - Sets masks within torch_geometric.data.Data object
        Args:
            use_static_split (bool, optional): If true, uses the fixed train/val/test
                mask defined by the child dataset class. (:default: :obj:`True`)
            split_sizes (tuple, length 2 or 3, optional): If length 2, index 0 is train
                size, index 1 is test size. If length 3, index 2 becomes val size. Does
                not need to sum to 1, just needs to capture relative proportions of each
                split. (:default: :obj:`(0.7, 0.2, 0.1)`)
            stratify (bool, optional): If True, stratifies the splits by class label.
                Only relevant if `use_fixed_split = False`. (:default: :obj:`True`)
            seed (int, optional): Seed for splitting. (:default: :obj:`None`)

        :rtype: torch_geometric.data.Data
        Returns:
            graph: Data object containing masks over the splits (graph.train_mask,
                graph.valid_mask, graph.test_mask) and the full data for the graph.
        """

        if sum(split_sizes) != 1:  # Normalize split sizes
            split_sizes = np.array(split_sizes) / sum(split_sizes)

        if use_fixed_split:
            # Set train, test, val static masks:
            self.graph.train_mask = self.fixed_train_mask
            self.graph.valid_mask = self.fixed_valid_mask
            self.graph.test_mask = self.fixed_test_mask

        else:
            # assert sum(split_sizes) == 1, "split_sizes must sum to 1"
            assert (
                len(split_sizes) == 3
            ), "split_sizes must contain (train_size, test_size, valid_size)"
            # Create a split for user (based on seed, etc.)
            train_mask, test_mask = train_test_split(
                list(range(self.graph.num_nodes)),
                test_size=split_sizes[1] + split_sizes[2],
                random_state=seed,
                stratify=self.graph.y.tolist() if stratify else None,
            )
            # print(self.graph.y.tolist())
            # print(train_mask)
            # exit(0)

            if split_sizes[2] > 0:
                valid_mask, test_mask = train_test_split(
                    test_mask,
                    test_size=split_sizes[2] / split_sizes[1],
                    random_state=seed,
                    stratify=self.graph.y[test_mask].tolist() if stratify else None,
                )
                self.graph.valid_mask = torch.tensor(
                    [i in valid_mask for i in range(self.graph.num_nodes)],
                    dtype=torch.bool,
                )

            self.graph.train_mask = torch.tensor(
                [i in train_mask for i in range(self.graph.num_nodes)], dtype=torch.bool
            )
            self.graph.test_mask = torch.tensor(
                [i in test_mask for i in range(self.graph.num_nodes)], dtype=torch.bool
            )

        return self.graph

    def download(self):
        """TODO: Implement"""
        pass

    def get_enclosing_subgraph(self, node_idx: int):
        """
        Args:
            node_idx (int): Node index for which to get subgraph around
        """
        k_hop_tuple = k_hop_subgraph(
            node_idx, num_hops=self.num_hops, edge_index=self.graph.edge_index
        )
        return EnclosingSubgraph(*k_hop_tuple)

    def nodes_with_label(self, label=0, mask=None) -> torch.Tensor:
        """
        Get all nodes that are a certain label
        Args:
            label (int, optional): Label for which to find nodes.
                (:default: :obj:`0`)

        Returns:
            torch.Tensor: Indices of nodes that are of the label
        """
        if mask is not None:
            return ((self.graph.y == label) & (mask)).nonzero(as_tuple=True)[0]
        return (self.graph.y == label).nonzero(as_tuple=True)[0]

    def choose_node_with_label(self, label=0, mask=None) -> Tuple[int, Explanation]:
        """
        Choose a random node with a given label
        Args:
            label (int, optional): Label for which to find node.
                (:default: :obj:`0`)

        Returns:
            tuple(int, Explanation):
                int: Node index found
                Explanation: explanation corresponding to that node index
        """
        nodes = self.nodes_with_label(label=label, mask=mask)
        node_idx = random.choice(nodes).item()

        return node_idx, self.explanations[node_idx]

    def nodes_in_shape(self, inshape=True, mask=None):
        """
        Get a group of nodes by shape membership.

        Args:
            inshape (bool, optional): If the nodes are in a shape.
                :obj:`True` means that the nodes returned are in a shape.
                :obj:`False` means that the nodes are not in a shape.

        Returns:
            torch.Tensor: All node indices for nodes in or not in a shape.
        """
        # Get all nodes in a shape
        condition = (
            (lambda n: self.G.nodes[n]["shape"] > 0)
            if inshape
            else (lambda n: self.G.nodes[n]["shape"] == 0)
        )

        if mask is not None:
            condition = lambda n: (condition(n) and mask[n].item())

        return torch.tensor([n for n in self.G.nodes if condition(n)]).long()

    def choose_node_in_shape(self, inshape=True, mask=None) -> Tuple[int, Explanation]:
        """
        Gets a random node by shape membership.

        Args:
            inshape (bool, optional): If the node is in a shape.
                :obj:`True` means that the node returned is in a shape.
                :obj:`False` means that the node is not in a shape.

        Returns:
            Tuple[int, Explanation]
                int: Node index found
                Explanation: Explanation corresponding to that node index
        """
        nodes = self.nodes_in_shape(inshape=inshape, mask=mask)
        node_idx = random.choice(nodes).item()

        return node_idx, self.explanations[node_idx]

    def choose_node(self, inshape=None, label=None, split=None):
        """
        Chooses random nodes in the graph. Has support for multiple logical
            indexing.

        Args:
            inshape (bool, optional): If the node is in a shape.
                :obj:`True` means that the node returned is in a shape.
                :obj:`False` means that the node is not in a shape.
            label (int, optional): Label for which to find node.
                (:default: :obj:`0`)

        Returns:
        """
        split = split.lower() if split is not None else None

        if split == "validation" or split == "valid" or split == "val":
            split = "val"

        map_to_mask = {
            "train": self.graph.train_mask,
            "val": self.graph.valid_mask,
            "test": self.graph.test_mask,
        }

        # Get mask based on provided string:
        mask = None if split is None else map_to_mask[split]

        if inshape is None:
            if label is None:
                to_choose = torch.arange(end=self.num_nodes)
            else:
                to_choose = self.nodes_with_label(label=label, mask=mask)

        elif label is None:
            to_choose = self.nodes_in_shape(inshape=inshape, mask=mask)

        else:
            t_inshape = self.nodes_in_shape(inshape=inshape, mask=mask)
            t_label = self.nodes_with_label(label=label, make=mask)

            # Joint masking over shapes and labels:
            to_choose = torch.as_tensor(
                [n.item() for n in t_label if n in t_inshape]
            ).long()

        assert_fmt = "Could not find a node in {} with inshape={}, label={}"
        assert to_choose.nelement() > 0, assert_fmt.format(self.name, inshape, label)

        node_idx = random.choice(to_choose).item()
        return node_idx, self.explanations[node_idx]

    def __len__(self) -> int:
        return 1  # There is always just one graph

    def dump(self, fname=None):
        fname = self.name + ".pickle" if fname is None else fname
        torch.save(self, open(fname, "wb"))

    @property
    def x(self):
        return self.graph.x

    @property
    def edge_index(self):
        return self.graph.edge_index

    @property
    def y(self):
        return self.graph.y

    def __getitem__(self, idx):
        assert idx == 0, "Dataset has only one graph"
        return self.graph, self.explanation


class GraphDataset:
    def __init__(self, name, split_sizes=(0.7, 0.2, 0.1), seed=None, device=None):

        self.name = name

        self.seed = seed
        self.device = device
        # explanation_list - list of explanations for each graph

        self.Y = torch.tensor([self.graphs[i].y for i in range(len(self.graphs))])

        if split_sizes[1] > 0:
            self.train_index, self.test_index = train_test_split(
                torch.arange(start=0, end=len(self.graphs)),
                test_size=split_sizes[1] + split_sizes[2],
                random_state=self.seed,
                shuffle=True,
                stratify=self.Y,
            )
        else:
            self.test_index = None
            self.train_index = torch.arange(start=0, end=len(self.graphs))

        if split_sizes[2] > 0:
            self.test_index, self.val_index = train_test_split(
                self.test_index,
                test_size=split_sizes[2] / (split_sizes[1] + split_sizes[2]),
                random_state=self.seed,
                shuffle=True,
                stratify=self.Y[self.test_index],
            )

        else:
            self.val_index = None

        self.Y = self.Y.to(self.device)

    def get_data_list(
        self,
        index,
    ):
        data_list = [self.graphs[i].to(self.device) for i in index]
        exp_list = [self.explanations[i] for i in index]

        return data_list, exp_list

    def get_loader(self, index, batch_size=16, **kwargs):

        data_list, exp_list = self.get_data_list(index)

        for i in range(len(data_list)):
            data_list[i].exp_key = [i]

        loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

        return loader, exp_list

    def get_train_loader(self, batch_size=16):
        return self.get_loader(index=self.train_index, batch_size=batch_size)

    def get_train_list(self):
        return self.get_data_list(index=self.train_index)

    def get_test_loader(self):
        assert self.test_index is not None, "test_index is None"
        return self.get_loader(index=self.test_index, batch_size=1)

    def get_test_list(self):
        assert self.test_index is not None, "test_index is None"
        return self.get_data_list(index=self.test_index)

    def get_val_loader(self):
        assert self.test_index is not None, "val_index is None"
        return self.get_loader(index=self.val_index, batch_size=1)

    def get_val_list(self):
        assert self.val_index is not None, "val_index is None"
        return self.get_data_list(index=self.val_index)

    def get_train_w_label(self, label):
        inds_to_choose = (self.Y[self.train_index] == label).nonzero(as_tuple=True)[0]
        in_train_idx = inds_to_choose[
            torch.randint(low=0, high=inds_to_choose.shape[0], size=(1,))
        ]
        chosen = self.train_index[in_train_idx.item()]

        return self.graphs[chosen], self.explanations[chosen]

    def get_test_w_label(self, label):
        assert self.test_index is not None, "test_index is None"
        inds_to_choose = (self.Y[self.test_index] == label).nonzero(as_tuple=True)[0]
        in_test_idx = inds_to_choose[
            torch.randint(low=0, high=inds_to_choose.shape[0], size=(1,))
        ]
        chosen = self.test_index[in_test_idx.item()]

        return self.graphs[chosen], self.explanations[chosen]

    def get_graph_as_networkx(self, graph_idx):
        """
        Get a given graph as networkx graph
        """

        g = self.graphs[graph_idx]
        return to_networkx_conv(g, node_attrs=["x"], to_undirected=True)

    def download(self):
        pass

    def __getitem__(self, idx):
        try:
            return self.graphs[idx], self.explanations[idx]
        except AttributeError:
            return self.graphs[idx]

    def __len__(self):
        return len(self.graphs)
