"""Generation and storage of data for large scale experiments
"""

import networkx as nx
import numpy as np
import torch
import os

from numpy.typing import NDArray
from typing import Tuple, Any
from torch import Tensor
from torch.utils.data import Dataset
from csiva.utils.graph_dataset import GraphDatasetGenerator # TODO: remove


# * Utility functions *
def generate_and_store_dataset(
    data_output_file : str,
    gt_output_file: str,
    model: Any
):
    """Generate a dataset with `model`, and store the resulting dataset and grountruth.

    Parameters
    ----------
    data_output_file: str
        Path for storage of the dataset as `.npy` file.
    gt_output_file: str
        Path for storage of the groundtruth adjacency matrix as `.npy` file.
    model: BaseStructuralCausalModel
        Instance of `BaseStructuralCausalModel` generating the data and the groundtruth.
    """
    dataset, groundtruth = model.sample()
    np.save(data_output_file, dataset)
    np.save(gt_output_file, groundtruth)


def prepare_model_input(data: Tensor, target: Tensor, device: str, start_token: int) -> Tuple[Tensor, Tensor, Tensor]:
    """TODO: add docstring. What do I do?
    """
    data = data.to(device)
    data = data.float()

    target = target.to(device)
    target = target.float()

    # Shift target for training. So model learns to predict token n from token n-1.
    shifted_target = torch.clone(target)
    shifted_target[:, 1:] = target[:, :-1]
    shifted_target[:, 0] = start_token  # indicate start of sequence

    return data, target, shifted_target


# TODO: remove if we use causally data generation flow
def get_dataloaders(num_graphs, num_samples, batch_size, num_nodes):
    train_set = GraphDatasetGenerator(num_graphs=num_graphs, num_samples=num_samples, num_nodes=num_nodes)
    test_set = GraphDatasetGenerator(num_graphs=num_graphs, num_samples=num_samples, num_nodes=num_nodes)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
    return train_loader, test_loader



def max_edges_in_dag(num_nodes: int) -> int:
    """Compute the maximum number of edges allowed for a direcetd acyclic graph:
    
    The max number of edges is compute as `self.num_nodes*(self.num_nodes-1)/2`
    """
    return int(num_nodes*(num_nodes-1)/2)


def topological_order(adjacency: NDArray):
    # DAG test
    if not nx.is_directed_acyclic_graph(nx.from_numpy_array(adjacency, create_using=nx.DiGraph)):
        raise ValueError("The input adjacency matrix is not acyclic.")
    

    # Define toporder one leaf at the time
    order = list()
    num_nodes = len(adjacency)
    mask = np.zeros((num_nodes))
    for _ in range(num_nodes):
        children_per_node = adjacency.sum(axis=1) + mask # adjacency[i, j] = 1 --> i parent of j
        leaf = np.argmin(children_per_node) # find leaf as node with no children
        mask[leaf] += float("inf") # select node only once
        order.append(leaf) # update order
    
    order = order[::-1] # source first
    return order



# * Utility classes *
class GraphDataset(Dataset):
    def __init__(self, path_to_dataset):
        self.dataset = list()
        self.name = path_to_dataset.split(os.path.sep)[-1]

        # Read directory in a dict. key=id, value=(dataset_file, groundtruth_file)
        for data_filename in os.listdir(path_to_dataset):
            if data_filename.startswith('data_'):
                id = data_filename.split('_')[1].split('.')[0]  # Extract the id value from the filename
                data_path = os.path.join(path_to_dataset, data_filename)
                groundtruth_filename = 'groundtruth_' +  data_filename.split("_")[-1]
                groundtruth_path = os.path.join(path_to_dataset, groundtruth_filename)

                # Check if the corresponding groundtruth file exists
                if os.path.isfile(groundtruth_path):
                    self.dataset.append((data_path, groundtruth_path))
                else:
                    raise FileNotFoundError(f"{groundtruth_path} file for id={id} does not exist.")
                

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Read data and groundtruth files
        data_path, gt_path = self.dataset[idx]
        data = torch.from_numpy(np.load(data_path)).float()
        groundtruth = torch.from_numpy(np.load(gt_path).flatten()).float()
        
        return data, groundtruth
        