import os
from datetime import datetime
import numpy as np
import argparse  # 新增：用于解析命令行参数
from data_generation import generate_synthetic_means, generate_synthetic_stochastic_offline_data, compute_V_matrix
from environment import SyntheticPreferenceEnv
from algorithms import Sto_to_Rel_UCB, RUCB, InterleavedFilter2
from tqdm import tqdm

def main_synthetic_example_1(num_runs, output_dir, experiment_name, task_id):
    """
    Experiment 1: Compare the performance of Sto_to_Rel_UCB, RUCB, Pure_Online_UCB and InterleavedFilter2 algorithms under different K values.

    Parameters:
    num_runs: number of experimental repetitions
    output_dir: output directory
    experiment_name: experiment name
    task_id: task number (passed in from the command line)
    """
    Ks = [8, 16, 24, 32]
    gap = 0.1
    bias = 0.01
    seed = 1
    T_max = 50000
    delta = 0.02
    N_offline = 1000

    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    k_dirs = {}
    for K in Ks:
        k_dir = os.path.join(experiment_dir, f"K_{K}")
        if not os.path.exists(k_dir):
            os.makedirs(k_dir)
        k_dirs[K] = k_dir

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    task_id_str = f"{task_id:03d}"
    log_file = os.path.join(experiment_dir, f"log_{timestamp}_{task_id_str}.txt")  
    with open(log_file, "w") as f:
        f.write(f"Experiment 1 Parameters (Task ID: {task_id_str}):\n") 
        f.write(f"Experiment Name: {experiment_name}\n")
        f.write(f"Ks: {Ks}\n")
        f.write(f"gap: {gap}\n")
        f.write(f"bias: {bias}\n")
        f.write(f"seed: {seed}\n")
        f.write(f"T_max: {T_max}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"N_offline: {N_offline}\n")
        f.write(f"num_runs: {num_runs}\n\n")

    # Outer loop: iterate over different K values
    for K in tqdm(Ks, desc="Processing K values"):
        print(f"Processing K={K} (Task ID: {task_id_str})...")  
        np.random.seed(seed)
        mu_off, mu_on = generate_synthetic_means(K=K, gap=gap, bias=bias, seed=seed)
        offline_data, _ = generate_synthetic_stochastic_offline_data(mu_off, N_offline=N_offline)
        env = SyntheticPreferenceEnv(mu_on)

        all_regrets_sto = []
        all_regrets_rucb = []
        all_regrets_pure = []
        all_regrets_if = []
        best_arms_sto = []
        best_arms_rucb = []
        best_arms_pure = []
        best_arms_if = []

        # Middle loop: iterate over num_runs
        for run in tqdm(range(num_runs), desc=f"Runs for K={K}", leave=False):
            print(f"  Run {run+1}/{num_runs} for K={K} (Task ID: {task_id_str})...")  
            
            # Sto_to_Rel_UCB
            algo_sto = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=True, seed=seed+run+(task_id-1)*5)
            algo_sto.fit_offline_data(offline_data)
            V_mat = compute_V_matrix(mu_off, mu_on)
            algo_sto.V = V_mat
            regret_history_sto, _ = algo_sto.run_online(env, T=T_max, num_relative_draws=1, true_mu_on=mu_on)
            cumulative_regret_sto = np.cumsum(regret_history_sto)
            all_regrets_sto.append(cumulative_regret_sto)
            best_arms_sto.append(algo_sto.get_best_arm())

            # RUCB
            algo_rucb = RUCB(K=K, alpha=0.51, delta = delta, seed=seed+run+(task_id-1)*5)
            regret_history_rucb, _ = algo_rucb.run_online(env, T=T_max, true_mu_on=mu_on)
            cumulative_regret_rucb = np.cumsum(regret_history_rucb)
            all_regrets_rucb.append(cumulative_regret_rucb)
            best_arms_rucb.append(algo_rucb.get_best_arm())

            # Pure_Online_UCB (implemented via Sto_to_Rel_UCB without offline data)
            algo_pure = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=False, seed=seed+run+(task_id-1)*5)
            regret_history_pure, _ = algo_pure.run_online(env, T=T_max, num_relative_draws=1, true_mu_on=mu_on)
            cumulative_regret_pure = np.cumsum(regret_history_pure)
            all_regrets_pure.append(cumulative_regret_pure)
            best_arms_pure.append(algo_pure.get_best_arm())

            # InterleavedFilter2
            algo_if = InterleavedFilter2(K=K, T=T_max, delta=delta, seed=seed+run+(task_id-1)*5)
            regret_history_if = algo_if.run(env, true_mu_on=mu_on)
            cumulative_regret_if = np.cumsum(regret_history_if)
            all_regrets_if.append(cumulative_regret_if)
            best_arms_if.append(algo_if.get_best_arm())

            # Save results
            np.save(os.path.join(k_dirs[K], f"regret_sto_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_sto)  
            np.save(os.path.join(k_dirs[K], f"regret_rucb_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_rucb)  
            np.save(os.path.join(k_dirs[K], f"regret_pure_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_pure)  
            np.save(os.path.join(k_dirs[K], f"regret_if_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_if)  

            # 记录每次运行的最终结果
            with open(log_file, "a") as f:
                f.write(f"K={K}, Run {run+1} (Task ID: {task_id_str}):\n")  
                f.write(f"  Sto_to_Rel_UCB - Final Cumulative Regret: {cumulative_regret_sto[-1]:.2f}, Best Arm: {algo_sto.get_best_arm()}\n")
                f.write(f"  RUCB - Final Cumulative Regret: {cumulative_regret_rucb[-1]:.2f}, Best Arm: {algo_rucb.get_best_arm()}\n")
                f.write(f"  Pure_Online_UCB - Final Cumulative Regret: {cumulative_regret_pure[-1]:.2f}, Best Arm: {algo_pure.get_best_arm()}\n")
                f.write(f"  InterleavedFilter2 - Final Cumulative Regret: {cumulative_regret_if[-1]:.2f}, Best Arm: {algo_if.get_best_arm()}\n\n")

        all_regrets_sto = np.array(all_regrets_sto)
        all_regrets_rucb = np.array(all_regrets_rucb)
        all_regrets_pure = np.array(all_regrets_pure)
        all_regrets_if = np.array(all_regrets_if)

        mean_regret_sto = np.mean(all_regrets_sto, axis=0)
        min_regret_sto = np.min(all_regrets_sto, axis=0)
        max_regret_sto = np.max(all_regrets_sto, axis=0)

        mean_regret_rucb = np.mean(all_regrets_rucb, axis=0)
        min_regret_rucb = np.min(all_regrets_rucb, axis=0)
        max_regret_rucb = np.max(all_regrets_rucb, axis=0)

        mean_regret_pure = np.mean(all_regrets_pure, axis=0)
        min_regret_pure = np.min(all_regrets_pure, axis=0)
        max_regret_pure = np.max(all_regrets_pure, axis=0)

        mean_regret_if = np.mean(all_regrets_if, axis=0)
        min_regret_if = np.min(all_regrets_if, axis=0)
        max_regret_if = np.max(all_regrets_if, axis=0)

        with open(log_file, "a") as f:
            f.write(f"K={K} Summary (over {num_runs} runs, Task ID: {task_id_str}):\n")  
            f.write(f"  Best Arm Selections:\n")
            f.write(f"    Sto_to_Rel_UCB: {best_arms_sto}\n")
            f.write(f"    RUCB: {best_arms_rucb}\n")
            f.write(f"    Pure_Online_UCB: {best_arms_pure}\n")
            f.write(f"    InterleavedFilter2: {best_arms_if}\n")
            f.write(f"  Final Mean Cumulative Regret (with Min and Max at T={T_max}):\n")
            f.write(f"    Sto_to_Rel_UCB: {mean_regret_sto[-1]:.2f} (Min: {min_regret_sto[-1]:.2f}, Max: {max_regret_sto[-1]:.2f})\n")
            f.write(f"    RUCB: {mean_regret_rucb[-1]:.2f} (Min: {min_regret_rucb[-1]:.2f}, Max: {max_regret_rucb[-1]:.2f})\n")
            f.write(f"    Pure_Online_UCB: {mean_regret_pure[-1]:.2f} (Min: {min_regret_pure[-1]:.2f}, Max: {max_regret_pure[-1]:.2f})\n")
            f.write(f"    InterleavedFilter2: {mean_regret_if[-1]:.2f} (Min: {min_regret_if[-1]:.2f}, Max: {max_regret_if[-1]:.2f})\n\n")

    print(f"Experiment 1 (Task ID: {task_id_str}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Synthetic Example 1 with task ID")
    parser.add_argument("--task_id", type=int, default=1, help="Task ID for this run")
    args = parser.parse_args()

    main_synthetic_example_1(
        num_runs=5,
        output_dir="../output_syn",
        experiment_name="synthetic_example_1",
        task_id=args.task_id
    )