import networkx as nx
import numpy as np
import torch
import os

from numpy.typing import NDArray
from typing import Tuple, Any, List
from torch import Tensor
from torch.utils.data import Dataset
#from utils.graph_dataset import GraphDatasetGenerator # TODO: remove


# * Utility functions *
def generate_and_store_dataset(
    data_output_file : str,
    gt_output_file: str,
    model: Any,
    standardize: bool
):
    """Generate a dataset with `model`, and store the resulting dataset and grountruth.

    Parameters
    ----------
    data_output_file: str
        Path for storage of the dataset as `.npy` file.
    gt_output_file: str
        Path for storage of the groundtruth adjacency matrix as `.npy` file.
    model: BaseStructuralCausalModel
        Instance of `BaseStructuralCausalModel` generating the data and the groundtruth.
    standardize: bool
        If True, correct empirical standard deviation to one.
    """
    dataset, groundtruth = model.sample()
    if standardize:
        marginal_std = np.std(dataset, axis=0)
        for i in range(len(marginal_std)):
            dataset[:, i] = dataset[:, i]/marginal_std[i]
    np.save(data_output_file, dataset)
    np.save(gt_output_file, groundtruth)


def prepare_model_input(data: Tensor, target: Tensor, device: str, start_token: int) -> Tuple[Tensor, Tensor, Tensor]:
    """TODO: add docstring. What do I do?
    """
    data = data.to(device)
    data = data.float()

    target = target.to(device)
    target = target.float()

    # Shift target for training. So model learns to predict token n from token n-1.
    shifted_target = torch.clone(target)
    shifted_target[:, 1:] = target[:, :-1]
    shifted_target[:, 0] = start_token  # indicate start of sequence

    return data, target, shifted_target


# # TODO: remove if we use causally data generation flow
# def get_dataloaders(num_graphs, num_samples, batch_size, num_nodes):
#     train_set = GraphDatasetGenerator(num_graphs=num_graphs, num_samples=num_samples, num_nodes=num_nodes)
#     test_set = GraphDatasetGenerator(num_graphs=num_graphs, num_samples=num_samples, num_nodes=num_nodes)
#     train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
#     test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
#     return train_loader, test_loader



def max_edges_in_dag(num_nodes: int) -> int:
    """Compute the maximum number of edges allowed for a direcetd acyclic graph:
    
    The max number of edges is compute as `self.num_nodes*(self.num_nodes-1)/2`
    """
    return int(num_nodes*(num_nodes-1)/2)


def topological_order(adjacency: NDArray):
    # DAG test
    if not nx.is_directed_acyclic_graph(nx.from_numpy_array(adjacency, create_using=nx.DiGraph)):
        raise ValueError("The input adjacency matrix is not acyclic.")
    

    # Define toporder one leaf at the time
    order = list()
    num_nodes = len(adjacency)
    mask = np.zeros((num_nodes))
    for _ in range(num_nodes):
        children_per_node = adjacency.sum(axis=1) + mask # adjacency[i, j] = 1 --> i parent of j
        leaf = np.argmin(children_per_node) # find leaf as node with no children
        mask[leaf] += float("inf") # select node only once
        order.append(leaf) # update order
    
    order = order[::-1] # source first
    return order



# * Utility classes *
class GraphDataset(Dataset):
    def __init__(
        self, path_to_datasets: List[str], dataset_ratios: List[float], n_samples: int
    ):
        """_summary_

        Args:
            path_to_datasets (List[str]): list of the training datasets paths
            dataset_ratios (List[float]): list of float values summing to 1,
            of the same length of path_to_datasets. It determines the ratio of datapoints
            from each single dataset.
            n_samples (int, optional): Total number of samples in the final dataset.
            Defaults to None.

        Raises:
            FileNotFoundError: _description_
        """
        if len(dataset_ratios) != len(path_to_datasets):
            raise ValueError(f"dataset_ratios list of length {len(dataset_ratios)}, " +
                             f"path_to_datasets list of length {len(path_to_datasets)}, " +
                             "while they must have same length.")
        
        if np.sum(dataset_ratios) != 1:
            raise ValueError("The sum of the dataset ratios does not sum to 1.")

        self.dataset = list()

        # For each dataset in path_to_dataset create an iterator
        dataset_iterators = [iter(sorted(os.listdir(path))) for path in path_to_datasets]

        # TODO: Clean up! Workaround to handle test data of different size
        # if len(path_to_datasets) == 1:
        #     n_samples = len(os.listdir(path_to_datasets[0]))
        for _ in range(n_samples):
            dataset_idx = np.random.choice(a=len(dataset_iterators), p=dataset_ratios) # Choose iterator
            iterator = dataset_iterators[dataset_idx]
            data_filename = ""
            
            try: # if dataset shorter than n_samples, next(iterator) gives error
                while not data_filename.startswith('data_'):
                    data_filename = next(iterator)

                id = data_filename.split('_')[1].split('.')[0]  # Extract the id value from the filename
                data_path = os.path.join(path_to_datasets[dataset_idx], data_filename)
                groundtruth_filename = 'groundtruth_' +  data_filename.split("_")[-1]
                groundtruth_path = os.path.join(path_to_datasets[dataset_idx], groundtruth_filename)
            except StopIteration:
                break # TODO: ugly! Shouldn't call next directly, this is giving me problems

            # Check if the corresponding groundtruth file exists
            if os.path.isfile(groundtruth_path):
                self.dataset.append((data_path, groundtruth_path))
            else:
                raise FileNotFoundError(f"{groundtruth_path} file for id={id} does not exist.")
                

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Read data and groundtruth files
        data_path, gt_path = self.dataset[idx]
        data = torch.from_numpy(np.load(data_path)).float()
        groundtruth = torch.from_numpy(np.load(gt_path).flatten()).float()
        
        return data, groundtruth
                

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Read data and groundtruth files
        data_path, gt_path = self.dataset[idx]

        try:
            data = torch.from_numpy(np.load(data_path)).float()
            groundtruth = torch.from_numpy(np.load(gt_path).flatten()).float()
        except FileNotFoundError:
            # v basic rerun??

            print(idx, data_path, gt_path, os.path.isfile(data_path), os.path.isfile(gt_path))
            data = torch.from_numpy(np.load(data_path)).float()
            groundtruth = torch.from_numpy(np.load(gt_path).flatten()).float()
        
        return data, groundtruth
        