#!/usr/bin/env python

# -----------------------------------------------
# Install (only once):
#   pip install pyagrum
# -----------------------------------------------

import numpy as np
import time
from soft_fci import fast_soft_fci
from Utils import (
    random_admg,
    random_intervention_targets,
    generate_interventional_data_pyagrum_from_admg,
    admg_to_mag,
    mag_two_mats_to_marks,
    edge_mark_metrics_acc_dis_cnt,
)


# Experiment parameters (mirror IFCI setup for fair comparison)
n_obs = 5
rho_dag = 0.5
rho_bi = 0.8
n_admgs = 100
n_targets = 3
max_target_size = 2
n_samples = 50000
max_cond_set_size = 2
alpha = 0.01
ci_method = "chisq"
beta_a = 0.5
beta_b = 0.5
min_variation = 0.4

# Storage for metrics
acc_values = []
dis_values = []
cnt_values = []
runtime_values = []

# Main experiment loop
for exp_idx in range(n_admgs):
    # Set seed for reproducibility
    seed = exp_idx

    # 1. Generate random ADMG
    A_di, A_bi = random_admg(
        n_obs=n_obs,
        rho_dag=rho_dag,
        rho_bi=rho_bi,
        seed=seed,
    )

    # 2. Generate random target list
    targets = random_intervention_targets(
        n_obs=n_obs,
        n_targets=n_targets,
        max_target_size=max_target_size,
        seed=seed,
    )

    # 3. Generate interventional data using pyAgrum generator
    # Include observational data (empty target) as the first environment
    all_targets = [frozenset()] + targets  # observational first
    data_dict = generate_interventional_data_pyagrum_from_admg(
        A_di=A_di,
        A_bi=A_bi,
        targets=all_targets,
        n_samples=n_samples,
        seed=seed,
        beta_a=beta_a,
        beta_b=beta_b,
        min_variation=min_variation,
    )

    # 4. Format targets as dict for fast_soft_fci
    # fast_soft_fci expects targets as Dict[int, FrozenSet[int]]
    targets_dict = {k: t for k, t in enumerate(all_targets)}

    # 5. Run fast_soft_fci (with timing)
    try:
        start_time = time.time()
        state = fast_soft_fci(
            datasets=data_dict,
            targets=targets_dict,
            max_cond_set_size=max_cond_set_size,
            ci_method=ci_method,
            alpha=alpha,
            seed=seed,
        )
        runtime = time.time() - start_time
        runtime_values.append(runtime)
    except Exception as e:
        print(f"Warning: fast_soft_fci failed for experiment {exp_idx}: {e}")
        continue

    # 6. Extract observed part of adjacency (first n_obs nodes)
    # state.A has shape (n_obs + n_F_nodes, n_obs + n_F_nodes)
    pred_mark_matrix = state.A[:n_obs, :n_obs].copy()

    # 7. Convert input ADMG to MAG and to mark matrix
    M_di, M_bi = admg_to_mag(A_di, A_bi)
    true_mag_mark_matrix = mag_two_mats_to_marks(M_di, M_bi)

    # 8. Compute metrics
    ACC, DIS, CNT, _ = edge_mark_metrics_acc_dis_cnt(
        pred_mark_matrix, true_mag_mark_matrix
    )

    # Store metrics (skip NaN values)
    if not np.isnan(ACC):
        acc_values.append(ACC)
    if not np.isnan(DIS):
        dis_values.append(DIS)
    if not np.isnan(CNT):
        cnt_values.append(CNT)

    # Progress indicator
    if (exp_idx + 1) % 10 == 0:
        print(f"Completed {exp_idx + 1}/{n_admgs} experiments")

# 9. Report mean and standard error
acc_mean = np.mean(acc_values) if acc_values else np.nan
acc_se = (
    np.std(acc_values, ddof=1) / np.sqrt(len(acc_values))
    if len(acc_values) > 1
    else 0.0
)

dis_mean = np.mean(dis_values) if dis_values else np.nan
dis_se = (
    np.std(dis_values, ddof=1) / np.sqrt(len(dis_values))
    if len(dis_values) > 1
    else 0.0
)

cnt_mean = np.mean(cnt_values) if cnt_values else np.nan
cnt_se = (
    np.std(cnt_values, ddof=1) / np.sqrt(len(cnt_values))
    if len(cnt_values) > 1
    else 0.0
)

print("\n" + "=" * 60)
print("fast_soft_fci Experiment Results")
print("=" * 60)
print(f"Number of successful experiments: {len(acc_values)}/{n_admgs}")
print("\nACC (Accuracy):")
print(f"  Mean: {acc_mean:.4f}")
print(f"  Standard Error: {acc_se:.4f}")
print("\nDIS (Distance):")
print(f"  Mean: {dis_mean:.4f}")
print(f"  Standard Error: {dis_se:.4f}")
print("\nCNT (Coverage):")
print(f"  Mean: {cnt_mean:.4f}")
print(f"  Standard Error: {cnt_se:.4f}")

print("\nRuntime (seconds):")
runtime_mean = np.mean(runtime_values) if runtime_values else np.nan
runtime_se = (
    np.std(runtime_values, ddof=1) / np.sqrt(len(runtime_values))
    if len(runtime_values) > 1
    else 0.0
)
print(f"  Mean: {runtime_mean:.4f}")
print(f"  Standard Error: {runtime_se:.4f}")
print("=" * 60)
