# -----------------------------------------------
# Install (only once):
#   pip install pyagrum
# -----------------------------------------------

import numpy as np
import time
from soft_fci import ifci
from Utils import (
    random_admg,
    random_intervention_targets,
    generate_interventional_data_pyagrum_from_admg,
    admg_to_mag,
    mag_two_mats_to_marks,
    edge_mark_metrics_acc_dis_cnt,
)


# Experiment parameters
n_obs = 5
rho_dag = 0.5
rho_bi = 0.8
n_admgs = 100
n_targets = 3
max_target_size = 2
n_samples = 50000
max_cond_set_size = 2
alpha = 0.01
ci_method = "chisq"
beta_a = 0.5
beta_b = 0.5
min_variation = 0.4

# Storage for metrics
acc_values = []
dis_values = []
cnt_values = []
runtime_values = []

# Main experiment loop
for exp_idx in range(n_admgs):
    # Set seed for reproducibility
    seed = exp_idx
    
    # 1. Generate random ADMG
    A_di, A_bi = random_admg(n_obs=n_obs, rho_dag=rho_dag, rho_bi=rho_bi, seed=seed)
    
    # 2. Generate random target list with 2 targets, each size <= 1
    targets = random_intervention_targets(
        n_obs=n_obs,
        n_targets=n_targets,
        max_target_size=max_target_size,
        seed=seed
    )
    
    # 3. Generate interventional data using pyagrum generator
    # Include observational data (empty target) as the first environment
    all_targets = [frozenset()] + targets  # Observational first, then interventions
    data_dict = generate_interventional_data_pyagrum_from_admg(
        A_di=A_di,
        A_bi=A_bi,
        targets=all_targets,
        n_samples=n_samples,
        seed=seed,
        beta_a=beta_a,
        beta_b=beta_b,
        min_variation=min_variation
    )
    
    # 4. Format targets as dictionary for IFCI
    # IFCI expects targets as Dict[int, FrozenSet[int]]
    targets_dict = {}
    for k, target in enumerate(all_targets):
        targets_dict[k] = target
    
    # 5. Run IFCI (with timing)
    try:
        start_time = time.time()
        state = ifci(
            datasets=data_dict,
            targets=targets_dict,
            max_cond_set_size=max_cond_set_size,
            ci_method=ci_method,
            alpha=alpha,
            seed=seed
        )
        runtime = time.time() - start_time
        runtime_values.append(runtime)
    except Exception as e:
        print(f"Warning: IFCI failed for experiment {exp_idx}: {e}")
        continue
    
    # 6. Extract the first n_obs x n_obs submatrix from state.A
    # state.A has shape (n_obs + n_F_nodes, n_obs + n_F_nodes)
    # We only want the observed nodes part
    pred_mark_matrix = state.A[:n_obs, :n_obs].copy()
    
    # 7. Convert input ADMG to MAG
    M_di, M_bi = admg_to_mag(A_di, A_bi)
    true_mag_mark_matrix = mag_two_mats_to_marks(M_di, M_bi)
    
    # 8. Compute metrics
    ACC, DIS, CNT, _ = edge_mark_metrics_acc_dis_cnt(
        pred_mark_matrix, true_mag_mark_matrix
    )
    
    # Store metrics (skip NaN values)
    if not np.isnan(ACC):
        acc_values.append(ACC)
    if not np.isnan(DIS):
        dis_values.append(DIS)
    if not np.isnan(CNT):
        cnt_values.append(CNT)
    
    # Progress indicator
    if (exp_idx + 1) % 10 == 0:
        print(f"Completed {exp_idx + 1}/{n_admgs} experiments")

# 9. Report mean and standard error
acc_mean = np.mean(acc_values) if acc_values else np.nan
acc_se = np.std(acc_values, ddof=1) / np.sqrt(len(acc_values)) if len(acc_values) > 1 else 0.0

dis_mean = np.mean(dis_values) if dis_values else np.nan
dis_se = np.std(dis_values, ddof=1) / np.sqrt(len(dis_values)) if len(dis_values) > 1 else 0.0

cnt_mean = np.mean(cnt_values) if cnt_values else np.nan
cnt_se = np.std(cnt_values, ddof=1) / np.sqrt(len(cnt_values)) if len(cnt_values) > 1 else 0.0

print("\n" + "="*60)
print("Experiment Results")
print("="*60)
print(f"Number of successful experiments: {len(acc_values)}/{n_admgs}")
print(f"\nACC (Accuracy):")
print(f"  Mean: {acc_mean:.4f}")
print(f"  Standard Error: {acc_se:.4f}")
print(f"\nDIS (Distance):")
print(f"  Mean: {dis_mean:.4f}")
print(f"  Standard Error: {dis_se:.4f}")
print(f"\nCNT (Coverage):")
print(f"  Mean: {cnt_mean:.4f}")
print(f"  Standard Error: {cnt_se:.4f}")
print(f"\nRuntime (seconds):")
runtime_mean = np.mean(runtime_values) if runtime_values else np.nan
runtime_se = np.std(runtime_values, ddof=1) / np.sqrt(len(runtime_values)) if len(runtime_values) > 1 else 0.0
print(f"  Mean: {runtime_mean:.4f}")
print(f"  Standard Error: {runtime_se:.4f}")
print("="*60)
