"""
Speedup ablation study: Comparing CLARITree vs CLARITreeFull performance.

This script benchmarks the speedup achieved by CLARITree's rank-one updates
compared to CLARITreeFull's full recomputation approach.
"""
import gc
import matplotlib.pyplot as plt

import time
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from clari_tree import (
    CLARITree, CLARITreeFull,
)


def make_multibeta_linear(
    n=5000, p=10, n_groups=3, rho=0.5, noise=0.2, seed=0
):
    """
    Generate a continuous regression dataset with `n_groups` distinct linear regions.
    Each region (group) corresponds to a different set of beta coefficients.

    Model:
        X ~ N(0, Σ),   Σ_ij = rho^|i-j|
        y_i = X_i^T β_{g_i} + ε_i,   ε_i ~ N(0, σ²)
        g_i is determined by the first two features (X_0, X_1),
        automatically partitioned into `n_groups` blocks.

    Parameters:
        n         Number of samples
        p         Number of features
        n_groups  Number of distinct beta groups
        rho       Feature correlation coefficient
        noise     Standard deviation of the noise
        seed      Random seed
    """
    rng = np.random.default_rng(seed)

    # === 1. Construct correlated feature matrix ===
    idx = np.arange(p)
    Sigma = rho ** np.abs(idx[:, None] - idx[None, :])
    L = np.linalg.cholesky(Sigma + 1e-12 * np.eye(p))
    Z = rng.standard_normal((n, p))
    X = Z @ L.T

    # === 2. Partition samples by (X0, X1) ===
    # Divide the (X0, X1) space evenly into `n_groups` stripe-like regions
    quantiles = np.linspace(0, 1, n_groups + 1)
    thresholds = np.quantile(X[:, 0], quantiles[1:-1])
    g = np.digitize(X[:, 0], thresholds)

    # === 3. Generate independent beta for each group ===
    betas = []
    rng_b = np.random.default_rng(seed + 7)
    for j in range(n_groups):
        b = rng_b.normal(0, 1, p)
        b *= (1.0 + 0.5 * j) * np.exp(-np.linspace(0, 2, p))
        betas.append(b)

    # === 4. Generate target variable y ===
    f = np.zeros(n)
    for j in range(n_groups):
        f[g == j] = X[g == j] @ betas[j]
    y = f + noise * rng.standard_normal(n)

    return X, y, g, betas


def time_fit(model_cls, X, y, **kwargs):
    """Return fit time in seconds (no prediction), with GC cleanup for stability."""
    gc.collect()
    start = time.time()
    model = model_cls(**kwargs)
    _ = model.fit(X, y)
    return time.time() - start


def _benchmark_single_n(args):
    """
    Helper to benchmark a single n, so we can run different n values in parallel.
    """
    (
        n,
        p,
        depth,
        kappa,
        lambda_,
        stride,
        repeats,
        rho,
        noise,
        seed,
    ) = args

    # Fixed data to avoid sampling differences between models
    X, y, g, betas = make_multibeta_linear(
        n=n, p=p, n_groups=4, rho=rho, noise=noise, seed=seed
    )
    
    # Standardize X (feature columns, excluding intercept column)
    scaler_X = StandardScaler()
    X = scaler_X.fit_transform(X)
    
    # Add intercept column (models require first column as intercept, must be added after standardization)
    X = np.concatenate([np.ones((X.shape[0], 1)), X], axis=1)
    
    # Center y (subtract mean)
    y_mean = np.mean(y)
    y = y - y_mean

    # stride calculated dynamically: n // 20, but at least 1
    stride = max(1, n // 20) if stride is None else stride

    # Multiple repeats for averaging, more stable
    clari_times = []
    full_times = []
    for r in range(repeats):
        clari_t = time_fit(
            CLARITree,
            X,
            y,
            depth=depth,
            kappa=kappa,
            lambda_=lambda_,
            stride=stride,
            verbose=False,
        )
        full_t = time_fit(
            CLARITreeFull,
            X,
            y,
            depth=depth,
            kappa=kappa,
            lambda_=lambda_,
            stride=stride,
            verbose=False,
        )
        clari_times.append(clari_t)
        full_times.append(full_t)

    clari_avg = float(np.mean(clari_times))
    full_avg = float(np.mean(full_times))
    speedup = full_avg / clari_avg if clari_avg > 0 else np.nan

    return {
        "n": n,
        "t_clari_tree_s": clari_avg,
        "t_clari_tree_full_s": full_avg,
        "speedup_full_over_clari": speedup,
    }

def benchmark_speedup(
    n_list=(100, 200, 400, 1600, 3200),
    p=10,
    depth=5,
    kappa=1e-3,
    lambda_=0.0,
    stride=None,  # None means auto-calculate as n // 20
    repeats=3,
    rho=0.8,
    noise=0.0,
    seed=0,
    parallel=True,
):
    """
    Benchmark speedup for a list of n.

    If parallel=True, each n in n_list is run in parallel using multiple
    processes; tqdm will show the overall progress.
    
    If stride=None, stride will be automatically calculated as n // 20 for each n.
    """
    rows = []

    args_list = [
        (n, p, depth, kappa, lambda_, stride, repeats, rho, noise, seed)
        for n in n_list
    ]

    if parallel:
        # Parallel execution for each n
        with ProcessPoolExecutor() as ex:
            futures = {
                ex.submit(_benchmark_single_n, args): args[0] for args in args_list
            }
            for fut in tqdm(
                as_completed(futures),
                total=len(futures),
                desc="Running benchmark over n (parallel)",
                unit="n",
            ):
                result = fut.result()
                rows.append(result)
                # Print result for each n immediately
                n_val = result["n"]
                speedup = result["speedup_full_over_clari"]
                t_clari = result["t_clari_tree_s"]
                t_full = result["t_clari_tree_full_s"]
                print(f"  n={n_val:5d}: CLARITree={t_clari:.4f}s, CLARITreeFull={t_full:.4f}s, Speedup={speedup:.2f}x")
    else:
        # Serial version (still with tqdm)
        for args in tqdm(
            args_list,
            total=len(args_list),
            desc="Running benchmark over n (serial)",
            unit="n",
        ):
            result = _benchmark_single_n(args)
            rows.append(result)
            # Print result for each n immediately
            n_val = result["n"]
            speedup = result["speedup_full_over_clari"]
            t_clari = result["t_clari_tree_s"]
            t_full = result["t_clari_tree_full_s"]
            print(f"  n={n_val:5d}: CLARITree={t_clari:.4f}s, CLARITreeFull={t_full:.4f}s, Speedup={speedup:.2f}x")

    df = pd.DataFrame(rows).sort_values("n").reset_index(drop=True)
    return df


if __name__ == "__main__":
    # Run benchmark
    df_speed = benchmark_speedup(
        n_list=[128],
        p=2,
        depth=4,
        kappa=1e-16,
        lambda_=0.0,
        stride=1,  # None means auto-calculate as n // 20 for each n
        repeats=5,  # Set to 1 for faster runs; set to 5+ for more stability
        rho=0.8,
        noise=0.0,
        seed=1,
        parallel=True,
    )

    # Print results table
    print("\n=== Speedup table (speedup = t_full / t_clari) ===")
    print(
        df_speed.to_string(
            index=False,
            formatters={
                "t_clari_tree_s": lambda v: f"{v:.4f}",
                "t_clari_tree_full_s": lambda v: f"{v:.4f}",
                "speedup_full_over_clari": lambda v: f"{v:.2f}x",
            },
        )
    )

    # Plot: n vs speedup
    plt.figure()
    plt.plot(
        df_speed["n"].values,
        df_speed["speedup_full_over_clari"].values,
        marker="o",
    )
    plt.xlabel("n (number of samples)")
    plt.ylabel("Speedup: CLARITreeFull time / CLARITree time")
    plt.title("CLARITree vs CLARITreeFull Speedup vs n")
    plt.xscale("log", base=2)  # n list is doubling, log scale is more intuitive
    plt.grid(True, which="both", linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig("speedup_vs_n.png", dpi=200)
    print("Plot saved to speedup_vs_n.png")
