#!/usr/bin/env python3

from __future__ import annotations

from dowhy import gcm
from dowhy.gcm.auto import AssignmentQuality
import networkx as nx
import pandas as pd


def fit_SCM(
    causal_graph: nx.DiGraph,
    observational_data: pd.DataFrame,
    perf_summary_causal_model: bool = False,
    causal_mechanism_quality: AssignmentQuality = AssignmentQuality.BETTER,
    verbose: bool = True,
) -> gcm.StructuralCausalModel:
    r""" 
    Fit a strctural causal model (SCM) to the causal graph and
    observational data. The function assigns causal mechanisms
    to the causal graph and fits the causal model using the
    observational data. The fitted causal model can be used to
    generate data from the interventional or observational

    Check `dowhy` (gcm.StructuralCausalModel) documentation 
    for more details.

    Args:
        causal_graph (nx.DiGraph): The learned Directed Acyclic
            Graph (DAG) representing the causal relationships
            between variables.
        observational_data (pd.DataFrame, optional): The observed
            data used to fit the causal model.  
        perf_summary_causal_model (bool): If True, returns a summary of 
            the performances of the fitted mechanisms using the 
            evaluate_causal_model method. If False, nothing is returned.
            NOTE: Check `dowhy` doc.
        causal_mechanism_quality (AssignmentQuality): AssignmentQuality 
            for automatic model selection and accuracy. Controls the 
            type of model used and time spent on selection.
            NOTE: Check `dowhy` doc.
        verbose (bool, optional): If True, shows progress bars
                during the fitting process. Default is True.

    Returns:
        gcm.StructuralCausalModel: The fitted causal model.
    """
    # Causal model fiting from DAG and data
    gcm.config.show_progress_bars = verbose
    causal_model = gcm.StructuralCausalModel(causal_graph)
    # Assign causal mechanisms
    gcm.auto.assign_causal_mechanisms(
        causal_model, 
        observational_data,
        causal_mechanism_quality
    )
    gcm.fit(
        causal_model, observational_data, perf_summary_causal_model
    )
    return causal_model