import logging
import os
from typing import Tuple
from lightning import seed_everything
from typing import Any, Dict
from . import (
    DatasetLowRankGenerator,
    DatasetLowRankGeneratorIFG,
    DatasetSPNFGGenerator,
    DatasetSPNIFGGenerator
)


def sim_fg_data(
    n_sim: int,
    n_features: int,
    n_modules: int,
    p_vertex: float,
    p_module: float,
    n_samples: int,
    n_hidden: int,
    nb_interventions: int,
    nb_test_interventions: int,
    out_dir: str,
    min_nb_target: bool = 1,
    max_nb_target: bool = 3,
    max_corr: float = 1.0,
    min_corr: float = 0.8,
    graph_dropout: float = 0.0,
    dependent_dropout: bool = True,
    noise_level: float = 0.1,
    enhanced: bool = True,
    hard_intv: bool = True,
    rescale: bool = False,
    conservative: bool = False,
    uniform: bool = False,
    cover: bool = True,
    enhanced_kwargs: Dict[str, Any] = {
          'mean_par': 5,
          'mean_chr': 5,
          'min_unique_par': 1,
          'min_unique_chr': 1,
          'min_upstream': -1,
          'min_downstream': -1
        },
    random_state: int = 0,
    verbose: bool = False,
):
    logger = logging.getLogger(__name__)
    seed_everything(random_state)
    sim_data = DatasetLowRankGenerator(
        n_features,
        n_modules,
        p_vertex,
        p_module,
        n_samples,
        hard_intv,
        n_hidden,
        enhanced,
        rescale,
        nb_interventions,
        nb_test_interventions,
        min_nb_target,
        max_nb_target,
        noise_level,
        max_corr,
        min_corr,
        graph_dropout,
        dependent_dropout,
        conservative,
        uniform,
        cover,
        enhanced_kwargs,
        verbose
    )
    for i in range(n_sim):
        sim_data.generate()
        folder = f'{out_dir}/n{n_features}_m{n_modules}_{i}'
        os.makedirs(folder, exist_ok=True)
        sim_data.save_data(folder, i)
    logger.info(f"Simulated {n_sim} datasets. Saved output to {out_dir}.")


def sim_ifg_data(
    n_sim: int,
    n_features: int,
    n_modules: int,
    p_vertex: float,
    p_module: float,
    n_samples: int,
    n_hidden: int,
    nb_interventions: int,
    nb_test_interventions: int,
    out_dir: str,
    min_nb_target: bool = 1,
    max_nb_target: bool = 3,
    max_corr: float = 1.0,
    min_corr: float = 0.8,
    graph_dropout: float = 0.0,
    dependent_dropout: bool = True,
    alpha: float = 0.1,
    scale: float = 0.1,
    noise_level: float = 0.1,
    enhanced: bool = True,
    hard_intv: bool = False,
    rescale: bool = False,
    conservative: bool = False,
    uniform: bool = False,
    cover: bool = True,
    enhanced_kwargs: Dict[str, Any] = {
          'mean_par': 5,
          'mean_chr': 5,
          'min_unique_par': 1,
          'min_unique_chr': 1,
          'min_upstream': -1,
          'min_downstream': -1
        },
    random_state: int = 0,
    verbose: bool = False,
):
    logger = logging.getLogger(__name__)
    if nb_interventions < 1:
        raise ValueError("Number of interventions must be at least 1.")
    seed_everything(random_state)
    sim_data = DatasetLowRankGeneratorIFG(
        n_features,
        n_modules,
        p_vertex,
        p_module,
        n_samples,
        hard_intv,
        n_hidden,
        enhanced,
        rescale,
        nb_interventions,
        nb_test_interventions,
        min_nb_target,
        max_nb_target,
        noise_level,
        max_corr,
        min_corr,
        graph_dropout,
        dependent_dropout,
        alpha,
        scale,
        conservative,
        uniform,
        cover,
        enhanced_kwargs,
        verbose
    )
    for i in range(n_sim):
        sim_data.generate()
        folder = f'{out_dir}/n{n_features}_m{n_modules}_{i}'
        os.makedirs(folder, exist_ok=True)
        sim_data.save_data(folder, i)
    logger.info(f"Simulated {n_sim} datasets. Saved output to {out_dir}.")


def sim_spnfg_data(
    n_sim: int,
    n_features: int,
    n_modules: int,
    n_samples: int,
    hard_intv: bool,
    max_copies: int,
    p_conn: float,
    sparsity_temp: float,
    n_hidden: int,
    nb_interventions: int,
    nb_test_interventions: int,
    out_dir: str,
    min_nb_target: bool = 1,
    max_nb_target: bool = 3,
    max_corr: float = 1.0,
    min_corr: float = 0.8,
    noise_level: float = 0.1,
    graph_dropout: float = 0.0,
    dependent_dropout: bool = True,
    rescale: bool = False,
    conservative: bool = False,
    uniform: bool = False,
    cover: bool = True,
    random_state: int = 0,
    verbose: bool = False,
):
    logger = logging.getLogger(__name__)
    seed_everything(random_state)
    sim_data = DatasetSPNFGGenerator(
        n_features,
        n_modules,
        n_samples,
        hard_intv,
        max_copies,
        p_conn,
        sparsity_temp,
        n_hidden,
        rescale,
        nb_interventions,
        nb_test_interventions,
        min_nb_target,
        max_nb_target,
        noise_level,
        max_corr,
        min_corr,
        graph_dropout,
        dependent_dropout,
        conservative,
        uniform,
        cover,
        verbose
    )
    for i in range(n_sim):
        sim_data.generate()
        folder = f'{out_dir}/n{n_features}_m{n_modules}_{i}'
        os.makedirs(folder, exist_ok=True)
        sim_data.save_data(folder, i)
    logger.info(f"Simulated {n_sim} datasets. Saved output to {out_dir}.")


def sim_spnifg_data(
    n_sim: int,
    n_features: int,
    n_modules: int,
    n_samples: int,
    max_copies: int,
    p_conn: Tuple[float, float],
    sparsity_temp: Tuple[float, float],
    n_hidden: int,
    nb_interventions: int,
    nb_test_interventions: int,
    out_dir: str,
    hard_intv: bool = False,
    min_nb_target: bool = 1,
    max_nb_target: bool = 3,
    max_corr: float = 1.0,
    min_corr: float = 0.8,
    graph_dropout: float = 0.0,
    dependent_dropout: bool = True,
    alpha: float = 0.1,
    scale: float = 0.1,
    noise_level: float = 0.1,
    rescale: bool = False,
    conservative: bool = False,
    uniform: bool = False,
    cover: bool = True,
    random_state: int = 0,
    verbose: bool = False,
):
    logger = logging.getLogger(__name__)
    if nb_interventions < 1:
        raise ValueError("Number of interventions must be at least 1.")
    seed_everything(random_state)
    sim_data = DatasetSPNIFGGenerator(
        n_features,
        n_modules,
        n_samples,
        hard_intv,
        max_copies,
        p_conn,
        sparsity_temp,
        n_hidden,
        rescale,
        nb_interventions,
        nb_test_interventions,
        min_nb_target,
        max_nb_target,
        noise_level,
        max_corr,
        min_corr,
        graph_dropout,
        dependent_dropout,
        alpha,
        scale,
        conservative,
        uniform,
        cover,
        verbose
    )
    for i in range(n_sim):
        sim_data.generate()
        folder = f'{out_dir}/n{n_features}_m{n_modules}_{i}'
        os.makedirs(folder, exist_ok=True)
        sim_data.save_data(folder, i)
    logger.info(f"Simulated {n_sim} datasets. Saved output to {out_dir}.")
