import os
import numpy as np
from datetime import datetime
from tqdm import tqdm
import argparse  
from data_generation import generate_synthetic_means, generate_synthetic_relative_offline_data, compute_V_matrix
from environment import SyntheticRelStoEnv
from algorithms import RelToStoElm

def main_synthetic_example_2(num_runs, output_dir, experiment_name, task_id):
    """
    Experiment 2: Analyze the individual effects of N_offline, gap and bias parameters on the convergence of the RelToStoElm algorithm.

    Parameters:
    num_runs: number of experimental repetitions
    output_dir: output directory
    experiment_name: experiment name
    task_id: task number (passed in from the command line)
    """
    K_default = 10               
    K_bias = 4                   
    seed = 123
    T_online = 600               
    delta = 0.05                 
    num_per_phase = 50           
    N_offlines = [100, 200, 500] 
    gaps = [0.05, 0.1, 0.2]      
    biases = [0.01, 0.05, 0.1]  
    default_N_offline = 100      
    default_gap = 0.1            
    default_bias = 0.01          

    mu_pairs = [
        ((0.8, 0.6, 0.5, 0.4), (0.81, 0.59, 0.51, 0.4)),    # bias = 0.01
        ((0.75, 0.75, 0.6, 0.55), (0.8, 0.7, 0.6, 0.5)),    # bias = 0.05
        ((0.5, 0.45, 0.45, 0.45), (0.6, 0.5, 0.5, 0.5)),    # bias = 0.1
    ]

    mu_pairs = [(list(mu_off), list(mu_on)) for mu_off, mu_on in mu_pairs]

    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    n_offline_dir = os.path.join(experiment_dir, "N_offline")
    gap_dir = os.path.join(experiment_dir, "gap")
    bias_dir = os.path.join(experiment_dir, "bias")
    for sub_dir in [n_offline_dir, gap_dir, bias_dir]:
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(experiment_dir, f"log_{timestamp}_{task_id}.txt") 
    with open(log_file, "w") as f:
        f.write(f"Sensitivity Analysis Parameters (Task ID: {task_id}):\n")  
        f.write(f"Experiment Name: {experiment_name}\n")
        f.write(f"K_default (N_offline, gap): {K_default}\n")
        f.write(f"K_bias (bias): {K_bias}\n")
        f.write(f"seed: {seed}\n")
        f.write(f"T_online: {T_online}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"num_per_phase: {num_per_phase}\n")
        f.write(f"N_offlines: {N_offlines}\n")
        f.write(f"gaps: {gaps}\n")
        f.write(f"biases: {biases}\n")
        f.write(f"mu_pairs for bias: {mu_pairs}\n")
        f.write(f"num_runs: {num_runs}\n")
        f.write(f"Total pulls per run: {T_online * num_per_phase}\n")
        f.write(f"default_N_offline: {default_N_offline}\n")
        f.write(f"default_gap: {default_gap}\n")
        f.write(f"default_bias: {default_bias}\n\n")

    # N_i
    for N_offline in tqdm(N_offlines, desc="Processing N_offline values"):
        print(f"Processing N_offline={N_offline}...")
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for N_offline={N_offline}", leave=False):
            print(f"  Run {run+1}/{num_runs} for N_offline={N_offline}...")
            np.random.seed(seed + run)
            mu_off, mu_on = generate_synthetic_means(K=K_default, gap=default_gap, bias=default_bias, seed=seed + run)
            offline_data, _ = generate_synthetic_relative_offline_data(mu_off, N_offline=N_offline)
            env = SyntheticRelStoEnv(mu_on)
            algo_elm = RelToStoElm(
                K=K_default,
                T=T_online,
                num_per_phase=num_per_phase,
                T_S=N_offline,
                means_offline=mu_off,
                means_online=mu_on,
                offline_data=offline_data,
                V_matrix=compute_V_matrix(mu_off, mu_on),
                delta=delta,
                env=env,
                seed= seed + run + task_id*5
            )
            _, _, pull_regrets_elm, _ = algo_elm.run()
            cumulative_regret_elm = np.cumsum(pull_regrets_elm)
            all_regrets.append(cumulative_regret_elm)
            np.save(os.path.join(n_offline_dir, f"regret_Noffline{N_offline}_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_elm)  

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"N_offline={N_offline} (Task ID: {task_id}):\n") 
            f.write(f"  Final Mean Cumulative Regret (at T={T_online * num_per_phase}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    # gap
    for gap in tqdm(gaps, desc="Processing gap values"):
        print(f"Processing gap={gap}...")
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for gap={gap}", leave=False):
            print(f"  Run {run+1}/{num_runs} for gap={gap}...")
            np.random.seed(seed + run)
            mu_off, mu_on = generate_synthetic_means(K=K_default, gap=gap, bias=default_bias, seed=seed + run)
            offline_data, _ = generate_synthetic_relative_offline_data(mu_off, N_offline=default_N_offline)
            env = SyntheticRelStoEnv(mu_on)
            algo_elm = RelToStoElm(
                K=K_default,
                T=T_online,
                num_per_phase=num_per_phase,
                T_S=default_N_offline,
                means_offline=mu_off,
                means_online=mu_on,
                offline_data=offline_data,
                V_matrix=compute_V_matrix(mu_off, mu_on),
                delta=delta,
                env=env,
                seed= seed + run + task_id*5
            )
            _, _, pull_regrets_elm, _ = algo_elm.run()
            cumulative_regret_elm = np.cumsum(pull_regrets_elm)
            all_regrets.append(cumulative_regret_elm)
            np.save(os.path.join(gap_dir, f"regret_gap{gap}_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_elm) 

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"gap={gap} (Task ID: {task_id}):\n")  
            f.write(f"  Final Mean Cumulative Regret (at T={T_online * num_per_phase}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    # bias
    for idx, (mu_off, mu_on) in enumerate(tqdm(mu_pairs, desc="Processing bias values")):
        bias = biases[idx]
        print(f"Processing bias={bias}...")
        all_regrets = []
        for run in tqdm(range(num_runs), desc=f"Runs for bias={bias}", leave=False):
            print(f"  Run {run+1}/{num_runs} for bias={bias}...")
            np.random.seed(seed + run)
            offline_data, _ = generate_synthetic_relative_offline_data(mu_off, N_offline=default_N_offline)
            env = SyntheticRelStoEnv(mu_on)
            algo_elm = RelToStoElm(
                K=K_bias,
                T=T_online,
                num_per_phase=num_per_phase,
                T_S=default_N_offline,
                means_offline=mu_off,
                means_online=mu_on,
                offline_data=offline_data,
                V_matrix=compute_V_matrix(mu_off, mu_on),
                delta=delta,
                env=env,
                seed= seed + run + task_id*5
            )
            _, _, pull_regrets_elm, _ = algo_elm.run()
            cumulative_regret_elm = np.cumsum(pull_regrets_elm)
            all_regrets.append(cumulative_regret_elm)
            np.save(os.path.join(bias_dir, f"regret_bias{bias}_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_elm)  

        all_regrets = np.array(all_regrets)
        mean_regret = np.mean(all_regrets, axis=0)
        min_regret = np.min(all_regrets, axis=0)
        max_regret = np.max(all_regrets, axis=0)
        with open(log_file, "a") as f:
            f.write(f"bias={bias} (Task ID: {task_id}):\n") 
            f.write(f"  Final Mean Cumulative Regret (at T={T_online * num_per_phase}): {mean_regret[-1]:.2f} (Min: {min_regret[-1]:.2f}, Max: {max_regret[-1]:.2f})\n")

    print(f"Sensitivity analysis (Task ID: {task_id}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Synthetic Example 2 with task ID")
    parser.add_argument("--task_id", type=int, default=1, help="Task ID for this run")
    args = parser.parse_args()

    main_synthetic_example_2(
        num_runs=5,
        output_dir="../output_syn",
        experiment_name="synthetic_example_2",
        task_id=args.task_id
    )