from utils.graph import get_dag_from_causal_graph
from utils.utils import get_params


import networkx as nx
import numpy as np
import pandas as pd
from dowhy import CausalModel
from omegaconf import DictConfig


import random
import time
from typing import Any, Dict, Optional


def run_dowhy_experiment(
    seed: int,
    causal_graph: nx.Graph,
    experiment_cfg: DictConfig,
    treatment_list: list[dict[str, Any]],
    outcome: dict[str, Any],
    data: Optional[pd.DataFrame] = None,
    int_table: Optional[pd.DataFrame] = None,
    evidence: dict[str, Any] = {},
    save_dir: str = None,
) -> tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Run a DoWhy experiment to estimate causal effects using a given causal graph and configuration.

    Args:
        seed (int): Random seed for reproducibility.
        causal_graph (nx.Graph): Causal graph representing the causal relationships between variables.
        experiment_cfg (DictConfig): Configuration for the experiment.
        treatment (str): Name of the treatment variable.
        outcome (str): Name of the outcome variable.
        observational_data (Optional[pd.DataFrame], optional): DataFrame containing observational data. Defaults to None.
        interventional_data (Optional[pd.DataFrame], optional): DataFrame containing interventional data. Defaults to None.
        evidence (dict[str, Any], optional): Dictionary of additional evidence variables and their values. Defaults to {}.

    Returns:
        tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the results and runtime information of the experiment.
            The results dictionary contains the seed and the estimated causal effect.
            The runtime dictionary contains the seed, identification time, and estimation time.
    """

    # Set random seed
    np.random.seed(seed)
    random.seed(seed)

    # Parse experiment config
    method = experiment_cfg.method_name
    results_list, runtime_list = [], []

    # Those methods work only with DAGs
    if not nx.is_directed_acyclic_graph(causal_graph):
        causal_graph = get_dag_from_causal_graph(causal_graph)

    for k, treatment in enumerate(treatment_list):
        treatment_vars = treatment["treatment_var"]
        treatment_values = treatment["treatment_value"]
        control_values = treatment["control_value"]

        # Step 1: Create a causal model from the data and given graph.
        print(f"Running experiment for treatment {k} and seed: {seed}")
        model = CausalModel(
            data=data,
            treatment=treatment_vars,
            outcome=outcome,
            graph=causal_graph,
        )

        # Step 2: Identify causal effect and return target estimands
        start_identification_time = time.process_time()
        identified_estimand = model.identify_effect()
        delta_identification_time = time.process_time() - start_identification_time
        print(f"Time for identifying effect: {delta_identification_time} CPU seconds")

        # Step 3: Estimate the target estimand using a statistical method.
        # Recover method_params from the experiment configuration
        method_params = get_params(experiment_cfg)
        if experiment_cfg.estimate_ate:
            start_estimate_time = time.process_time()
            ate_estimate = model.estimate_effect(
                identified_estimand,
                method_name=method,
                method_params=method_params,
                target_units="ate",
                treatment_value=treatment_values,
                control_value=control_values,
            )
            delta_estimate_time_ate = time.process_time() - start_estimate_time
            print(f"Time for estimating ATE: {delta_estimate_time_ate} CPU seconds")


        # If the model supports it, estimate CATE
        if experiment_cfg.estimate_cate:
            if evidence == {}:
                raise ValueError("Evidence not provided, cannot run CATE estimation.")

            evidence_label = list(evidence.keys())  # FIXME
            evidence_value = list(evidence.values())
            cate_df = data[data[evidence_label]==evidence_value]
            start_estimate_time_cate = time.process_time()
            cate_estimate = model.estimate_effect(
                identified_estimand,
                method_name=method,
                target_units= lambda df: df[evidence_label]==evidence_value,
                treatment_value= treatment_values,
                control_value=control_values,
                method_params=method_params,
            )
            delta_estimate_time_cate = time.process_time() - start_estimate_time_cate
            print(f"Time for estimating CATE: {delta_estimate_time_cate} CPU seconds")

        results = {
            "Seed": seed,
            "target": outcome,
            "ATE": ate_estimate.value if experiment_cfg.estimate_ate else None,
            "evidence": evidence if experiment_cfg.estimate_cate else None,
            "CATE": cate_estimate.value if experiment_cfg.estimate_cate else None,
        }

        runtime = {
            "Seed": seed,
            "Identification Time": delta_identification_time,
            "Estimation Time ATE": delta_estimate_time_ate
            if experiment_cfg.estimate_ate
            else None,
            "Estimation Time CATE": delta_estimate_time_cate
            if experiment_cfg.estimate_cate
            else None,
        }

        results_list.append(results)
        runtime_list.append(runtime)

    return results_list, runtime_list


