import os
from datetime import datetime
import numpy as np
import argparse  # 新增：用于解析命令行参数
from data_generation import generate_synthetic_means, generate_synthetic_stochastic_offline_data, compute_V_matrix
from environment import SyntheticPreferenceEnv
from algorithms import Sto_to_Rel_UCB
from tqdm import tqdm

def main_synthetic_example_2(num_runs, output_dir, experiment_name, task_id):
    """
    Experiment 2: Analyze the individual effects of N_offline, gap and bias parameters on the convergence of the Sto_to_Rel_UCB algorithm.

    Parameters:
    num_runs: number of experimental repetitions
    output_dir: output directory
    experiment_name: experiment name
    task_id: task number (passed in from the command line)
    """
    K = 20
    seed = 1
    T_online = 30000
    delta = 0.02
    N_offlines = [500, 300, 100]
    gaps = [0.2, 0.1, 0.01]
    biases = [0.01, 0.05, 0.1]
    default_N_offline = 500
    default_gap = 0.1
    default_bias = 0

    mu_pairs = [
        ((0.8, 0.6, 0.5, 0.4), (0.81, 0.59, 0.51, 0.4)),
        ((0.75, 0.75, 0.6, 0.55), (0.8, 0.7, 0.6, 0.5)),
        ((0.5, 0.45, 0.45, 0.45), (0.6, 0.5, 0.5, 0.5)),
    ]

    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    n_offline_dir = os.path.join(experiment_dir, "N_offline")
    gap_dir = os.path.join(experiment_dir, "gap")
    bias_dir = os.path.join(experiment_dir, "bias")
    for sub_dir in [n_offline_dir, gap_dir, bias_dir]:
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    task_id_str = f"{task_id:03d}"
    log_file = os.path.join(experiment_dir, f"log_{timestamp}_{task_id_str}.txt")  
    with open(log_file, "w") as f:
        f.write(f"Sensitivity Analysis Parameters (Task ID: {task_id_str}):\n")  # 修改：记录 task_id
        f.write(f"Experiment Name: {experiment_name}\n")
        f.write(f"K: {K}\n")
        f.write(f"seed: {seed}\n")
        f.write(f"T_online: {T_online}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"N_offlines: {N_offlines}\n")
        f.write(f"gaps: {gaps}\n")
        f.write(f"biases: {biases}\n")
        f.write(f"num_runs: {num_runs}\n")
        f.write(f"default_N_offline: {default_N_offline}\n")
        f.write(f"default_gap: {default_gap}\n")
        f.write(f"default_bias: {default_bias}\n\n")

    # Sensitivity analysis for N_offline
    for N_offline in tqdm(N_offlines, desc="Processing N_offline values"):
        print(f"Processing N_offline={N_offline} (Task ID: {task_id_str})...")  
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for N_offline={N_offline}", leave=False):
            print(f"  Run {run+1}/{num_runs} for N_offline={N_offline} (Task ID: {task_id_str})...")  
            np.random.seed(seed + run)
            mu_off, mu_on = generate_synthetic_means(K=K, gap=default_gap, bias=default_bias, seed = seed)
            offline_data, _ = generate_synthetic_stochastic_offline_data(mu_off, N_offline=N_offline)
            env = SyntheticPreferenceEnv(mu_on)
            algo_sto = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=True, seed=seed+run+(task_id-1)*5)
            algo_sto.fit_offline_data(offline_data)
            V_mat = compute_V_matrix(mu_off, mu_on)
            algo_sto.V = V_mat
            regret_history_sto, _ = algo_sto.run_online(env, T=T_online, num_relative_draws=1, true_mu_on=mu_on)
            cumulative_regret_sto = np.cumsum(regret_history_sto)
            all_regrets.append(cumulative_regret_sto)
            np.save(os.path.join(n_offline_dir, f"regret_Noffline{N_offline}_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_sto)  

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"N_offline={N_offline} (Task ID: {task_id_str}):\n")  
            f.write(f"  Final Mean Cumulative Regret (at T={T_online}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    # Sensitivity analysis for gap
    for gap in tqdm(gaps, desc="Processing gap values"):
        print(f"Processing gap={gap} (Task ID: {task_id_str})...")  
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for gap={gap}", leave=False):
            print(f"  Run {run+1}/{num_runs} for gap={gap} (Task ID: {task_id_str})...")  
            np.random.seed(seed + run)
            mu_off, mu_on = generate_synthetic_means(K=K, gap=gap, bias=default_bias, seed=seed)
            offline_data, _ = generate_synthetic_stochastic_offline_data(mu_off, N_offline=default_N_offline)
            env = SyntheticPreferenceEnv(mu_on)
            algo_sto = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=True, seed=seed+run+(task_id-1)*5)
            algo_sto.fit_offline_data(offline_data)
            V_mat = compute_V_matrix(mu_off, mu_on)
            algo_sto.V = V_mat
            regret_history_sto, _ = algo_sto.run_online(env, T=T_online, num_relative_draws=1, true_mu_on=mu_on)
            cumulative_regret_sto = np.cumsum(regret_history_sto)
            all_regrets.append(cumulative_regret_sto)
            np.save(os.path.join(gap_dir, f"regret_gap{gap}_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_sto)  

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"gap={gap} (Task ID: {task_id_str}):\n")  
            f.write(f"  Final Mean Cumulative Regret (at T={T_online}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    # Sensitivity analysis for bias
    for idx, (mu_off, mu_on) in enumerate(tqdm(mu_pairs, desc="Processing bias values")):
        bias = biases[idx]
        print(f"Processing bias={bias} (Task ID: {task_id_str})...")  
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for bias={bias}", leave=False):
            print(f"  Run {run+1}/{num_runs} for bias={bias} (Task ID: {task_id_str})...")  
            np.random.seed(seed + run)  
            offline_data, _ = generate_synthetic_stochastic_offline_data(mu_off, N_offline=default_N_offline)
            env = SyntheticPreferenceEnv(mu_on)
            algo_sto = Sto_to_Rel_UCB(K=len(mu_off), delta=delta, V=None, use_offline=True, seed=seed+run+(task_id-1)*5)
            algo_sto.fit_offline_data(offline_data)
            V_mat = compute_V_matrix(mu_off, mu_on)
            algo_sto.V = V_mat
            regret_history_sto, _ = algo_sto.run_online(env, T=T_online, num_relative_draws=1, true_mu_on=mu_on)
            cumulative_regret_sto = np.cumsum(regret_history_sto)
            all_regrets.append(cumulative_regret_sto)
            np.save(os.path.join(bias_dir, f"regret_bias{bias}_run_{run}_{timestamp}_{task_id_str}.npy"), cumulative_regret_sto)  

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"bias={bias} (Task ID: {task_id_str}):\n")  
            f.write(f"  Final Mean Cumulative Regret (at T={T_online}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    print(f"Sensitivity analysis (Task ID: {task_id_str}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Synthetic Example 2 with task ID")
    parser.add_argument("--task_id", type=int, default=1, help="Task ID for this run")
    args = parser.parse_args()

    main_synthetic_example_2(
        num_runs=5,
        output_dir="../output_syn",
        experiment_name="synthetic_example_2",
        task_id=args.task_id
    )