import os
import numpy as np
from datetime import datetime
from tqdm import tqdm
import argparse  
from data_generation import generate_synthetic_means, generate_synthetic_relative_offline_data, compute_V_matrix
from environment import SyntheticRelStoEnv
from algorithms import RelToStoElm, ETC, UCB, ThompsonSampling

def main_synthetic_example_1(num_runs, output_dir, experiment_name, task_id):
    """
    Experiment 1: Compare the performance of RelToStoElm, Pure-Online, ETC, UCB and ThompsonSampling algorithms under different K values.

    Parameters:
    num_runs: number of experimental repetitions
    output_dir: output directory
    experiment_name: experiment name
    task_id: task number (passed in from the command line)
    """
    Ks = [8,16,32]  
    gap = 0.1
    bias = 0.01
    seed = 123
    T_online = 600  
    delta = 0.05
    N_offline = 5000  
    num_per_phase = 50  
    m = 500
    print(f"m (ETC): {m}")

    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    k_dirs = {}
    for K in Ks:
        k_dir = os.path.join(experiment_dir, f"K_{K}")
        if not os.path.exists(k_dir):
            os.makedirs(k_dir)
        elm_dir = os.path.join(k_dir, "Rel_to_Sto_Elm")
        pure_dir = os.path.join(k_dir, "Pure_Online")
        etc_dir = os.path.join(k_dir, "ETC")
        ucb_dir = os.path.join(k_dir, "UCB")
        ts_dir = os.path.join(k_dir, "ThompsonSampling")
        for sub_dir in [elm_dir, pure_dir, etc_dir, ucb_dir, ts_dir]:
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir)
        k_dirs[K] = {
            "Rel_to_Sto_Elm": elm_dir,
            "Pure_Online": pure_dir,
            "ETC": etc_dir,
            "UCB": ucb_dir,
            "ThompsonSampling": ts_dir
        }

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(experiment_dir, f"log_{timestamp}_{task_id}.txt")  
    with open(log_file, "w") as f:
        f.write(f"Experiment 1 Parameters (Task ID: {task_id}):\n")  
        f.write(f"Experiment Name: {experiment_name}\n")
        f.write(f"Ks: {Ks}\n")
        f.write(f"gap: {gap}\n")
        f.write(f"bias: {bias}\n")
        f.write(f"seed: {seed}\n")
        f.write(f"T_online: {T_online}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"N_offline: {N_offline}\n")
        f.write(f"num_per_phase: {num_per_phase}\n")
        f.write(f"m (ETC): {m}\n")
        f.write(f"num_runs: {num_runs}\n")
        f.write(f"Total pulls per run: {T_online * num_per_phase}\n\n")

    for K in tqdm(Ks, desc="Processing K values"):
        print(f"Processing K={K}...")

        all_regrets_elm = []
        all_regrets_pure = []
        all_regrets_etc = []
        all_regrets_ucb = []
        all_regrets_ts = []
        best_arms_elm = []
        best_arms_pure = []
        best_arms_etc = []
        best_arms_ucb = []
        best_arms_ts = []

        for run in tqdm(range(num_runs), desc=f"Runs for K={K}", leave=False):
            print(f"  Run {run+1}/{num_runs} for K={K}...")
            np.random.seed(seed + run)

            mu_off, mu_on = generate_synthetic_means(K=K, gap=gap, bias=bias, seed=seed + run)
            offline_data, _ = generate_synthetic_relative_offline_data(mu_off, N_offline=N_offline)
            V_mat = compute_V_matrix(mu_off, mu_on)
            
            env = SyntheticRelStoEnv(mu_on)

            algo_elm = RelToStoElm(
                K=K,
                T=T_online,
                num_per_phase=num_per_phase,
                T_S=N_offline,
                means_offline=mu_off,
                means_online=mu_on,
                offline_data=offline_data,
                V_matrix=V_mat,
                delta=delta,
                env=env,
                seed = seed + run + task_id*5
            )
            surviving_arms_elm, pull_rewards_elm, pull_regrets_elm, _ = algo_elm.run()
            cumulative_regret_elm = np.cumsum(pull_regrets_elm)
            all_regrets_elm.append(cumulative_regret_elm)
            best_arms_elm.append(sorted(list(surviving_arms_elm)))
            np.save(os.path.join(k_dirs[K]["Rel_to_Sto_Elm"], f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_elm) 

            V_mat_zeros = np.zeros((K, K))
            algo_pure = RelToStoElm(
                K=K,
                T=T_online,
                num_per_phase=num_per_phase,
                T_S=0,
                means_offline=mu_off,
                means_online=mu_on,
                offline_data=[],
                V_matrix=V_mat_zeros,
                delta=delta,
                env=env,
                seed = seed + run + task_id*5
            )
            surviving_arms_pure, pull_rewards_pure, pull_regrets_pure, _ = algo_pure.run()
            cumulative_regret_pure = np.cumsum(pull_regrets_pure)
            all_regrets_pure.append(cumulative_regret_pure)
            best_arms_pure.append(sorted(list(surviving_arms_pure)))
            np.save(os.path.join(k_dirs[K]["Pure_Online"], f"regret_run_{run}_{timestamp}_{task_id}_{seed + run + task_id*5}.npy"), cumulative_regret_pure)  

            # ETC
            algo_etc = ETC(K=K, T=T_online * num_per_phase, m=m, means_online=mu_on, env=env, seed = seed + run + task_id*5)
            surviving_arms_etc, pull_rewards_etc, pull_regrets_etc, _ = algo_etc.run()
            cumulative_regret_etc = np.cumsum(pull_regrets_etc)
            all_regrets_etc.append(cumulative_regret_etc)
            best_arms_etc.append(sorted(list(surviving_arms_etc)))
            np.save(os.path.join(k_dirs[K]["ETC"], f"regret_run_{run}_{timestamp}_{task_id}_{seed + run + task_id*5}.npy"), cumulative_regret_etc)  

            # UCB
            algo_ucb = UCB(K=K, T=T_online * num_per_phase, means_online=mu_on, env=env, delta=delta, seed = seed + run + task_id*5)
            surviving_arms_ucb, pull_rewards_ucb, pull_regrets_ucb, _ = algo_ucb.run()
            cumulative_regret_ucb = np.cumsum(pull_regrets_ucb)
            all_regrets_ucb.append(cumulative_regret_ucb)
            best_arms_ucb.append(sorted(list(surviving_arms_ucb)))
            np.save(os.path.join(k_dirs[K]["UCB"], f"regret_run_{run}_{timestamp}_{task_id}_{seed + run + task_id*5}.npy"), cumulative_regret_ucb)  

            # Thompson Sampling
            algo_ts = ThompsonSampling(K=K, T=T_online * num_per_phase, means_online=mu_on, env=env, prior_mean=0.5, init_pulls=1,seed = seed + run + task_id*5)
            surviving_arms_ts, pull_rewards_ts, pull_regrets_ts, _ = algo_ts.run()
            cumulative_regret_ts = np.cumsum(pull_regrets_ts)
            all_regrets_ts.append(cumulative_regret_ts)
            best_arms_ts.append(sorted(list(surviving_arms_ts)))
            np.save(os.path.join(k_dirs[K]["ThompsonSampling"], f"regret_run_{run}_{timestamp}_{task_id}_{seed + run + task_id*5}.npy"), cumulative_regret_ts) 

            total_pulls = T_online * num_per_phase
            assert len(pull_rewards_elm) == total_pulls, f"RelToStoElm pulls: {len(pull_rewards_elm)}"
            assert len(pull_rewards_pure) == total_pulls, f"Pure-Online pulls: {len(pull_rewards_pure)}"
            assert len(pull_rewards_etc) == total_pulls, f"ETC pulls: {len(pull_rewards_etc)}"
            assert len(pull_rewards_ucb) == total_pulls, f"UCB pulls: {len(pull_rewards_ucb)}"
            assert len(pull_rewards_ts) == total_pulls, f"Thompson Sampling pulls: {len(pull_rewards_ts)}"

            with open(log_file, "a") as f:
                f.write(f"K={K}, Run {run+1} (Task ID: {task_id}):\n") 
                f.write(f"  mu_off: {mu_off}\n")
                f.write(f"  mu_on: {mu_on}\n")
                f.write(f"  Rel_to_Sto_Elm - Final Cumulative Regret: {cumulative_regret_elm[-1]:.2f}, Surviving Arms: {sorted(list(surviving_arms_elm))}\n")
                f.write(f"  Pure_Online - Final Cumulative Regret: {cumulative_regret_pure[-1]:.2f}, Surviving Arms: {sorted(list(surviving_arms_pure))}\n")
                f.write(f"  ETC - Final Cumulative Regret: {cumulative_regret_etc[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_etc))}\n")
                f.write(f"  UCB - Final Cumulative Regret: {cumulative_regret_ucb[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_ucb))}\n")
                f.write(f"  Thompson Sampling - Final Cumulative Regret: {cumulative_regret_ts[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_ts))}\n\n")

        all_regrets_elm = np.array(all_regrets_elm)
        all_regrets_pure = np.array(all_regrets_pure)
        all_regrets_etc = np.array(all_regrets_etc)
        all_regrets_ucb = np.array(all_regrets_ucb)
        all_regrets_ts = np.array(all_regrets_ts)

        mean_regret_elm = np.mean(all_regrets_elm, axis=0)
        min_regret_elm = np.min(all_regrets_elm, axis=0)
        max_regret_elm = np.max(all_regrets_elm, axis=0)

        mean_regret_pure = np.mean(all_regrets_pure, axis=0)
        min_regret_pure = np.min(all_regrets_pure, axis=0)
        max_regret_pure = np.max(all_regrets_pure, axis=0)

        mean_regret_etc = np.mean(all_regrets_etc, axis=0)
        min_regret_etc = np.min(all_regrets_etc, axis=0)
        max_regret_etc = np.max(all_regrets_etc, axis=0)

        mean_regret_ucb = np.mean(all_regrets_ucb, axis=0)
        min_regret_ucb = np.min(all_regrets_ucb, axis=0)
        max_regret_ucb = np.max(all_regrets_ucb, axis=0)

        mean_regret_ts = np.mean(all_regrets_ts, axis=0)
        min_regret_ts = np.min(all_regrets_ts, axis=0)
        max_regret_ts = np.max(all_regrets_ts, axis=0)

        with open(log_file, "a") as f:
            f.write(f"K={K} Summary (over {num_runs} runs, Task ID: {task_id}):\n")  
            f.write(f"  Best Arm Selections:\n")
            f.write(f"    RelToStoElm: {best_arms_elm}\n")
            f.write(f"    Pure-Online: {best_arms_pure}\n")
            f.write(f"    ETC: {best_arms_etc}\n")
            f.write(f"    UCB: {best_arms_ucb}\n")
            f.write(f"    ThompsonSampling: {best_arms_ts}\n")
            f.write(f"  Final Mean Cumulative Regret (with Min and Max at T={total_pulls}):\n")
            f.write(f"    RelToStoElm: {mean_regret_elm[-1]:.2f} (Min: {min_regret_elm[-1]:.2f}, Max: {max_regret_elm[-1]:.2f})\n")
            f.write(f"    Pure-Online: {mean_regret_pure[-1]:.2f} (Min: {min_regret_pure[-1]:.2f}, Max: {max_regret_pure[-1]:.2f})\n")
            f.write(f"    ETC: {mean_regret_etc[-1]:.2f} (Min: {min_regret_etc[-1]:.2f}, Max: {max_regret_etc[-1]:.2f})\n")
            f.write(f"    UCB: {mean_regret_ucb[-1]:.2f} (Min: {min_regret_ucb[-1]:.2f}, Max: {max_regret_ucb[-1]:.2f})\n")
            f.write(f"    ThompsonSampling: {mean_regret_ts[-1]:.2f} (Min: {min_regret_ts[-1]:.2f}, Max: {max_regret_ts[-1]:.2f})\n\n")

    print(f"Experiment 1 (Task ID: {task_id}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Synthetic Example 1 with task ID")
    parser.add_argument("--task_id", type=int, default=1, help="Task ID for this run")
    args = parser.parse_args()

    main_synthetic_example_1(
        num_runs=5,
        output_dir="../output_syn",
        experiment_name="synthetic_example_1",
        task_id=args.task_id
    )