"""
MC-COCO Experiment Runner

Implements all 6 experiment blocks from EXPERIMENT_PLAN.md.
Results saved to JSON for later analysis and plotting.
"""

import numpy as np
import json
import os
import time
import argparse
from pathlib import Path

import sys
sys.path.insert(0, str(Path(__file__).parent))
from algorithms import (
    MC1_ConstrainedExperts, MC3_SmoothConvex, MC1_Heterogeneous,
    NaiveIndependent, AdversaryExpert, AdversarySmooth,
    theoretical_ccv_bound_ce, theoretical_regret_bound_ce,
    theoretical_ccv_bound_smooth, theoretical_regret_bound_smooth,
)

# Default adversary type for expert setting
DEFAULT_ADV = 'conflicting'


RESULTS_DIR = Path(__file__).parent.parent / "results"
RESULTS_DIR.mkdir(exist_ok=True)


def run_mc1_experiment(N: int, K: int, T: int, beta: float,
                       adversary_type: str = DEFAULT_ADV, seed: int = 42,
                       num_seeds: int = 5) -> dict:
    """Run MC-1 (Constrained Experts) experiment with multiple seeds."""
    all_results = []

    for s in range(num_seeds):
        actual_seed = seed + s
        adv = AdversaryExpert(N, K, feasible_expert=0,
                              adversary_type=adversary_type, seed=actual_seed)
        algo = MC1_ConstrainedExperts(N, K, T, beta=beta)

        cum_cost_star = 0.0
        for t in range(T):
            f_t, g_t = adv.generate()
            algo.step(f_t, g_t)
            cum_cost_star += f_t[0]  # Feasible expert cost

        regret = sum(algo.cost_history) - cum_cost_star
        per_ccv = algo.get_per_constraint_ccv()
        max_ccv = algo.get_max_ccv()
        total_ccv = algo.get_total_ccv()

        all_results.append({
            'regret': float(regret),
            'max_ccv': float(max_ccv),
            'total_ccv': float(total_ccv),
            'per_ccv': per_ccv.tolist(),
            'seed': actual_seed,
        })

    # Aggregate
    regrets = [r['regret'] for r in all_results]
    max_ccvs = [r['max_ccv'] for r in all_results]
    total_ccvs = [r['total_ccv'] for r in all_results]

    # Theoretical bounds
    theo_regret = theoretical_regret_bound_ce(T, K, N, beta)
    theo_ccv = theoretical_ccv_bound_ce(T, K, N, beta)

    result = {
        'algorithm': 'MC-1',
        'N': N, 'K': K, 'T': T, 'beta': beta,
        'adversary_type': adversary_type,
        'regret_mean': float(np.mean(regrets)),
        'regret_std': float(np.std(regrets)),
        'max_ccv_mean': float(np.mean(max_ccvs)),
        'max_ccv_std': float(np.std(max_ccvs)),
        'total_ccv_mean': float(np.mean(total_ccvs)),
        'total_ccv_std': float(np.std(total_ccvs)),
        'theoretical_regret': float(theo_regret),
        'theoretical_per_ccv': float(theo_ccv),
        'raw_results': all_results,
    }
    return result


def run_naive_experiment(N: int, K: int, T: int, beta: float,
                         adversary_type: str = DEFAULT_ADV, seed: int = 42,
                         num_seeds: int = 5) -> dict:
    """Run Naive Independent baseline with multiple seeds."""
    all_results = []

    for s in range(num_seeds):
        actual_seed = seed + s
        adv = AdversaryExpert(N, K, feasible_expert=0,
                              adversary_type=adversary_type, seed=actual_seed)
        algo = NaiveIndependent(N, K, T, beta=beta)

        cum_cost_star = 0.0
        for t in range(T):
            f_t, g_t = adv.generate()
            algo.step(f_t, g_t)
            cum_cost_star += f_t[0]

        regret = sum(algo.cost_history) - cum_cost_star
        per_ccv = algo.get_per_constraint_ccv()
        max_ccv = algo.get_max_ccv()
        total_ccv = algo.get_total_ccv()

        all_results.append({
            'regret': float(regret),
            'max_ccv': float(max_ccv),
            'total_ccv': float(total_ccv),
            'per_ccv': per_ccv.tolist(),
            'seed': actual_seed,
        })

    regrets = [r['regret'] for r in all_results]
    max_ccvs = [r['max_ccv'] for r in all_results]

    result = {
        'algorithm': 'Naive-Independent',
        'N': N, 'K': K, 'T': T, 'beta': beta,
        'adversary_type': adversary_type,
        'regret_mean': float(np.mean(regrets)),
        'regret_std': float(np.std(regrets)),
        'max_ccv_mean': float(np.mean(max_ccvs)),
        'max_ccv_std': float(np.std(max_ccvs)),
        'raw_results': all_results,
    }
    return result


def run_mc3_experiment(d: int, K: int, T: int, D: float, M: float,
                       beta: float, seed: int = 42,
                       num_seeds: int = 5) -> dict:
    """Run MC-3 (Smooth Convex OGD) experiment."""
    all_results = []

    for s in range(num_seeds):
        actual_seed = seed + s
        adv = AdversarySmooth(d, K, D=D, M=M, seed=actual_seed)
        algo = MC3_SmoothConvex(d, K, T, D=D, M=M, beta=beta)

        cum_cost_star = 0.0  # Cost at origin (feasible point)
        for t in range(T):
            x = algo.get_action()
            f_val, f_grad, g_vals, g_grads = adv.generate(x)
            algo.step(f_val, f_grad, g_vals, g_grads)
            # Exact f_t(0) from adversary
            cum_cost_star += adv.get_last_f_at_origin()

        regret = sum(algo.cost_history) - cum_cost_star
        per_ccv = algo.get_per_constraint_ccv()
        max_ccv = algo.get_max_ccv()
        total_ccv = algo.get_total_ccv()

        all_results.append({
            'regret': float(regret),
            'max_ccv': float(max_ccv),
            'total_ccv': float(total_ccv),
            'per_ccv': per_ccv.tolist(),
            'seed': actual_seed,
        })

    regrets = [r['regret'] for r in all_results]
    max_ccvs = [r['max_ccv'] for r in all_results]
    total_ccvs = [r['total_ccv'] for r in all_results]

    theo_regret = theoretical_regret_bound_smooth(T, K, D, M, beta)
    theo_ccv = theoretical_ccv_bound_smooth(T, K, D, M, beta)

    result = {
        'algorithm': 'MC-3',
        'd': d, 'K': K, 'T': T, 'D': D, 'M': M, 'beta': beta,
        'regret_mean': float(np.mean(regrets)),
        'regret_std': float(np.std(regrets)),
        'max_ccv_mean': float(np.mean(max_ccvs)),
        'max_ccv_std': float(np.std(max_ccvs)),
        'total_ccv_mean': float(np.mean(total_ccvs)),
        'total_ccv_std': float(np.std(total_ccvs)),
        'theoretical_regret': float(theo_regret),
        'theoretical_per_ccv': float(theo_ccv),
        'raw_results': all_results,
    }
    return result


def run_hetero_experiment(N: int, K: int, T: int, beta: float,
                          alphas: list,
                          adversary_type: str = 'conflicting', seed: int = 42,
                          num_seeds: int = 5,
                          constraint_difficulty: list = None) -> dict:
    """Run MC-1 with heterogeneous prioritization."""
    all_results = []
    alphas_arr = np.array(alphas)

    cd = np.array(constraint_difficulty) if constraint_difficulty else None

    for s in range(num_seeds):
        actual_seed = seed + s
        adv = AdversaryExpert(N, K, feasible_expert=0,
                              adversary_type=adversary_type, seed=actual_seed,
                              constraint_difficulty=cd)
        algo = MC1_Heterogeneous(N, K, T, beta=beta, alphas=alphas_arr)

        cum_cost_star = 0.0
        for t in range(T):
            f_t, g_t = adv.generate()
            algo.step(f_t, g_t)
            cum_cost_star += f_t[0]

        regret = sum(algo.cost_history) - cum_cost_star
        per_ccv = algo.get_per_constraint_ccv()

        all_results.append({
            'regret': float(regret),
            'per_ccv': per_ccv.tolist(),
            'seed': actual_seed,
        })

    result = {
        'algorithm': 'MC-1-Hetero',
        'N': N, 'K': K, 'T': T, 'beta': beta,
        'alphas': alphas,
        'regret_mean': float(np.mean([r['regret'] for r in all_results])),
        'per_ccv_means': np.mean([r['per_ccv'] for r in all_results], axis=0).tolist(),
        'per_ccv_stds': np.std([r['per_ccv'] for r in all_results], axis=0).tolist(),
        'raw_results': all_results,
    }
    return result


# ============================================================
# Experiment Blocks
# ============================================================

def block1_sanity(num_seeds: int = 5) -> list:
    """Block 1: Sanity check for MC-1."""
    print("=" * 60)
    print("Block 1: Sanity Check — MC-CE Algorithm")
    print("=" * 60)

    results = []
    configs = [
        {'N': 50, 'K': 5, 'T': 10000, 'beta': 0.5},
        {'N': 50, 'K': 5, 'T': 10000, 'beta': 0.7},
        {'N': 50, 'K': 5, 'T': 10000, 'beta': 1.0},
    ]

    for i, cfg in enumerate(configs):
        print(f"\n  Run 1.{i+1}: N={cfg['N']}, K={cfg['K']}, T={cfg['T']}, β={cfg['beta']}")
        t0 = time.time()
        res = run_mc1_experiment(**cfg, num_seeds=num_seeds)
        dt = time.time() - t0
        print(f"    Regret: {res['regret_mean']:.1f} ± {res['regret_std']:.1f}"
              f"  (theory: {res['theoretical_regret']:.1f})")
        print(f"    Max CCV: {res['max_ccv_mean']:.1f} ± {res['max_ccv_std']:.1f}"
              f"  (theory: {res['theoretical_per_ccv']:.1f})")
        print(f"    Time: {dt:.1f}s")

        ratio = res['max_ccv_mean'] / res['theoretical_per_ccv']
        status = "PASS" if ratio < 3.0 else "FAIL"
        print(f"    CCV/Theory ratio: {ratio:.3f} — {status}")
        res['status'] = status
        res['time'] = dt
        results.append(res)

    return results


def block2_k_dependence(num_seeds: int = 5) -> list:
    """Block 2: K-dependence — logarithmic vs linear.

    Use beta=0.9 (large λ) so that the Lyapunov mechanism is active
    and MC-1 vs Naive difference becomes visible.
    Also test with beta=0.5 for comparison.
    """
    print("\n" + "=" * 60)
    print("Block 2: K-Dependence — Logarithmic vs Linear")
    print("=" * 60)

    results = []
    K_values = [2, 5, 10, 20, 50, 100]

    for beta in [0.9, 0.5]:
        print(f"\n  --- β = {beta} ---")
        for K in K_values:
            print(f"\n  K={K}:")
            t0 = time.time()

            # MC-1 (ours)
            res_mc1 = run_mc1_experiment(N=50, K=K, T=10000, beta=beta, num_seeds=num_seeds)
            # Naive baseline
            res_naive = run_naive_experiment(N=50, K=K, T=10000, beta=beta, num_seeds=num_seeds)

            dt = time.time() - t0
            print(f"    MC-1 Max CCV: {res_mc1['max_ccv_mean']:.1f} ± {res_mc1['max_ccv_std']:.1f}")
            print(f"    Naive Max CCV: {res_naive['max_ccv_mean']:.1f} ± {res_naive['max_ccv_std']:.1f}")
            print(f"    Theory CCV: {res_mc1['theoretical_per_ccv']:.1f}")
            diff_pct = 100 * (res_naive['max_ccv_mean'] - res_mc1['max_ccv_mean']) / max(res_mc1['max_ccv_mean'], 1)
            print(f"    MC-1 advantage: {diff_pct:.2f}%")
            print(f"    Time: {dt:.1f}s")

            results.append({
                'K': K,
                'beta': beta,
                'mc1': res_mc1,
                'naive': res_naive,
            })

    return results


def block3_t_scaling(num_seeds: int = 5) -> list:
    """Block 3: T-scaling — rate verification.

    Key fix: add β=0.9 to observe clear sub-linear CCV growth.
    At β=0.5, CCV ≈ 0.7T (linear) because λ is too small.
    At β=0.9, CCV should grow as T^{0.1} · polylog(T).
    """
    print("\n" + "=" * 60)
    print("Block 3: T-Scaling — Rate Verification")
    print("=" * 60)

    results = []
    T_values = [100, 500, 1000, 5000, 10000, 30000]

    for beta in [0.5, 0.7, 0.9]:
        print(f"\n  β = {beta}:")
        for T in T_values:
            t0 = time.time()
            res = run_mc1_experiment(N=50, K=10, T=T, beta=beta, num_seeds=num_seeds)
            dt = time.time() - t0
            # Compute diagnostic ratios
            ccv = res['max_ccv_mean']
            ccv_over_T = ccv / T if T > 0 else 0
            ccv_over_T_1mb = ccv / (T**(1-beta)) if T > 0 else 0
            print(f"    T={T:6d}: Regret={res['regret_mean']:8.1f}  "
                  f"Max CCV={ccv:8.1f}  "
                  f"CCV/T={ccv_over_T:.4f}  "
                  f"CCV/T^(1-β)={ccv_over_T_1mb:.1f}  "
                  f"(theory V={res['theoretical_per_ccv']:.0f})  "
                  f"[{dt:.1f}s]")
            res['T_actual'] = T
            res['beta_actual'] = beta
            results.append(res)

    return results


def block4_smooth_convex(num_seeds: int = 5) -> list:
    """Block 4: Smooth convex — MC-3 algorithm."""
    print("\n" + "=" * 60)
    print("Block 4: Smooth Convex — MC-3 Algorithm")
    print("=" * 60)

    results = []

    # Fixed params
    configs = [
        {'d': 10, 'K': 5, 'T': 10000, 'D': 1.0, 'M': 1.0, 'beta': 0.5},
        {'d': 10, 'K': 20, 'T': 10000, 'D': 1.0, 'M': 1.0, 'beta': 0.5},
        {'d': 10, 'K': 5, 'T': 10000, 'D': 1.0, 'M': 1.0, 'beta': 0.7},
        {'d': 10, 'K': 5, 'T': 10000, 'D': 1.0, 'M': 1.0, 'beta': 1.0},
    ]

    for i, cfg in enumerate(configs):
        print(f"\n  Run 4.{i+1}: d={cfg['d']}, K={cfg['K']}, T={cfg['T']}, β={cfg['beta']}")
        t0 = time.time()
        res = run_mc3_experiment(**cfg, num_seeds=num_seeds)
        dt = time.time() - t0
        print(f"    Regret: {res['regret_mean']:.1f} ± {res['regret_std']:.1f}")
        print(f"    Max CCV: {res['max_ccv_mean']:.1f} ± {res['max_ccv_std']:.1f}")
        print(f"    Theory: R={res['theoretical_regret']:.0f}, V={res['theoretical_per_ccv']:.0f}")
        print(f"    Time: {dt:.1f}s")
        res['time'] = dt
        results.append(res)

    # T-scaling for smooth
    print("\n  T-scaling for MC-3:")
    T_values = [100, 500, 1000, 5000, 10000, 30000]
    for T in T_values:
        t0 = time.time()
        res = run_mc3_experiment(d=10, K=10, T=T, D=1.0, M=1.0, beta=0.5,
                                 num_seeds=num_seeds)
        dt = time.time() - t0
        print(f"    T={T:6d}: Regret={res['regret_mean']:8.1f}  "
              f"Max CCV={res['max_ccv_mean']:8.1f}  [{dt:.1f}s]")
        res['time'] = dt
        results.append(res)

    return results


def block5_heterogeneous(num_seeds: int = 5) -> list:
    """Block 5: Heterogeneous constraint prioritization.

    Key fix: Use CONFLICTING adversary (not asymmetric with equal difficulty).
    In the conflicting adversary, every non-feasible expert violates most
    constraints heavily (~0.75). The algorithm must trade off between
    constraints. With larger α_k, constraint k gets stronger Lyapunov
    penalty (λ_k = α_k · Λ), so it should achieve lower CCV.

    Theory: CCV_k = O(T^{1-β} ln N / α_k)
    So CCV_k / CCV_j should ≈ α_j / α_k
    """
    print("\n" + "=" * 60)
    print("Block 5: Heterogeneous Constraint Prioritization")
    print("=" * 60)

    results = []

    alpha_configs = [
        {'name': 'uniform', 'alphas': [1.0, 1.0, 1.0, 1.0, 1.0]},
        {'name': 'geometric', 'alphas': [1.0, 0.5, 0.25, 0.125, 0.0625]},
        {'name': 'one_critical', 'alphas': [1.0, 1.0, 1.0, 1.0, 0.01]},
    ]

    for beta_val in [0.9, 1.0]:
        print(f"\n  --- β = {beta_val} ---")
        for cfg in alpha_configs:
            print(f"\n  Config '{cfg['name']}': α = {cfg['alphas']}")
            t0 = time.time()
            # Use CONFLICTING adversary: constraints compete with each other
            # Only this adversary creates inter-constraint tension where
            # alpha_k weighting can influence which constraints get protected
            res = run_hetero_experiment(
                N=50, K=5, T=10000, beta=beta_val,
                alphas=cfg['alphas'], num_seeds=num_seeds,
                adversary_type='conflicting',
                constraint_difficulty=None)
            dt = time.time() - t0

            per_ccv = res['per_ccv_means']
            per_ccv_std = res['per_ccv_stds']
            print(f"    Regret: {res['regret_mean']:.1f}")
            for k in range(5):
                print(f"    CCV_{k}: {per_ccv[k]:.1f} ± {per_ccv_std[k]:.1f}"
                      f"  (α_k={cfg['alphas'][k]:.4f}, 1/α_k={1.0/cfg['alphas'][k]:.1f})")
            # Compute ratio to show α_k effect
            if cfg['name'] != 'uniform':
                ref_ccv = per_ccv[0]  # CCV for α_0=1.0
                print(f"    CCV ratios (CCV_k/CCV_0):")
                for k in range(5):
                    ratio = per_ccv[k] / ref_ccv if ref_ccv > 0 else 0
                    expected = 1.0 / cfg['alphas'][k]  # Expected: α_0/α_k = 1/α_k
                    print(f"      k={k}: actual={ratio:.3f}, expected(1/α_k)={expected:.1f}")
            print(f"    Time: {dt:.1f}s")

            res['config_name'] = cfg['name']
            res['beta'] = beta_val
            res['time'] = dt
            results.append(res)

    return results


def block6_tradeoff(num_seeds: int = 5) -> list:
    """Block 6: Regret × CCV trade-off."""
    print("\n" + "=" * 60)
    print("Block 6: Regret × CCV Trade-off")
    print("=" * 60)

    results = []

    # Vary beta
    print("\n  Varying β with K=10, T=10000:")
    for beta in [0.3, 0.5, 0.7, 0.9, 1.0]:
        res = run_mc1_experiment(N=50, K=10, T=10000, beta=beta, num_seeds=num_seeds)
        abs_regret = abs(res['regret_mean'])
        product = abs_regret * res['total_ccv_mean']
        KT = 10 * 10000
        print(f"    β={beta:.1f}: Regret={res['regret_mean']:.1f}, "
              f"TotalCCV={res['total_ccv_mean']:.1f}, "
              f"|R|×V={product:.0f}, |R|×V/(KT)={product/KT:.4f}")
        res['product'] = float(product)
        res['product_normalized'] = float(product / KT)
        results.append(res)

    # Vary K at beta=1
    print("\n  Varying K at β=1.0, T=10000:")
    for K in [2, 5, 10, 20, 50]:
        res = run_mc1_experiment(N=50, K=K, T=10000, beta=1.0, num_seeds=num_seeds)
        abs_regret = abs(res['regret_mean'])
        product = abs_regret * res['total_ccv_mean']
        KT = K * 10000
        print(f"    K={K:3d}: Regret={res['regret_mean']:.1f}, "
              f"TotalCCV={res['total_ccv_mean']:.1f}, "
              f"|R|×V={product:.0f}, |R|×V/(KT)={product/KT:.4f}")
        res['product'] = float(product)
        res['product_normalized'] = float(product / KT)
        results.append(res)

    return results


# ============================================================
# Main
# ============================================================

def main():
    parser = argparse.ArgumentParser(description='MC-COCO Experiments')
    parser.add_argument('--blocks', nargs='+', type=int, default=[1, 2, 3, 4, 5, 6],
                        help='Which experiment blocks to run (1-6)')
    parser.add_argument('--num-seeds', type=int, default=5,
                        help='Number of random seeds per experiment')
    parser.add_argument('--output-dir', type=str, default=str(RESULTS_DIR),
                        help='Output directory for results')
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    all_results = {}
    t_total = time.time()

    if 1 in args.blocks:
        all_results['block1_sanity'] = block1_sanity(args.num_seeds)
    if 2 in args.blocks:
        all_results['block2_k_dependence'] = block2_k_dependence(args.num_seeds)
    if 3 in args.blocks:
        all_results['block3_t_scaling'] = block3_t_scaling(args.num_seeds)
    if 4 in args.blocks:
        all_results['block4_smooth_convex'] = block4_smooth_convex(args.num_seeds)
    if 5 in args.blocks:
        all_results['block5_heterogeneous'] = block5_heterogeneous(args.num_seeds)
    if 6 in args.blocks:
        all_results['block6_tradeoff'] = block6_tradeoff(args.num_seeds)

    dt_total = time.time() - t_total
    print(f"\n{'='*60}")
    print(f"Total time: {dt_total:.1f}s")

    # Save results
    output_file = output_dir / 'experiment_results.json'
    with open(output_file, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    print(f"Results saved to {output_file}")


if __name__ == '__main__':
    main()
