import os
import json
import time
from datetime import datetime
import random
import sys
from tqdm import tqdm
import pandas as pd

from causal_profiler.constants import (
    QueryType,
    NoiseDistribution,
    NoiseMode,
    MechanismFamily,
    VariableDataType,
)
from causal_profiler.space_of_interest import SpaceOfInterest
from causal_profiler.sampler import Sampler
from utils_metrics import (
    get_DAG_metrics,
    get_ADMG_metrics,
    get_distrib_metrics,
    get_mechanism_metrics,
)


# Parameter grids
NUM_VARIABLES_LIST = [3, 4, 5]
EDGES_RATIOS = [0.2, 0.4, 0.6, 0.8]
PROP_HIDDEN_LIST = [0.0, 0.1, 0.2, 0.3]
NUM_CATEGORIES_LIST = [2, 3, 4, 7]
NUM_NOISE_REGIONS_LIST = [2, 5, 10, 20, 50]

SAMPLES_PER_COMBO = 10  # Number of SCMs per parameter combination
NUM_QUERIES = 10
n_samples_for_computation = 10000  # 10k


def main():
    """
    Iterate over all desired parameter combinations. For each combination:
      1) Construct a SpaceOfInterest
      2) Sample multiple SCMs
      3) Analyze them using the utils_metrics.py script
      3) Save the result of the analysis in src/analysis/data/
    """

    # Ensure output directory exists
    output_dir = os.path.join(os.path.dirname(__file__), "data")
    os.makedirs(output_dir, exist_ok=True)

    # Create a subdirectory or a unique filename prefix
    # to store the analysis of the SCMs with N variables
    date_time = datetime.now().strftime("%Y_%m_%d_%H_%M")
    prefix = f"{date_time}_empirical_distrib"
    combo_dir = os.path.join(output_dir, prefix)
    os.makedirs(combo_dir, exist_ok=True)

    for N in tqdm(NUM_VARIABLES_LIST, desc="NUM_VARIABLES_LIST"):
        # Initialize the dict of results
        results = []
        results_dict = {}
        results_dict["num_variables"] = N

        # The maximum possible edges for a DAG on N nodes
        max_edges = N * (N - 1) // 2

        for edge_ratio in tqdm(EDGES_RATIOS, leave=False, desc="EDGES_RATIOS"):
            results_dict["edge_ratio"] = edge_ratio
            # Interpret `expected_edges` as a numeric expression
            expected_edges = edge_ratio * max_edges

            for hidden_prop in tqdm(
                PROP_HIDDEN_LIST, leave=False, desc="PROP_HIDDEN_LIST"
            ):
                results_dict["hidden_prop"] = hidden_prop

                for num_cats in tqdm(
                    NUM_CATEGORIES_LIST, leave=False, desc="NUM_CATEGORIES_LIST"
                ):
                    results_dict["num_cats"] = num_cats

                    for noise_regions in tqdm(
                        NUM_NOISE_REGIONS_LIST,
                        leave=False,
                        desc="NUM_NOISE_REGIONS_LIST",
                    ):
                        results_dict["num_noise_regions"] = noise_regions

                        # 1. Create a SpaceOfInterest
                        soi = SpaceOfInterest(
                            # N nodes, single dimensional
                            number_of_nodes=(N, N),
                            variable_dimensionality=(1, 1),
                            # edges
                            expected_edges=str(expected_edges),
                            # discrete mechanism
                            mechanism_family=MechanismFamily.TABULAR,
                            # noise & variable specs
                            noise_mode=NoiseMode.ADDITIVE,
                            noise_distribution=NoiseDistribution.UNIFORM,
                            noise_args=[-1, 1],
                            number_of_noise_regions=str(noise_regions),
                            variable_type=VariableDataType.DISCRETE,
                            number_of_categories=(num_cats, num_cats),
                            # hidden variables
                            proportion_of_hidden_variables=hidden_prop,
                            # queries
                            number_of_queries=NUM_QUERIES,
                            query_type=QueryType.CONDITIONAL,
                            # dataset
                            number_of_data_points=1,
                        )

                        # 2. Instantiate the sampler
                        sampler = Sampler(soi)

                        # 3. Generate and save SAMPLES_PER_COMBO SCMs
                        for i in range(SAMPLES_PER_COMBO):

                            # sample and set a random seed
                            seed = random.randrange(sys.maxsize)
                            random.seed(seed)

                            # sample a scm and record sampling duration
                            t_start = time.time()
                            scm = sampler.generate_scm()
                            t_end = time.time()

                            # save seed and sampling duration
                            results_dict["sampling_duration"] = t_end - t_start
                            results_dict["sampling_seed"] = seed

                            # Analyse the SCM
                            results_dict.update(get_DAG_metrics(scm))
                            results_dict.update(get_ADMG_metrics(scm))
                            results_dict.update(get_mechanism_metrics(scm))
                            results_dict.update(
                                get_distrib_metrics(
                                    scm,
                                    n_samples_for_computation=n_samples_for_computation,
                                )
                            )

                            # Add results
                            results.append(results_dict.copy())

                        # End of one parameter combo

        # save results for N
        df_results = pd.DataFrame(results)
        df_results.to_pickle(os.path.join(combo_dir, f"N_{N}.pkl"))
        results = []
    print("SCM sampling completed.")


if __name__ == "__main__":
    main()
