import numpy as np
import jax.numpy as jnp
from jax.tree_util import tree_map
from torch.utils.data import Dataset, DataLoader, default_collate


class HierarchyDatasetGenerator():
    """
    A class for generating a dataset based on a hierarchical tree structure.

    Args:
        include_control (bool): If True, include control planets in the dataset.
        include_headnode (bool): If True, include the head node in the dataset.
        tree_depth (int): The depth of the hierarchical tree.

    Attributes:
        include_control (bool): If True, include control planets in the dataset.
        include_headnode (bool): If True, include the head node in the dataset.
        tree_depth (int): The depth of the hierarchical tree.
        num_nodes (int): The total number of nodes in the hierarchical tree.
        num_endnodes (int): The number of end nodes in the hierarchical tree.

    Methods:
        create_dataset(): Creates a dataset based on the hierarchical tree structure.
        binary_tree_adj_matrix(): Generates the adjacency matrix of a binary tree.
        depth_first_search(node, visited, adj_matrix): Performs a depth-first search to generate all possible walks in a binary tree.

    Returns:
        None
    """

    def __init__(self, 
                include_control:bool=False,
                include_headnode:bool=False,
                tree_depth:int=4,  
                include_bias_input=False):
        """
        Args:
            include_control (bool): if True, include control planets in the dataset
        """
        self.include_control = include_control
        self.include_headnode = include_headnode 
        self.include_bias_input = include_bias_input
        self.tree_depth  = tree_depth
        self.num_nodes = 2 ** (self.tree_depth) - 1 
        self.num_endnodes = 2 ** (self.tree_depth - 1)

    
    def create_dataset(self):
        """
        Creates a dataset based on the hierarchical tree structure.

        Returns:
            tuple: A tuple containing the inputs and the walk matrix of the dataset.
        """

        initial_node = 0
        adj_matrix = self.binary_tree_adj_matrix()
        visited = np.zeros(self.num_nodes)
        walk_vectors = self.depth_first_search(initial_node, visited, adj_matrix)
        # Converting list of vectors to matrix
        walk_matrix = np.column_stack(walk_vectors)
        if not self.include_headnode:
            walk_matrix = walk_matrix[1:,:]
        # if self.include_control:
        #     # adding a colum of zeros for the 9th 'control' item
        #     property_matrix = np.column_stack((property_matrix, np.zeros(self.num_nodes-1)))
        #     # adding a row of zeros for the 15th property and setting the property for the control planet
        #     property_matrix = np.row_stack((property_matrix, np.zeros(self.n_planets)))
        #     property_matrix[self.num_nodes-1, self.n_planets-1] = 3

        # walk matrix are the outputs y so now we make one-hot inputs x
        inputs = np.diag(np.ones(self.num_endnodes))
        if self.include_bias_input:
            inputs = np.column_stack((np.ones(self.num_endnodes), inputs))


        return  inputs, walk_matrix.T
    
    def binary_tree_adj_matrix(self)->np.ndarray:
        """function to generate the adjacency matrix of a binary tree

        Returns:
            np.ndarray: adjacency matrix of the binary tree specified by init arguments
        """
        # Initialize matrix with zeros
        adj_matrix = np.zeros((self.num_nodes, self.num_nodes))

        # Iterate over nodes and connect each node to its two children
        for i in range(self.num_nodes):
            # Two children of a node at index i are at indices 2*i+1 and 2*i+2
            child1, child2 = 2*i+1, 2*i+2

            # Connect parent node to child nodes if they are within the total number of nodes
            if child1 < self.num_nodes:
                adj_matrix[i][child1] = 1
                adj_matrix[child1][i] = 1

            if child2 < self.num_nodes:
                adj_matrix[i][child2] = 1
                adj_matrix[child2][i] = 1

        return adj_matrix
    
    def depth_first_search(self, node, visited, adj_matrix):
        """depth first search algorithm to 
        generate all possible walks in a binary tree

        Args:
            node (int): current node
            visited (np.array): the visited nodes in the tree of length equal to the number of nodes in the tree 
            adj_matrix (np.array): adjacency matrix of the binary tree

        Returns:
            list: of walks through the tree
        """
        visited[node] = 1
        walks = []
        
        # Check if current node has any unvisited children
        has_unvisited_child = any(adj_matrix[node][i] == 1 and visited[i] == 0 for i in range(len(visited)))
        
        if not has_unvisited_child:
            # If there are no unvisited children, then we are at a leaf node
            walks.append(visited.copy())
        else:
            # If there are unvisited children, visit them
            for i, _ in enumerate(visited):
                if adj_matrix[node][i] == 1 and visited[i] == 0:
                    walks += self.depth_first_search(i, visited, adj_matrix)
                    
        visited[node] = 0  # backtracking
        return walks
    
class HierarchyDataset(Dataset):
    """
    A dataset class for hierarchical data.

    This class extends the `Dataset` class from the `torch.utils.data` module. 
    It represents a dataset for hierarchical data, where each sample consists of an input and a target. 
    The input and target are retrieved from the `input_array` tuple provided during initialization. 
    Optionally, a transformation function can be applied to the input and target.

    Args:
        input_array (tuple): A tuple containing the input and target arrays.
        transform (callable, optional): A function that applies a transformation to the input and target. Defaults to None.

    Methods:
        __getitem__(index): Retrieves a sample from the dataset at the given index.
        __len__(): Returns the length of the dataset.

    Returns:
        tuple: A tuple containing the input and target of a sample.
    """

    def __init__(self, input_array: tuple, transform: callable = None):
        assert input_array[0].shape[0] == input_array[1].shape[0]
        self.input_array = input_array
        self.transform = transform

    def __getitem__(self, index: int) -> tuple:
        x = self.input_array[0][index]
        y = self.input_array[1][index]
        if self.transform:
            x, y = self.transform(x), self.transform(y)
        return x, y

    def __len__(self) -> int:
        return self.input_array[0].shape[0]

class Cast(object):
    """
    Casts the input to a numpy array of type jnp.float32.
    Args:
        input: The input to be casted.
    Returns:
        numpy.ndarray: The casted input.
    """
    def __call__(self, input):
        return np.array(input, dtype=jnp.float32)
  
def numpy_collate(batch):
    """
    Collates a batch of data using numpy.
    Args:
        batch (list): A list of data samples to be collated.
    Returns:
        numpy.ndarray: The collated batch.
    """
    return tree_map(np.asarray, default_collate(batch))

class NumpyLoader(DataLoader):
    """
    NumpyLoader is a subclass of DataLoader that loads data using numpy.

    Args:
        dataset: The dataset to load the data from.
        batch_size (int, optional): The batch size. Defaults to 1.
        shuffle (bool, optional): Whether to shuffle the data. Defaults to False.
        sampler (Sampler, optional): The sampler used for sampling data. Defaults to None.
        batch_sampler (Sampler, optional): The batch sampler used for sampling batches of data. Defaults to None.
        num_workers (int, optional): The number of worker processes. Defaults to 0.
        pin_memory (bool, optional): Whether to pin memory for faster data transfer. Defaults to False.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False.
        timeout (float, optional): The timeout value for collecting a batch from workers. Defaults to 0.
        worker_init_fn (Callable, optional): A function to initialize worker processes. Defaults to None.
    """
  
    def __init__(self,
                    dataset, 
                    batch_size=1,
                    shuffle=False, 
                    sampler=None,
                    batch_sampler=None,
                    num_workers=0,
                    pin_memory=False, 
                    drop_last=False,
                    timeout=0, 
                    worker_init_fn=None):
    
        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)
