import os
import numpy as np
import pandas as pd
import random
from datetime import datetime
from tqdm import tqdm
import argparse
from data_generation import generate_movielens_relative_offline_data, generate_movielens_stochastic_online_data, compute_V_matrix
from environment import RealDataRelStoEnv
from algorithms import RelToStoElm, ETC, UCB, ThompsonSampling

def main_real_example(num_runs, output_dir, experiment_name, task_id, ratings_file, sample_size):
    """
    Experiment: Compare the performance of 5 algorithms on real datasets.
    
    Parameters:
      num_runs: Number of experiment runs
      output_dir: Output directory (passed via command line)
      experiment_name: Experiment name
      task_id: Task ID (passed via command line)
      ratings_file: Path to the ratings CSV file (passed via command line)
    """
    # Experiment parameters
    K = 10                   
    seed = 1
    T_online = 100            
    delta = 0.05                
    num_per_phase = 100          
    N_offline = 1000
    m = 200
    feedback_mode = "data"
    environment_mode = "data"
    sample_size= sample_size
    print(f"m (ETC): {m}")

    # Create output directory
    experiment_dir = os.path.join(output_dir, experiment_name)
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)

    # Create subdirectories for each algorithm
    elm_dir = os.path.join(experiment_dir, "Rel_to_Sto_Elm")
    pure_online_dir = os.path.join(experiment_dir, "Pure_Online")
    etc_dir = os.path.join(experiment_dir, "ETC")
    ucb_dir = os.path.join(experiment_dir, "UCB")
    ts_dir = os.path.join(experiment_dir, "ThompsonSampling")
    for sub_dir in [elm_dir, pure_online_dir, etc_dir, ucb_dir, ts_dir]:
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)

    print(f"Loading dataset from {ratings_file}...")
    df = pd.read_csv(ratings_file)
    movie_counts = df['movieId'].value_counts()
    top_100_ids = movie_counts.head(100).index.tolist()  # get top 100
    select_id = random.sample(top_100_ids, K)
    print(f"Selected movie IDs: {select_id}")

    # Generate offline and online data
    print("Generating offline and online data...")
    offline_data, mu_off = generate_movielens_relative_offline_data(
        ratings_file=ratings_file,
        select_id=select_id,
        N_offline_per_pair=N_offline,
        feedback_mode=feedback_mode
    )
    movie_rewards, mu_on = generate_movielens_stochastic_online_data(
        ratings_file=ratings_file,
        select_id=select_id
    )

    # Record experiment parameters
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(experiment_dir, f"log_{timestamp}_{task_id}.txt")
    with open(log_file, "w") as f:
        f.write(f"Real Data Experiment Parameters (Task ID: {task_id}):\n")
        f.write(f"Experiment Name: {experiment_name}\n")
        f.write(f"Data File: {ratings_file}\n")  # Record the data file path
        f.write(f"Output Directory: {output_dir}\n")  # Record the output directory
        f.write(f"K: {K}\n")
        f.write(f"Selected Movie IDs: {select_id}\n")
        f.write(f"seed: {seed}\n")
        f.write(f"T_online: {T_online}\n")
        f.write(f"delta: {delta}\n")
        f.write(f"num_per_phase: {num_per_phase}\n")
        f.write(f"N_offline: {N_offline}\n")
        f.write(f"m (ETC): {m}\n")
        f.write(f"num_runs: {num_runs}\n")
        f.write(f"Total pulls per run: {T_online * num_per_phase}\n")
        f.write(f"mu_off: {mu_off}\n")
        f.write(f"mu_on: {mu_on}\n\n")

    # Store cumulative regret and best arms for each run
    all_regrets_elm = []
    all_regrets_pure = []
    all_regrets_etc = []
    all_regrets_ucb = []
    all_regrets_ts = []
    best_arms_elm = []
    best_arms_pure = []
    best_arms_etc = []
    best_arms_ucb = []
    best_arms_ts = []

    # Run the experiment multiple times
    for run in tqdm(range(num_runs), desc="Running experiments"):
        print(f"Running experiment {run+1}/{num_runs}...")
        np.random.seed(seed + run)

        # Create the online environment
        env = RealDataRelStoEnv(arm_ids=select_id, movie_rewards_dict=movie_rewards, feedback_mode=environment_mode, sample_size= sample_size)

        # Compute V_matrix
        V_mat = compute_V_matrix(mu_off, mu_on)
        V_mat_zeros = np.zeros((K, K))  # For Pure_Online

        # Run Rel_to_Sto_Elm (with offline data)
        algo_elm = RelToStoElm(
            K=K,
            T=T_online,
            num_per_phase=num_per_phase,
            T_S=N_offline,
            means_offline=mu_off,
            means_online=mu_on,
            offline_data=offline_data,
            V_matrix=V_mat,
            delta=delta,
            env=env,
            seed=seed + run + task_id*3
        )
        surviving_arms_elm, pull_rewards_elm, pull_regrets_elm, _ = algo_elm.run()
        cumulative_regret_elm = np.cumsum(pull_regrets_elm)
        all_regrets_elm.append(cumulative_regret_elm)
        best_arms_elm.append(sorted(list(surviving_arms_elm)))
        np.save(os.path.join(elm_dir, f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_elm)

        # Run Pure_Online (without offline data, using RelToStoElm)
        algo_pure = RelToStoElm(
            K=K,
            T=T_online,
            num_per_phase=num_per_phase,
            T_S=0, 
            means_offline=mu_off,
            means_online=mu_on,
            offline_data=[], 
            V_matrix=V_mat_zeros,  
            delta=delta,
            env=env,
            seed=seed + run + task_id*3
        )
        surviving_arms_pure, pull_rewards_pure, pull_regrets_pure, _ = algo_pure.run()
        cumulative_regret_pure = np.cumsum(pull_regrets_pure)
        all_regrets_pure.append(cumulative_regret_pure)
        best_arms_pure.append(sorted(list(surviving_arms_pure)))
        np.save(os.path.join(pure_online_dir, f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_pure)

        # ETC
        algo_etc = ETC(
            K=K,
            T=T_online * num_per_phase,
            m=m,
            means_online=mu_on,
            env=env,
            seed=seed + run + task_id*3
        )
        surviving_arms_etc, pull_rewards_etc, pull_regrets_etc, _ = algo_etc.run()
        cumulative_regret_etc = np.cumsum(pull_regrets_etc)
        all_regrets_etc.append(cumulative_regret_etc)
        best_arms_etc.append(sorted(list(surviving_arms_etc)))
        np.save(os.path.join(etc_dir, f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_etc)

        # UCB
        algo_ucb = UCB(
            K=K,
            T=T_online * num_per_phase,
            means_online=mu_on,
            env=env,
            delta=delta,
            seed=seed + run + task_id*3
        )
        surviving_arms_ucb, pull_rewards_ucb, pull_regrets_ucb, _ = algo_ucb.run()
        cumulative_regret_ucb = np.cumsum(pull_regrets_ucb)
        all_regrets_ucb.append(cumulative_regret_ucb)
        best_arms_ucb.append(sorted(list(surviving_arms_ucb)))
        np.save(os.path.join(ucb_dir, f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_ucb)

        # Thompson Sampling
        algo_ts = ThompsonSampling(K=K, T=T_online * num_per_phase, means_online=mu_on, env=env, prior_mean=0.5, init_pulls=1,seed = seed + run + task_id*3)
        surviving_arms_ts, pull_rewards_ts, pull_regrets_ts, _ = algo_ts.run()
        cumulative_regret_ts = np.cumsum(pull_regrets_ts)
        all_regrets_ts.append(cumulative_regret_ts)
        best_arms_ts.append(sorted(list(surviving_arms_ts)))
        np.save(os.path.join(ts_dir, f"regret_run_{run}_{timestamp}_{task_id}.npy"), cumulative_regret_ts)

        # Verify total number of pulls
        total_pulls = T_online * num_per_phase
        assert len(pull_rewards_elm) == total_pulls, f"Rel_to_Sto_Elm pulls: {len(pull_rewards_elm)}"
        assert len(pull_rewards_pure) == total_pulls, f"Pure_Online pulls: {len(pull_rewards_pure)}"
        assert len(pull_rewards_etc) == total_pulls, f"ETC pulls: {len(pull_rewards_etc)}"
        assert len(pull_rewards_ucb) == total_pulls, f"UCB pulls: {len(pull_rewards_ucb)}"
        assert len(pull_rewards_ts) == total_pulls, f"Thompson Sampling pulls: {len(pull_rewards_ts)}"

        # Log the final results of each run
        with open(log_file, "a") as f:
            f.write(f"Run {run+1} (Task ID: {task_id}):\n")
            f.write(f"  Rel_to_Sto_Elm - Final Cumulative Regret: {cumulative_regret_elm[-1]:.2f}, Surviving Arms: {sorted(list(surviving_arms_elm))}\n")
            f.write(f"  Pure_Online - Final Cumulative Regret: {cumulative_regret_pure[-1]:.2f}, Surviving Arms: {sorted(list(surviving_arms_pure))}\n")
            f.write(f"  ETC - Final Cumulative Regret: {cumulative_regret_etc[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_etc))}\n")
            f.write(f"  UCB - Final Cumulative Regret: {cumulative_regret_ucb[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_ucb))}\n")
            f.write(f"  Thompson Sampling - Final Mean Cumulative Regret: {cumulative_regret_ts[-1]:.2f}, Chosen Arm: {sorted(list(surviving_arms_ts))}\n\n")

    # Compute average regret and log the results
    all_regrets_elm = np.array(all_regrets_elm)
    all_regrets_pure = np.array(all_regrets_pure)
    all_regrets_etc = np.array(all_regrets_etc)
    all_regrets_ucb = np.array(all_regrets_ucb)
    all_regrets_ts = np.array(all_regrets_ts)

    mean_regret_elm = np.mean(all_regrets_elm, axis=0)
    mean_regret_pure = np.mean(all_regrets_pure, axis=0)
    mean_regret_etc = np.mean(all_regrets_etc, axis=0)
    mean_regret_ucb = np.mean(all_regrets_ucb, axis=0)
    mean_regret_ts = np.mean(all_regrets_ts, axis=0)

    min_regret_elm = np.min(all_regrets_elm, axis=0)
    min_regret_pure = np.min(all_regrets_pure, axis=0)
    min_regret_etc = np.min(all_regrets_etc, axis=0)
    min_regret_ucb = np.min(all_regrets_ucb, axis=0)
    min_regret_ts = np.min(all_regrets_ts, axis=0)

    max_regret_elm = np.max(all_regrets_elm, axis=0)
    max_regret_pure = np.max(all_regrets_pure, axis=0)
    max_regret_etc = np.max(all_regrets_etc, axis=0)
    max_regret_ucb = np.max(all_regrets_ucb, axis=0)
    max_regret_ts = np.max(all_regrets_ts, axis=0)

    total_pulls = T_online * num_per_phase
    with open(log_file, "a") as f:
        f.write(f"Average Results (over {num_runs} runs, Task ID: {task_id}):\n")
        f.write(f"  Rel_to_Sto_Elm - Final Mean Cumulative Regret: {mean_regret_elm[-1]:.2f} (Min: {min_regret_elm[-1]:.2f}, Max: {max_regret_elm[-1]:.2f})\n")
        f.write(f"  Pure_Online - Final Mean Cumulative Regret: {mean_regret_pure[-1]:.2f} (Min: {min_regret_pure[-1]:.2f}, Max: {max_regret_pure[-1]:.2f})\n")
        f.write(f"  ETC - Final Mean Cumulative Regret: {mean_regret_etc[-1]:.2f} (Min: {min_regret_etc[-1]:.2f}, Max: {max_regret_etc[-1]:.2f})\n")
        f.write(f"  UCB - Final Mean Cumulative Regret: {mean_regret_ucb[-1]:.2f} (Min: {min_regret_ucb[-1]:.2f}, Max: {max_regret_ucb[-1]:.2f})\n")
        f.write(f"  Thompson Sampling - Final Mean Cumulative Regret: {mean_regret_ts[-1]:.2f} (Min: {min_regret_ts[-1]:.2f}, Max: {max_regret_ts[-1]:.2f})\n")

    print(f"Real data experiment (Task ID: {task_id}) completed. Results saved to {experiment_dir}")

if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Run Real Data Example 1 with task ID and custom paths")
    parser.add_argument("task_id", type=int, default=1, help="Task ID for this run")
    parser.add_argument("ratings_file", type=str, default="../real_data/movielens.csv", help="Path to the ratings CSV file")
    parser.add_argument("output_dir", type=str, default="../output_real", help="Directory to save experiment results")
    parser.add_argument("experiment_name", type=str, default="real_example_movielens", help="Name of the experiment")
    parser.add_argument("num_runs", type=int, default=3, help="Number of experiment runs")
    parser.add_argument("sample_size", type=int, default=100, help="Number of experiment runs")
    args = parser.parse_args()

    # Run the experiment
    main_real_example(
        num_runs=args.num_runs,
        output_dir=args.output_dir,
        experiment_name=args.experiment_name,
        task_id=args.task_id,
        ratings_file=args.ratings_file,
        sample_size=args.sample_size
    )