"""
Main experiment script for RCDP-UCB (Figure 1 & 2 Reproduction).
Runs robustness experiments comparing RCDP-UCB against various baselines 
under different delay and corruption scenarios.
"""
import numpy as np
import matplotlib.pyplot as plt

import argparse
import pandas as pd
import os
from contextual_dueling_bandit import (
    ContextualDuelingBanditEnv, 
    DuelingGLMLearner, 
    BaselineDuelingGLMLearner,
    NonRobustDuelingGLMLearner,
    RobustBaselineDuelingGLMLearner,
    OracleDuelingGLMLearner, # Kept for potential future use if needed, though removed from main loop
    RCDBLearner,
    MaxInPLearner,
    MaxPairUCBLearner,
    ColSTIMLearner,
    StrategicDelay,
    StochasticGaussianDelay,
    StrategicOutcomeCorruption,
    run_simulation
)

def run_experiment_for_mapping(phi_type, args, d, e, K, T):
    print(f"\n--- Running Experiment for Mapping: {phi_type} (n_runs={args.n_runs}, Post-Serving e={e}) ---")
    
    # Store results: [Method Name] -> (n_runs, T)
    method_regrets = {}
    
    # We will determine method names dynamically from the first run or pre-define them
    # Focus only on RCDP-UCB with a=0.1
    our_method_names = ["RCDP-UCB (Ours)"]
    baseline_names = ["RCDB", "ColSTIM", "MaxInP", "MaxPairUCB"]
    method_names = our_method_names + baseline_names
    
    for name in method_names:
        method_regrets[name] = np.zeros((args.n_runs, T))

    # Determine seed based on phi_type index to avoid collisions
    phi_types_list = ['polynomial', 'abs', 'cosine', 'sinusoidal']
    try:
        phi_idx = phi_types_list.index(phi_type)
    except ValueError:
        phi_idx = 0
        
    base_seed = 42 + phi_idx * 1000 + 999 

    for run_idx in range(args.n_runs):
        run_seed = base_seed + run_idx * 100
        print(f"  > Run {run_idx+1}/{args.n_runs} (Seed {run_seed})")
        
        # 1. Setup Environment Parameters for this RUN
        np.random.seed(run_seed)
        run_theta = np.random.randn(d) * 0.1
        run_zeta = np.random.randn(e) * 5.0 # Strong dependence on Y
        
        def get_fresh_env(is_clean=False):
            np.random.seed(run_seed) 
            if is_clean:
                d_model, c_model, delay_desc = None, None, "Clean"
            else:
                if args.delay_type == 'strategic':
                    inst_mag = int(np.sqrt(args.delay_mag))
                    d_model = StrategicDelay(budget=args.delay_mag, magnitude=inst_mag, threshold_val=0.0)
                    delay_desc = f"AdvDelay(B={args.delay_mag})"
                elif args.delay_type == 'stochastic':
                    d_model = StochasticGaussianDelay(mean_delay=args.mean, std_delay=args.std)
                    delay_desc = f"StochDelay(mu={args.mean},sig={args.std})"
                else:
                    d_model, delay_desc = None, "Clean"
                c_model = StrategicOutcomeCorruption(budget=args.corruption_budget) if args.delay_type != 'clean' else None
            
            # Env gets 'e' so it generates Y
            env = ContextualDuelingBanditEnv(d=d, e=e, k=K, delay_model=d_model, corruption_model=c_model, phi_type=phi_type)
            env.theta_star, env.zeta_star = run_theta.copy(), run_zeta.copy()
            return env, delay_desc

        # (Unused parameter calculation block removed) 
        # Fixed Worst-Case Parameters (User Requested)
        # Assuming Unknown Environment: High Corruption & Delay Budget
        fixed_C = 25.0
        fixed_Lambda = 10000.0
        fixed_mu_tau = 100.0
        
        kappa_val = 0.25
        dim_total = d + e
        # Tuned Beta
        beta_base = 1.0 * np.sqrt(dim_total)
        beta_base_d = 1.0 * np.sqrt(d)

        # --- Run Ours ---
        env_ours, _ = get_fresh_env()
        # Use Fixed Parameters for Algorithm
        learner_ours = DuelingGLMLearner(d=d, e=e, lambda_reg=1.0, alpha=0.1, 
                                        C=fixed_C, Lambda=fixed_Lambda, mu_tau=fixed_mu_tau, kappa=kappa_val)
        reg = run_simulation(env_ours, learner_ours, T, name=None) 
        method_regrets["RCDP-UCB (Ours)"][run_idx] = reg

        # --- Run Baselines ---
        env_rcdb, _ = get_fresh_env()
        # RCDB uses C only
        kappa_rcdb = kappa_val
        
        # Use fixed_C for RCDB as well
        calc_C_rcdb = fixed_C 
        
        calc_rcdb_alpha_w = np.sqrt(d) / (np.sqrt(kappa_rcdb) * calc_C_rcdb)
        calc_rcdb_beta_s = beta_base_d + calc_rcdb_alpha_w * calc_C_rcdb
        learner_rcdb = RCDBLearner(d=d, lambda_reg=1.0, alpha=0.5, rcdb_alpha=calc_rcdb_alpha_w, rcdb_beta=calc_rcdb_beta_s, kappa=kappa_rcdb)
        method_regrets["RCDB"][run_idx] = run_simulation(env_rcdb, learner_rcdb, T, name=None)

        # 3. ColSTIM
        env_col, _ = get_fresh_env()
        learner_col = ColSTIMLearner(d=d, lambda_reg=1.0, alpha=0.5)
        method_regrets["ColSTIM"][run_idx] = run_simulation(env_col, learner_col, T, name=None)

        # 4. MaxInP
        env_maxinp, _ = get_fresh_env()
        learner_maxinp = MaxInPLearner(d=d, lambda_reg=1.0, alpha=0.5)
        method_regrets["MaxInP"][run_idx] = run_simulation(env_maxinp, learner_maxinp, T, name=None)

        # 5. MaxPairUCB
        env_maxpair, _ = get_fresh_env()
        learner_maxpair = MaxPairUCBLearner(d=d, lambda_reg=1.0, alpha=beta_base_d)
        method_regrets["MaxPairUCB"][run_idx] = run_simulation(env_maxpair, learner_maxpair, T, name=None)


    # --- Plotting (Professional Style: Huge Fonts) ---
    plt.rcParams.update({
        'font.size': 32,           # Global font size
        'axes.titlesize': 40,      # Title size
        'axes.labelsize': 36,      # Axis label size
        'xtick.labelsize': 32,     # X-tick size
        'ytick.labelsize': 32,     # Y-tick size
        'legend.fontsize': 32,     # Legend size
        'lines.linewidth': 7.0,    # Thicker lines
        'figure.figsize': (14, 11) # Single figure size
    })
    
    plt.figure()
    
    # Professional color palette (All solid lines)
    colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#9467bd', '#d62728']
    styles = ['-', '-', '-', '-', '-']
    alphas = [1.0, 0.9, 0.9, 0.9, 0.9]
    
    t_range = np.arange(T)
    # 10 equally spaced points including start and end for error bars
    error_indices = np.linspace(0, T-1, 11).astype(int)
    
    for i, name in enumerate(method_names):
        data = method_regrets[name] # Shape (n_runs, T)
        mean_regret = np.mean(data, axis=0)
        std_regret = np.std(data, axis=0)
        
        # Plot mean curve
        if np.any(np.isnan(mean_regret)):
            print(f"Warning: {name} has NaN values!")
        else:
            print(f"{name} Final Regret: {mean_regret[-1]:.2f}")

        c = colors[i % len(colors)]
        s = styles[i % len(styles)]
        a = alphas[i % len(alphas)]

        plt.plot(t_range, mean_regret, label=name, 
                 linestyle=s, color=c, alpha=a)
        
        # Plot Error bars
        plt.errorbar(t_range[error_indices], mean_regret[error_indices], 
                     yerr=std_regret[error_indices], 
                     fmt='none', 
                     ecolor=c, 
                     capsize=5, 
                     elinewidth=2.0,
                     alpha=a)

    plt.xlabel('Rounds (t)')
    plt.ylabel('Cumulative Regret')
    
    plt.grid(True, linestyle='-', alpha=0.3)
    plt.legend(frameon=True, framealpha=0.9, edgecolor='gray')
    
    # Ensure fig1 directory exists
    os.makedirs("fig1", exist_ok=True)
    
    if args.delay_type == 'strategic':
        filename = f"fig1/fig1_strategic_robustness_{phi_type}_d{d}_K{K}_T{T}_C{args.corruption_budget}_Mag{args.delay_mag}.pdf"
    elif args.delay_type == 'stochastic':
        filename = f"fig1/fig1_stochastic_robustness_{phi_type}_d{d}_K{K}_T{T}_C{args.corruption_budget}_Mean{args.mean}_Std{args.std}.pdf"
    else:
        filename = f"fig1/fig1_clean_robustness_{phi_type}_d{d}_K{K}_T{T}.pdf"

    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close() # Close figure to free memory
    print(f"Plot saved to {filename}")

    # --- Save to CSV ---
    data_list = []
    t_range = np.arange(1, T + 1)
    
    # Ensure exps directory exists
    os.makedirs("exps", exist_ok=True)
    
    for name in method_names:
        regret_data = method_regrets[name] # (n_runs, T)
        for r_idx in range(args.n_runs):
            # Create a DataFrame for this run
            df_run = pd.DataFrame({
                "Round": t_range,
                "Regret": regret_data[r_idx],
                "Method": name,
                "Run": r_idx,
                "Phi": phi_type,
                "DelayType": args.delay_type,
                "DelayMag": args.delay_mag if args.delay_type == 'strategic' else 0,
                "StochasticMean": args.mean if args.delay_type == 'stochastic' else 0,
                "StochasticStd": args.std if args.delay_type == 'stochastic' else 0,
                "Corruption": args.corruption_budget
            })
            data_list.append(df_run)
            
    if data_list:
        final_df = pd.concat(data_list, ignore_index=True)
        # Unique filename for this setting
        if args.delay_type == 'strategic':
            setting_str = f"C{args.corruption_budget}_Mag{args.delay_mag}"
        else:
            setting_str = f"C{args.corruption_budget}_Mean{args.mean}_Std{args.std}"

        # fig1 filename
        csv_filename = f"exps/fig1_robustness_{args.delay_type}_{phi_type}_d{d}_K{K}_{setting_str}.csv"
        final_df.to_csv(csv_filename, index=False)
        print(f"Data saved to {csv_filename}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--T', type=int, default=2000, help='Total rounds')
    parser.add_argument('--corruption_budget', type=int, default=10, help='Total corruption budget')
    parser.add_argument('--n_runs', type=int, default=5, help='Number of independent runs')
    
    # New Arguments
    parser.add_argument('--delay_type', type=str, choices=['strategic', 'stochastic'], default='stochastic', help='Type of delay')
    parser.add_argument('--delay_mag', type=int, default=400, help='Magnitude of strategic delay')
    parser.add_argument('--mean', type=float, default=20, help='Mean for stochastic delay')
    parser.add_argument('--std', type=float, default=5, help='Std dev for stochastic delay')

    args = parser.parse_args()

    # --- Experiment Settings ---
    e = 10 # ENABLE Post-Serving ($e > 0$)
    T = args.T

    # Define Experiment Settings (Levels of Difficulty) - MATCHING FIG0 exactly
    # Format: (C, mean, std, delay_mag)
    settings = [
        (25, 100.0, 100.0, 10000),
    ]

    print(f"=== Running Multi-Mapping Robustness Experiment (Fig1: WITH Post-Serving) ===")
    print(f"T={T}, N={args.n_runs}")

    # Iterate over Mapping Types 
    phi_types = ['polynomial', 'abs', 'cosine', 'sinusoidal']
    
    # Run for both Strategic and Stochastic delays, plus Clean
    delay_scenarios = ['strategic', 'stochastic']
    
    d_list = [10, 20, 30]
    K_list = [10, 20, 30]

    for d_val in d_list:
        for K_val in K_list:

            print(f"\n================================================")
            print(f"DIMENSION d={d_val}, ARM COUNT K={K_val}")
            print(f"================================================")

            for (c_val, mu_val, std_val, mag_val) in settings:
                args.corruption_budget = c_val
                args.mean = mu_val
                args.std = std_val
                args.delay_mag = mag_val
                
                print(f"\n\n################################################")
                print(f"SETTING: C={c_val}, Mean={mu_val}, Std={std_val}, Mag={mag_val}")
                print(f"################################################")

                for dtype in delay_scenarios:
                    args.delay_type = dtype
                    print(f"\n  >>> Delay Type: {dtype.upper()}")

                    for phi in phi_types:
                        run_experiment_for_mapping(phi, args, d_val, e, K_val, T)

if __name__ == "__main__":
    main()
