from custom_models.load_model import load_custom_model
from utils.utils import get_params


import networkx as nx
import numpy as np
import pandas as pd
from omegaconf import DictConfig


import random
from typing import Any, Dict, Optional


def run_custom_model(
    seed: int,
    causal_graph: nx.Graph,
    experiment_cfg: DictConfig,
    treatments_list: dict[str, Any],
    outcome: dict[str, Any],
    data: Optional[pd.DataFrame] = None,
    int_table: Optional[pd.DataFrame] = None,
    evidence: dict[str, Any] = {},
    save_dir: Optional[str] = None,
) -> tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
    """
    Runs a custom causal model on the given data and returns the results and runtime information.

    Args:
        seed (int): Random seed for reproducibility.
        causal_graph (nx.Graph): Causal graph representing the relationships between variables.
        experiment_cfg (DictConfig): Configuration for the experiment.
        treatment (str): Name of the treatment variable.
        outcome (str): Name of the outcome variable.
        observational_data (Optional[pd.DataFrame], optional): DataFrame containing observational data. Defaults to None.
        interventional_data (Optional[pd.DataFrame], optional): DataFrame containing interventional data. Defaults to None.
        evidence (dict[str, Any], optional): Dictionary of additional evidence variables and their values. Defaults to {}.
        save_dir (str, optional): Directory to save the trained model. Defaults to None.

    Returns:
        tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the results and runtime information.
            - results (Dict[str, Any]): Dictionary containing the results of the causal model.
            - runtime (Dict[str, Any]): Dictionary containing the runtime information of the model training and estimation.
    """

    # Set random seed
    np.random.seed(seed)
    random.seed(seed)

    # Parse experiment config
    method_name = experiment_cfg.method_name

    # Step 1: Initialize Causal Model
    model = load_custom_model(method_name, causal_graph)
    method_params = get_params(experiment_cfg)

    # Step 2: Train Causal Model
    train_runtime = model.fit(
        data=data,
        int_table=int_table,
        method_params=method_params,
        seed=seed,
        save_dir=save_dir,
        outcome=outcome,
        treatment=treatments_list[0],
        evidence=evidence,
    )
    train_time = train_runtime["Training Time"]
    print(f"Time for training Causal Model: {train_time} CPU seconds")

    # Step 3: Estimate the target estimand using a statistical method.
    # Recover method_params from the experiment configuration
    method_params = get_params(experiment_cfg)
    # for treatment in treatments_list:
    #    # We want to do multiple effect estimations.

    results_list, runtime_list = [], []

    for treatment in treatments_list:
        results, runtime = model.estimate_effect(
            outcome=outcome,
            treatment=treatment,
            evidence=evidence,
            method_params=method_params,
            seed=seed,
            save_dir=save_dir,
            data = data,
        )
        results_list.append(results)
        runtime_list.append(runtime)

    [results.update({"Seed": seed}) for results in results_list]

    for runtime in runtime_list:
        runtime.update({"Seed": seed, "Training Time": train_time})

        if "Avg. Epoch Time" in train_runtime:
            runtime.update({"Avg. Epoch Time": train_runtime["Avg. Epoch Time"]})
        if "Avg. GPU Memory" in train_runtime:
            runtime.update({"Avg. GPU Memory": train_runtime["Avg. GPU Memory"]})

        total_time = sum(runtime[key] for key in runtime if key.endswith("Time"))
        runtime.update({"Total Time": total_time})

    return results_list, runtime_list