import os
from datetime import datetime
import numpy as np
import pandas as pd
import argparse 
from data_generation import compute_real_means_offline_online, generate_movielens_stochastic, compute_V_matrix
from environment import RealDataPreferenceEnv
from algorithms import Sto_to_Rel_UCB, RUCB, InterleavedFilter2
from tqdm import tqdm

def main_real_example(num_runs, output_dir, experiment_name, task_id, real_data):
    """
    Real dataset experiment: Compare the performance of Sto_to_Rel_UCB, RUCB, Pure_Online and InterleavedFilter2 algorithms.

    Parameters:
    num_runs: number of experimental repetitions
    output_dir: output directory
    task_id: task number (passed in from the command line)
    """
    # Real data experiment parameters
    offline_ratings_file = real_data
    online_ratings_file = real_data
    K = 10
    min_ratings = 100
    T_online = 10000
    feedback_mode = "data"  
    delta = 0.02
    seed = 1

    # Create output directory
    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    # Create subdirectories for each algorithm
    sto_dir = os.path.join(experiment_dir, "Sto_to_Rel_UCB")
    pure_dir = os.path.join(experiment_dir, "Pure_Online")
    rucb_dir = os.path.join(experiment_dir, "RUCB")
    if_dir = os.path.join(experiment_dir, "InterleavedFilter2")
    for sub_dir in [sto_dir, pure_dir, rucb_dir, if_dir]:
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)

    # Log experiment parameters
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(experiment_dir, f"real_experiment_log_{timestamp}_{task_id}.txt")  
    with open(log_file, "w") as f:
        f.write(f"Real Data Experiment Parameters (Task ID: {task_id}):\n") 
        f.write(f"offline_ratings_file: {offline_ratings_file}\n")
        f.write(f"online_ratings_file: {online_ratings_file}\n")
        f.write(f"K: {K}\n")
        f.write(f"min_ratings: {min_ratings}\n")
        f.write(f"T_online: {T_online}\n")
        f.write(f"feedback_mode: {feedback_mode}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"num_runs: {num_runs}\n\n")

    # Compute means for real data
    mu_off_dict, mu_on_dict, selected_ids = compute_real_means_offline_online(
        offline_ratings_file, online_ratings_file, K, min_ratings=min_ratings, seed=seed
    )

    K = len(selected_ids)
    id2index = {mid: idx for idx, mid in enumerate(selected_ids)}
    df_off = pd.read_csv(offline_ratings_file)
    df_off = df_off[df_off['movieId'].isin(selected_ids)]
    df_off['rating_norm'] = df_off['rating'] / 5.0

    # Store cumulative regret of each run
    all_cumulative_regrets_sto = []
    all_cumulative_regrets_pure = []
    all_cumulative_regrets_rucb = []
    all_cumulative_regrets_if = []
    best_arms_sto = []
    best_arms_pure = []
    best_arms_rucb = []
    best_arms_if = []

    # Run multiple times to compute average, minimum, and maximum regret
    for run in tqdm(range(num_runs), desc="Total Runs"):
        print(f"Running experiment {run+1}/{num_runs} (Task ID: {task_id})...")  

        # Sample offline data (resampled in each run)
        offline_data = []
        # Check if sample size exceeds population size
        sample_size = 100000
        replace = sample_size > len(df_off)

        for _, row in df_off.sample(n=sample_size, random_state=seed, replace=replace).iterrows():
            arm = id2index[row['movieId']]
            offline_data.append((arm, row['rating_norm']))

        # Initialize environment
        movie_rewards_dict = generate_movielens_stochastic(offline_ratings_file, selected_ids)
        mu_off_arr = np.array([mu_off_dict[mid] for mid in selected_ids])
        mu_on_arr = np.array([mu_on_dict[mid] for mid in selected_ids])
        env = RealDataPreferenceEnv(arm_ids=selected_ids, movie_rewards_dict=movie_rewards_dict, mu_on=mu_on_arr, feedback_mode=feedback_mode)

        # Run Sto_to_Rel_UCB
        algo_sto = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=True, seed=seed+run+(task_id-1)*2)
        algo_sto.fit_offline_data(offline_data)
        V_mat = compute_V_matrix(mu_off_arr, mu_on_arr)
        algo_sto.V = V_mat
        regret_history_sto, _ = algo_sto.run_online(env, T=T_online, num_relative_draws=1, true_mu_on=mu_on_arr)
        cumulative_regret_sto = np.cumsum(regret_history_sto)
        all_cumulative_regrets_sto.append(cumulative_regret_sto)
        best_arms_sto.append(algo_sto.get_best_arm())

        # Run RUCB
        algo_rucb = RUCB(K=K, alpha=0.51, delta = delta, seed=seed+run+(task_id-1)*2)
        regret_history_rucb, _ = algo_rucb.run_online(env, T=T_online, true_mu_on=mu_on_arr)
        cumulative_regret_rucb = np.cumsum(regret_history_rucb)
        all_cumulative_regrets_rucb.append(cumulative_regret_rucb)
        best_arms_rucb.append(algo_rucb.get_best_arm())

        # Pure_Online (implemented using Sto_to_Rel_UCB without offline data)
        algo_pure = Sto_to_Rel_UCB(K=K, delta=delta, V=None, use_offline=False, seed=seed+run+(task_id-1)*2)
        regret_history_pure, _ = algo_pure.run_online(env, T=T_online, num_relative_draws=1, true_mu_on=mu_on_arr)
        cumulative_regret_pure = np.cumsum(regret_history_pure)
        all_cumulative_regrets_pure.append(cumulative_regret_pure)
        best_arms_pure.append(algo_pure.get_best_arm())

        # Run InterleavedFilter2
        algo_if = InterleavedFilter2(K=K, T=T_online, delta=delta, seed=seed+run+(task_id-1)*2)
        regret_history_if = algo_if.run(env, true_mu_on=mu_on_arr)
        cumulative_regret_if = np.cumsum(regret_history_if)
        all_cumulative_regrets_if.append(cumulative_regret_if)
        best_arms_if.append(algo_if.get_best_arm())

        # Save cumulative regret for each run to corresponding subdirectory
        np.save(os.path.join(sto_dir, f"real_regret_sto_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_sto)  
        np.save(os.path.join(pure_dir, f"real_regret_pure_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_pure)  
        np.save(os.path.join(rucb_dir, f"real_regret_rucb_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_rucb)  
        np.save(os.path.join(if_dir, f"real_regret_if_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_if)  

        # 记录每次运行的最终结果
        with open(log_file, "a") as f:
            f.write(f"Run {run+1} (Task ID: {task_id}):\n")  
            f.write(f"  Sto_to_Rel_UCB - Final Cumulative Regret: {cumulative_regret_sto[-1]:.2f}, Best Arm: {algo_sto.get_best_arm()}\n")
            f.write(f"  Pure_Online - Final Cumulative Regret: {cumulative_regret_pure[-1]:.2f}, Best Arm: {algo_pure.get_best_arm()}\n")
            f.write(f"  RUCB - Final Cumulative Regret: {cumulative_regret_rucb[-1]:.2f}, Best Arm: {algo_rucb.get_best_arm()}\n")
            f.write(f"  InterleavedFilter2 - Final Cumulative Regret: {cumulative_regret_if[-1]:.2f}, Best Arm: {algo_if.get_best_arm()}\n\n")

    # Compute average, minimum, and maximum regret
    all_cumulative_regrets_sto = np.array(all_cumulative_regrets_sto)
    all_cumulative_regrets_pure = np.array(all_cumulative_regrets_pure)
    all_cumulative_regrets_rucb = np.array(all_cumulative_regrets_rucb)
    all_cumulative_regrets_if = np.array(all_cumulative_regrets_if)

    mean_regret_sto = np.mean(all_cumulative_regrets_sto, axis=0)
    min_regret_sto = np.min(all_cumulative_regrets_sto, axis=0)
    max_regret_sto = np.max(all_cumulative_regrets_sto, axis=0)

    mean_regret_pure = np.mean(all_cumulative_regrets_pure, axis=0)
    min_regret_pure = np.min(all_cumulative_regrets_pure, axis=0)
    max_regret_pure = np.max(all_cumulative_regrets_pure, axis=0)

    mean_regret_rucb = np.mean(all_cumulative_regrets_rucb, axis=0)
    min_regret_rucb = np.min(all_cumulative_regrets_rucb, axis=0)
    max_regret_rucb = np.max(all_cumulative_regrets_rucb, axis=0)

    mean_regret_if = np.mean(all_cumulative_regrets_if, axis=0)
    min_regret_if = np.min(all_cumulative_regrets_if, axis=0)
    max_regret_if = np.max(all_cumulative_regrets_if, axis=0)

    # Record results
    with open(log_file, "a") as f:
        f.write(f"Average Results (over {num_runs} runs, Task ID: {task_id}):\n")  
        f.write("Best Arm Selections:\n")
        f.write(f"  Sto_to_Rel_UCB: {best_arms_sto}\n")
        f.write(f"  Pure_Online: {best_arms_pure}\n")
        f.write(f"  RUCB: {best_arms_rucb}\n")
        f.write(f"  InterleavedFilter2: {best_arms_if}\n")
        f.write(f"Final Results (at T={T_online}):\n")
        f.write(f"  Sto_to_Rel_UCB Mean Cumulative Regret: {mean_regret_sto[-1]:.2f} (Min: {min_regret_sto[-1]:.2f}, Max: {max_regret_sto[-1]:.2f})\n")
        f.write(f"  Pure_Online Mean Cumulative Regret: {mean_regret_pure[-1]:.2f} (Min: {min_regret_pure[-1]:.2f}, Max: {max_regret_pure[-1]:.2f})\n")
        f.write(f"  RUCB Mean Cumulative Regret: {mean_regret_rucb[-1]:.2f} (Min: {min_regret_rucb[-1]:.2f}, Max: {max_regret_rucb[-1]:.2f})\n")
        f.write(f"  InterleavedFilter2 Mean Cumulative Regret: {mean_regret_if[-1]:.2f} (Min: {min_regret_if[-1]:.2f}, Max: {max_regret_if[-1]:.2f})\n")

    print(f"Real data experiment (Task ID: {task_id}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Run Real Data Example with task ID and custom paths")
    parser.add_argument("task_id", type=int, default=1, help="Task ID for this run")
    parser.add_argument("ratings_file", type=str, default="../real_data/movielens.csv", help="Path to the ratings CSV file")
    parser.add_argument("output_dir", type=str, default="../output_real", help="Directory to save experiment results")
    parser.add_argument("experiment_name", type=str, default="real_example_yelp", help="Name of the experiment")
    parser.add_argument("num_runs", type=int, default=3, help="Number of experiment runs")
    args = parser.parse_args()

    # Run the experiment
    main_real_example(
        num_runs=args.num_runs,
        output_dir=args.output_dir,
        experiment_name=args.experiment_name,
        task_id=args.task_id,
        real_data=args.ratings_file
    )