#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Run automated experiments for:
  - Mixing function types (the commented lambdas in your f_mixing cell)
  - Number of environments k ∈ {3, 6, 9}
  - Noise distributions: Gaussian; Gamma(alpha=1, theta=2); Gamma(alpha=2, theta=2)

For each (func, k, dist) and for each seed, we:
  1) Generate k environments (env 0 is "base"), sample sources with requested distribution,
     apply the mixing function to get X_e, and stack to shape (k, n, d).
  2) Estimate Σ̂ per Algorithm 1 using your Stein routines: find_sigma_per_group(...).
  3) Solve for J_inv with j_inv_linear_system, permute/scale as in your cell, estimate causal order.
  4) Log a row with SHD (0/1 for the bivariate case) and metadata.

Outputs:
  - CSV at logs/experiments.csv (appends by default).
"""

import os
import csv
import time
import argparse
from math import floor
from typing import Dict, Callable, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment

# --- Your library imports (must exist in your repo) ---
# find_sigma_per_group must accept (X, stein_fn, first_group_index=int, verbose=bool)

from utils.score import find_sigma_per_group, stein_score_hess_all 
from utils.jacobian import j_inv_linear_system  # type: ignore
from utils.lingam import estimate_causal_order as _estimate_causal_order  # type: ignore

# -------------------------
# Mixing function zoo (your lambdas from the notebook cell)
# -------------------------
def _apply_mixing(s: np.ndarray, f: Callable[[np.ndarray, np.ndarray], np.ndarray], direction: str) -> np.ndarray:
    """
    Apply the bivariate mixing function.

    Parameters
    ----------
    s : ndarray of shape (n, 2)
        Source samples for one environment.
    f : callable
        Mixing map `f(cause, effect) -> array_like` used for the second component.
    direction : {"forward", "backward"}
        Whether the ground-truth direction is `cause -> effect` ("forward") or reversed.

    Returns
    -------
    X : ndarray of shape (n, 2)
        Mixed observations. For "forward", returns [s1, f(s1,s2)].
        For "backward", returns [f(s2,s1), s2].
    """
    if direction not in ("forward", "backward"):
        raise ValueError("direction must be 'forward' or 'backward'")
    s0, s1 = s[:, 0], s[:, 1]
    if direction == "forward":
        x0 = s0
        x1 = f(s0, s1)
    else:
        x0 = f(s1, s0)
        x1 = s1
    return np.stack([x0, x1], axis=1)


FUNCTIONS: Dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = {
    "linear_5c_plus_e":                     (lambda cause, effect: 5 * cause + effect),                      # Linear
    "anm_c2_plus_e":                        (lambda cause, effect: cause**2 + effect),                        # ANM
    "pnl_cpluse_cubed":                     (lambda cause, effect: (cause + effect)**3),                      # PNL
    "lsnm_c2_times_e":                      (lambda cause, effect: cause**2 * effect),                        # LSNM
    "c2e_plus_arctan_e":                    (lambda cause, effect: cause**2 * effect + np.arctan(effect)),    # New! (good)
    "c2_arctan_e_plus_e3":                  (lambda cause, effect: cause**2 * np.arctan(effect) + effect**3), # New! (good)
    "c2_plus_arctan_c_times_e_plus_c_e3":   (lambda cause, effect: cause**2 + np.arctan(cause) * effect + cause * effect**3),  # New! (workish)
}


# -------------------------
# Sampling k environments with requested distribution
# -------------------------
def _sample_env_sources(
    rng: np.random.Generator,
    n: int,
    d: int,
    k: int,
    distribution: str,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Sample sources for `k` environments with per-dimension rescaling and
    internally chosen distribution parameters.

    Parameters
    ----------
    rng : numpy.random.Generator
        Random generator for reproducibility.
    n : int
        Number of samples per environment.
    d : int
        Dimensionality (must be 2 in this pipeline).
    k : int
        Total number of environments (including the base one).
    distribution : {"gaussian", "gamma_a1_theta2", "gamma_a2_theta2"}
        Source family for each environment:
        - "gaussian":        mu ~ U(0.5, 1.0), sigma ~ U(1.0, 1.5) per dimension
        - "gamma_a1_theta2": alpha ~ U(0.5, 1.0),  theta ~ U(1.75, 2.25) per dimension
        - "gamma_a2_theta2": alpha ~ U(2.0, 2.5),  theta ~ U(1.75, 2.25) per dimension

    Returns
    -------
    S_stack : ndarray of shape (k, n, d)
        Source samples per environment.
    L_list : ndarray of shape (k, d)
        Per-environment, per-dimension rescaling factors. `L_list[0] = 1`,
        and for e >= 1, `L_list[e] ~ 1 + U(1, 1.5)` independently by dimension.

    Notes
    -----
    - Gaussian env e: S_e ~ Normal(mu=mu, scale=sigma * L_e)
    - Gamma env e:    S_e ~ Gamma(shape=alpha, scale=theta * L_e)
      (theta is the “scale” parameter; L_e rescales per environment.)
    """
    # Per-environment rescaling (env 0 = base)
    L_list = np.ones((k, d))
    if k > 1:
        # UPDATED per your request:
        L_list[1:] = rng.uniform(1, 1.5, size=(k - 1, d)) + 1.0

    envs = []

    if distribution == "gaussian":
        mu = rng.uniform(.5, 1, size=d)
        sigma = rng.uniform(1.0, 1.5, size=d)
        for e in range(k):
            loc = mu
            scale = sigma * L_list[e]
            S = rng.normal(loc=loc, scale=scale, size=(n, d))
            envs.append(S)

    elif distribution == "gamma_a1_theta2":
        alpha = rng.uniform(0.5, 1.0, size=d)
        theta = rng.uniform(1.75, 2.25, size=d)
        for e in range(k):
            S = rng.gamma(shape=alpha, scale=theta * L_list[e], size=(n, d))
            envs.append(S)

    elif distribution == "gamma_a2_theta2":
        alpha = rng.uniform(2.0, 2.5, size=d)
        theta = rng.uniform(1.75, 2.25, size=d)
        for e in range(k):
            S = rng.gamma(shape=alpha, scale=theta * L_list[e], size=(n, d))
            envs.append(S)

    else:
        raise ValueError(f"Unknown distribution: {distribution}")

    return np.stack(envs, axis=0), L_list



def _make_X_stack(
    S_stack: np.ndarray,
    f_lambda: Callable[[np.ndarray, np.ndarray], np.ndarray],
    direction: str
) -> np.ndarray:
    """
    Apply the mixing function to each environment independently.

    Parameters
    ----------
    S_stack : ndarray of shape (k, n, d)
        Sources per environment.
    f_lambda : callable
        Mixing function `f(cause, effect)`.
    direction : {"forward", "backward"}
        Ground-truth causal direction for constructing X.

    Returns
    -------
    X_stack : ndarray of shape (k, n, d)
        Mixed observations per environment.
    """
    X_envs = []
    for e in range(S_stack.shape[0]):
        X_envs.append(_apply_mixing(S_stack[e], f_lambda, direction))
    return np.stack(X_envs, axis=0)


def _estimate_order_from_J(J_inv_hat: np.ndarray) -> str:
    """
    Infer causal direction from an estimated inverse Jacobian.

    Parameters
    ----------
    J_inv_hat : ndarray of shape (2, 2)
        Estimated inverse Jacobian (after permutation and diagonal normalization).

    Returns
    -------
    direction : {"forward", "backward"}
        Heuristic direction; refined with LiNGAM if available.
    """
    B_est = np.eye(J_inv_hat.shape[0]) - J_inv_hat
    order = _estimate_causal_order(B_est)  # returns [0,1] or [1,0]
    if order == [0, 1]:
        return "forward"
    elif order == [1, 0]:
        return "backward"



def main():
    parser = argparse.ArgumentParser(description="Automate experiments and log SHD.")
    parser.add_argument("--seeds", type=int, default=50, help="Number of seeds per configuration.")
    parser.add_argument("--n", type=int, default=2000, help="Samples per environment.")
    parser.add_argument("--d", type=int, default=2, help="Dimensionality (must be 2 for these experiments).")
    parser.add_argument("--direction", type=str, default="forward", choices=["forward", "backward"],
                        help="Ground-truth direction used by the mixing function.")
    parser.add_argument("--ks", type=int, nargs="+", default=[3, 6, 9], help="Environment counts to test.")
    parser.add_argument("--functions", type=str, nargs="*", default=list(FUNCTIONS.keys()),
                        help="Subset of function names to run; default = all.")
    parser.add_argument("--distributions", type=str, nargs="+",
                        default=["gaussian", "gamma_a1_theta2", "gamma_a2_theta2"],
                        help="Which noise families to run.")
    parser.add_argument("--out_csv", type=str, default="logs/experiments.csv", help="CSV log file path.")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite CSV instead of appending.")
    parser.add_argument("--stein_eta_g", type=float, default=1e-2)
    parser.add_argument("--stein_eta_h", type=float, default=1e-2)
    args = parser.parse_args()

    if args.d != 2:
        raise ValueError("This pipeline expects d=2 (bivariate). Set --d 2.")

    os.makedirs(os.path.dirname(args.out_csv), exist_ok=True)

    write_header = args.overwrite or (not os.path.exists(args.out_csv))
    mode = "w" if args.overwrite else "a"

    # Construct Stein callable
    stein_fn = lambda x: stein_score_hess_all(x, eta_G=args.stein_eta_g, eta_H=args.stein_eta_h)

    # ---------- PRINT: high-level run summary ----------
    total_runs = len(args.functions) * len(args.ks) * len(args.distributions) * args.seeds
    print("========== Experiment Runner ==========")
    print(f"Output CSV          : {args.out_csv}")
    print(f"d                   : {args.d}")
    print(f"n per environment   : {args.n}")
    print(f"ground-truth dir    : {args.direction}")
    print(f"functions           : {args.functions}")
    print(f"k values            : {args.ks}")
    print(f"distributions       : {args.distributions}")
    print(f"seeds per config    : {args.seeds}")
    print(f"Total runs          : {total_runs}")
    print("=======================================\n")

    run_idx = 0
    t0_all = time.time()

    with open(args.out_csv, mode, newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow([
                "timestamp", "func_name", "distribution", "k_envs", "seed",
                "n", "d", "direction_true", "estimated_direction",
                "shd", "success", "failure_reason"
            ])

        for func_name in args.functions:
            if func_name not in FUNCTIONS:
                print(f"[WARN] Skipping unknown func '{func_name}'.")
                continue

            f_lambda = FUNCTIONS[func_name]
            print(f"\n--- Function: {func_name} ---")

            for k in args.ks:
                print(f"  > k={k} environments")

                for dist in args.distributions:
                    print(f"    · distribution={dist}")

                    for seed in range(args.seeds):
                        run_idx += 1
                        t_start = time.time()
                        rng = np.random.default_rng(seed)

                        try:
                            # 1) sample sources + mixing
                            S_stack, L_list = _sample_env_sources(rng, args.n, args.d, k, dist)
                            X_stack = _make_X_stack(S_stack, f_lambda, args.direction)

                            # 2) Σ̂ components via Stein (use mid-split)
                            first_group = floor(k / 2)
                            Sigma_hat_per_group, _ = find_sigma_per_group(
                                X_stack, stein_fn, first_group_index=first_group, verbose=False
                            )
                            H_diff_1, H_diff_2 = Sigma_hat_per_group

                            # 3) Solve linear system for J^{-1}
                            J_inv_hat = j_inv_linear_system(H_diff_1, H_diff_2, tol=1e-6)
                            if J_inv_hat is None:
                                dt = time.time() - t_start
                                print(f"      [seed={seed:03d}] FAIL (J_inv_hat=None)  "
                                      f"[{run_idx}/{total_runs}]  {dt:.2f}s")
                                writer.writerow([time.time(), func_name, dist, k, seed,
                                                 args.n, args.d, args.direction, "", 1, 0,
                                                 "j_inv_linear_system returned None"])
                                continue

                            # 4) Hungarian perm + scaling (mirror notebook)
                            _, col_index = linear_sum_assignment(1.0 / np.abs(J_inv_hat))
                            PW_ica = np.zeros_like(J_inv_hat)
                            PW_ica[col_index] = J_inv_hat
                            D = np.diag(PW_ica)[:, np.newaxis]
                            J_inv_hat = PW_ica / D

                            # 5) Causal order & SHD
                            est_dir = _estimate_order_from_J(J_inv_hat)
                            shd = 0 if est_dir == args.direction else 1

                            writer.writerow([time.time(), func_name, dist, k, seed,
                                             args.n, args.d, args.direction, est_dir, shd, 1, ""])

                            dt = time.time() - t_start
                            print(f"      [seed={seed:03d}] OK  dir={est_dir:<8} SHD={shd}  "
                                  f"[{run_idx}/{total_runs}]  {dt:.2f}s")

                        except Exception as ex:
                            dt = time.time() - t_start
                            print(f"      [seed={seed:03d}] EXC {type(ex).__name__}: {ex}  "
                                  f"[{run_idx}/{total_runs}]  {dt:.2f}s")
                            writer.writerow([time.time(), func_name, dist, k, seed,
                                             args.n, args.d, args.direction, "", 1, 0, repr(ex)])

    dt_all = time.time() - t0_all
    print(f"\n✓ Done. Logs written to {args.out_csv}")
    print(f"Total elapsed time: {dt_all:.2f}s")
    print("Tip: run `python plot_results.py --in_csv logs/experiments.csv` to produce the figures.")


if __name__ == "__main__":
    main()

