
from ucb1 import UCB1, UCB1_multiple_C
from bayesianucb import BayesianUCB
from environments import Env,  DynamicEnvOrdered, DynamicEnvOrdered_Incremental, DynamicEnvRandom, DynamicEnvOrdered_Mix_Env
from simulation import do_run, do_run_BO, do_run_UCB
from methods import  RemoveOldestNonPillarClass, PillarHolder
from multi_agents import MultiAgentSystem, MultiAgentSystem_BO
from gp import *
import matplotlib.pyplot as plt
from tqdm import tqdm

import numpy as np
import os
import pandas as pd
import time

from workload import *
from Dynamic_algo import SW_UCB, Discounted_UCB, ThompsonSamplingSW,  ThompsonSampling_fDSW
from test_methods import  *
import warnings
from sklearn.exceptions import ConvergenceWarning

# Suppress ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)

def test_method_change_c_multiagent(mus_all, sigmas_all, method_type='UCB1', change_every=2000):
    """
    Test method for multi-agent systems
    
    Args:
        mus_all: Mean values for actions
        sigmas_all: Standard deviations for actions
        method_type: Algorithm type for multi-agent (currently supports 'UCB1')
        change_every: Steps between environment changes
    """
    n_steps = 6000

    if len(mus_all.shape) == 1:
        mus_all = np.array([mus_all])
        sigmas_all = np.array([sigmas_all])
        change_every = n_steps + 1

    Avg_res, Best_action, Agent_prob = [], [], []

    num_actions = mus_all.shape[1] if len(mus_all.shape) > 1 else mus_all.shape[0]

    # Multi-agent systems
    method_type == 'UCB1'
    agent_params = {
        
        "n_actions": num_actions,
    }
    select_action_class = MultiAgentSystem(UCB1, agent_params, n_actions=num_actions)
    env = DynamicEnvOrdered(mus_all, sigmas_all, change_every=change_every) if len(mus_all.shape) > 1 else Env(mus_all, sigmas_all)
    results = do_run(env, select_action_class, n_steps=n_steps, n_actions=num_actions)
    

    Avg_res.append(results["average_response_time"])
    Best_action.append(results["best_chosen"])
    Agent_prob.append(results["correct_identifications"])

    return np.array(Avg_res), np.array(Best_action), np.array(Agent_prob), select_action_class, results

def test_method_change_c_singleagent(mus_all, sigmas_all, method_type='UCB1', change_every=2000):
    """
    Test method for single-agent systems
    
    Args:
        mus_all: Mean values for actions
        sigmas_all: Standard deviations for actions
        method_type: Algorithm type ('UCB1', 'SW_UCB', 'Discounted_UCB', 'ThompsonSamplingSW', 'ThompsonSampling_fDSW', 'BayesianUCB')
        change_every: Steps between environment changes
    """
    n_steps = 6000

    if len(mus_all.shape) == 1:
        mus_all = np.array([mus_all])
        sigmas_all = np.array([sigmas_all])
        change_every = n_steps + 1

    Avg_res, Best_action, Agent_prob = [], [], []

    num_actions = mus_all.shape[1] if len(mus_all.shape) > 1 else mus_all.shape[0]

    # Single-agent systems
    if method_type == 'UCB1':
        
        select_action_class = UCB1(n_actions=num_actions, c=1.0)
    elif method_type == 'SW_UCB':
        select_action_class = SW_UCB(n_actions=num_actions)
    elif method_type == 'Discounted_UCB':
        select_action_class = Discounted_UCB(n_actions=num_actions)
    elif method_type == 'ThompsonSamplingSW':
        select_action_class = ThompsonSamplingSW(n_actions=num_actions)
    elif method_type == 'BayesianUCB':
        select_action_class = BayesianUCB(n_actions=num_actions)
    elif method_type == 'ThompsonSampling_fDSW':
        select_action_class = ThompsonSampling_fDSW(n_actions=num_actions)
    else:
        raise ValueError(f"Unknown single-agent method type: {method_type}")
    
    env = DynamicEnvOrdered(mus_all, sigmas_all, change_every=change_every) if len(mus_all.shape) > 1 else Env(mus_all, sigmas_all)
    results = do_run_UCB(env, select_action_class, n_steps)

    Avg_res.append(results["average_response_time"])
    Best_action.append(results["best_chosen"])
    
    # Handle the case where correct_identifications might not exist for non-multi-agent methods
    if "correct_identifications" in results:
        Agent_prob.append(results["correct_identifications"])
    else:
        Agent_prob.append([0] * n_steps)  # Default value if not available

    return np.array(Avg_res), np.array(Best_action), np.array(Agent_prob), select_action_class, results

def test_method_change_c_BO(mus_all, sigmas_all, method_low, method_mic, method_high, start_size=10, init_method=None, change_every=2000, verbose=True):
    n_steps = 6000

    if len(mus_all.shape) == 1:
        mus_all = np.array([mus_all])
        sigmas_all = np.array([sigmas_all])
        change_every = n_steps + 1

    new_points = np.linspace(0.01, 10.0, start_size) if not init_method else init_method(size=start_size)

    add_points_function_low = method_low[0]
    add_points_function_mic = method_mic[0]
    add_points_function_high = method_high[0]
    remove_points_function = method_low[1]

    Avg_res, Best_action, Agent_prob = [], [], []

    num_actions = mus_all.shape[1] if len(mus_all.shape) > 1 else mus_all.shape[0]

    agent_params = {
        "c_list": new_points,
        "n_actions": num_actions,
        "c_timestep_change": 40,
        "add_points_function_low": add_points_function_low,
        "add_points_function_mic": add_points_function_mic,
        "add_points_function_high": add_points_function_high,
        "remove_points_function": remove_points_function,
        "verbose": verbose,
    }

    select_action_class = MultiAgentSystem_BO(UCB1_multiple_C, agent_params, n_actions=num_actions)

    env = DynamicEnvOrdered(mus_all, sigmas_all, change_every=change_every) if len(mus_all.shape) > 1 else Env(mus_all, sigmas_all)

    results = do_run_BO(env, select_action_class, n_steps=n_steps, n_actions=num_actions)

    Avg_res.append(results["average_response_time"])
    Best_action.append(results["best_chosen"])
    Agent_prob.append(results["correct_identifications"])

    return np.array(Avg_res), np.array(Best_action), np.array(Agent_prob), select_action_class, results

def test_method_Bandit_dynamic(mus_all, sigmas_all, method_type, change_every=2000, n_steps=6000):
    
    return test_method_change_c_singleagent(mus_all, sigmas_all, method_type, change_every)[:4]

def test_method_n_times_multiagent(n, method_type='UCB1', **kwargs):
    """
    Test multi-agent methods multiple times
    """
    response_times_means, best_chosen_step_prob, accuracy_agent = [], [], []

    for _ in range(n):
        mean_response, best_s_action, agent_prob, _, _ = test_method_change_c_multiagent(method_type=method_type, **kwargs)
        response_times_means.append(mean_response)
        best_chosen_step_prob.append(best_s_action)
        accuracy_agent.append(agent_prob)

    return np.array(response_times_means), np.array(best_chosen_step_prob), np.array(accuracy_agent)

def test_method_n_times_singleagent(n, method_type='UCB1', **kwargs):
    """
    Test single-agent methods multiple times
    """
    response_times_means, best_chosen_step_prob, accuracy_agent = [], [], []

    for _ in range(n):
        mean_response, best_s_action, agent_prob, _, _ = test_method_change_c_singleagent(method_type=method_type, **kwargs)
        response_times_means.append(mean_response)
        best_chosen_step_prob.append(best_s_action)
        accuracy_agent.append(agent_prob)

    return np.array(response_times_means), np.array(best_chosen_step_prob), np.array(accuracy_agent)

def test_method_n_times_nor(n, **kwargs):
    
    response_times_means, best_chosen_step_prob, selected_actions_all_runs = [], [], []

    for _ in range(n):
        mean_response, best_s_action, selected_actions, _ = test_method_Bandit_dynamic(**kwargs)
        response_times_means.append(mean_response)
        best_chosen_step_prob.append(best_s_action)
        selected_actions_all_runs.append(selected_actions)

    return np.array(response_times_means), np.array(best_chosen_step_prob), np.array(selected_actions_all_runs)

def test_method_n_times(n, method_type='UCB1', **kwargs):
    
    return test_method_n_times_multiagent(n, method_type, **kwargs)

def test_method_n_times_BO(n, **kwargs):
    response_times_means, best_chosen_step_prob, accuracy_agent = [], [], []

    for _ in range(n):
        mean_response, best_s_action, agent_prob, _, _ = test_method_change_c_BO(**kwargs)
        response_times_means.append(mean_response)
        best_chosen_step_prob.append(best_s_action)
        accuracy_agent.append(agent_prob)

    return np.array(response_times_means), np.array(best_chosen_step_prob), np.array(accuracy_agent)

def main():
    pillar_holder = PillarHolder(num_pillar_points=8)
    remove_oldest = RemoveOldestNonPillarClass(pillar_holder)
    gp_with_pillar_rerun = GPWithPillarReRun(pillar_holder)

    method_low_DBO = (gp_with_pillar_rerun.select_points_GP, remove_oldest.remove_oldest_non_pillar)
    method_mic_DBO = (gp_with_pillar_rerun.select_points_GP, remove_oldest.remove_oldest_non_pillar)
    method_high_DBO = (gp_with_pillar_rerun.select_points_GP, remove_oldest.remove_oldest_non_pillar)

    methods = [method_low_DBO, method_mic_DBO, method_high_DBO]

    num_actions_ordered = sorted(np.unique([key for key, _ in environments.items()]))

    # Unified output directory
    output_dir = 'all_method_data'
    os.makedirs(output_dir, exist_ok=True)

    # Multi-agent algorithms (currently only UCB1 supported)
    multiagent_algorithms = ['UCB1']
    
    print("Running Multi-Agent Methods")
    for algorithm in multiagent_algorithms:
        print(f"Running Multi-Agent {algorithm}")

        mean_response_time_summary = []
        lower_bound_summary = []
        upper_bound_summary = []
        mean_best_action_chosen_summary = []
        lower_bound_best_action_summary = []
        upper_bound_best_action_summary = []

        for num_actions in num_actions_ordered:
            print(f"Num actions: {num_actions}")
            runs = environments[num_actions]
            means, bests = [], []

            for mus, sigmas in runs:
                mean_response_times, best_action_chosen, _ = test_method_n_times_multiagent(
                    n=1, mus_all=mus, sigmas_all=sigmas, method_type=algorithm
                )
                means.append(mean_response_times)
                bests.append(best_action_chosen)

            mean_response_times = np.mean(np.array(means), 0)
            best_action_chosen = np.mean(np.array(bests), 0)

            # Aggregate the summaries
            mean_response_time_summary.append(mean_response_times.mean())
            std_error_summary = np.std(mean_response_times) / np.sqrt(len(mean_response_times))
            lower_bound_summary.append(mean_response_times.mean() - 1.96 * std_error_summary)
            upper_bound_summary.append(mean_response_times.mean() + 1.96 * std_error_summary)

            mean_best_action_chosen_summary.append(best_action_chosen.mean())
            std_error_best_action_summary = np.std(best_action_chosen) / np.sqrt(len(best_action_chosen))
            lower_bound_best_action_summary.append(best_action_chosen.mean() - 1.96 * std_error_best_action_summary)
            upper_bound_best_action_summary.append(best_action_chosen.mean() + 1.96 * std_error_best_action_summary)

        summary_data = pd.DataFrame({
            'num_actions': num_actions_ordered,
            'mean_response_time': mean_response_time_summary,
            'lower_bound_response_time': lower_bound_summary,
            'upper_bound_response_time': upper_bound_summary,
            'mean_best_action': mean_best_action_chosen_summary,
            'lower_bound_best_action': lower_bound_best_action_summary,
            'upper_bound_best_action': upper_bound_best_action_summary
        })

        # Multi-agent naming: DAMAS-{algorithm}-summary
        summary_file = os.path.join(output_dir, f"DAMAS-{algorithm}-summary.csv")
        summary_data.to_csv(summary_file, index=False)

    # Single-agent algorithms (including UCB1)
    singleagent_algorithms = ['UCB1', 'SW_UCB', 'Discounted_UCB', 'ThompsonSamplingSW', 'BayesianUCB', 'ThompsonSampling_fDSW']
    
    print("Running Single-Agent Methods")
    for algorithm in singleagent_algorithms:
        print(f"Running Single-Agent {algorithm}")

        mean_response_time_summary = []
        lower_bound_summary = []
        upper_bound_summary = []
        mean_best_action_chosen_summary = []
        lower_bound_best_action_summary = []
        upper_bound_best_action_summary = []

        for num_actions in num_actions_ordered:
            print(f"Num actions: {num_actions}")
            runs = environments[num_actions]
            means, bests = [], []

            for mus, sigmas in runs:
                mean_response_times, best_action_chosen, _ = test_method_n_times_singleagent(
                    n=1, mus_all=mus, sigmas_all=sigmas, method_type=algorithm
                )
                means.append(mean_response_times)
                bests.append(best_action_chosen)

            mean_response_times = np.mean(np.array(means), 0)
            best_action_chosen = np.mean(np.array(bests), 0)

            # Aggregate the summaries
            mean_response_time_summary.append(mean_response_times.mean())
            std_error_summary = np.std(mean_response_times) / np.sqrt(len(mean_response_times))
            lower_bound_summary.append(mean_response_times.mean() - 1.96 * std_error_summary)
            upper_bound_summary.append(mean_response_times.mean() + 1.96 * std_error_summary)

            mean_best_action_chosen_summary.append(best_action_chosen.mean())
            std_error_best_action_summary = np.std(best_action_chosen) / np.sqrt(len(best_action_chosen))
            lower_bound_best_action_summary.append(best_action_chosen.mean() - 1.96 * std_error_best_action_summary)
            upper_bound_best_action_summary.append(best_action_chosen.mean() + 1.96 * std_error_best_action_summary)

        summary_data = pd.DataFrame({
            'num_actions': num_actions_ordered,
            'mean_response_time': mean_response_time_summary,
            'lower_bound_response_time': lower_bound_summary,
            'upper_bound_response_time': upper_bound_summary,
            'mean_best_action': mean_best_action_chosen_summary,
            'lower_bound_best_action': lower_bound_best_action_summary,
            'upper_bound_best_action': upper_bound_best_action_summary
        })

        # Single-agent naming: {algorithm}-single-summary
        summary_file = os.path.join(output_dir, f"{algorithm}-single-summary.csv")
        summary_data.to_csv(summary_file, index=False)

    # BO Methods (keeping the original BO functionality)
    print("Running BO-DAMAS-UCB Methods")
    mean_response_time_summary = []
    lower_bound_summary = []
    upper_bound_summary = []
    mean_best_action_chosen_summary = []
    lower_bound_best_action_summary = []
    upper_bound_best_action_summary = []

    for num_actions in num_actions_ordered:
        print("Num actions: ", num_actions)

        runs = environments[num_actions]
        means, bests = [], []

        for mus, sigmas in runs:
            mean_response_times, best_action_chosen, _ = test_method_n_times_BO(1, mus_all=mus, sigmas_all=sigmas,
                                                                               method_low=method_low_DBO,
                                                                               method_mic=method_mic_DBO,
                                                                               method_high=method_high_DBO)
            means.append(mean_response_times)
            bests.append(best_action_chosen)

        mean_response_times = np.mean(np.array(means), 0)
        best_action_chosen = np.mean(np.array(bests), 0)

        # Aggregate the summaries
        mean_response_time_summary.append(mean_response_times.mean())
        std_error_summary = np.std(mean_response_times) / np.sqrt(len(mean_response_times))
        lower_bound_summary.append(mean_response_times.mean() - 1.96 * std_error_summary)
        upper_bound_summary.append(mean_response_times.mean() + 1.96 * std_error_summary)

        mean_best_action_chosen_summary.append(best_action_chosen.mean())
        std_error_best_action_summary = np.std(best_action_chosen) / np.sqrt(len(best_action_chosen))
        lower_bound_best_action_summary.append(best_action_chosen.mean() - 1.96 * std_error_best_action_summary)
        upper_bound_best_action_summary.append(best_action_chosen.mean() + 1.96 * std_error_best_action_summary)

    summary_data_BO = pd.DataFrame({
        'num_actions': num_actions_ordered,
        'mean_response_time': mean_response_time_summary,
        'lower_bound_response_time': lower_bound_summary,
        'upper_bound_response_time': upper_bound_summary,
        'mean_best_action': mean_best_action_chosen_summary,
        'lower_bound_best_action': lower_bound_best_action_summary,
        'upper_bound_best_action': upper_bound_best_action_summary
    })

    # BO naming remains the same
    summary_data_BO.to_csv(os.path.join(output_dir, 'BO-DAMAS-UCB-summary.csv'), index=False)

    print("All experiments completed!")
    print("Results saved in the following files:")
    print("Multi-Agent Methods:")
    for algorithm in multiagent_algorithms:
        print(f"  - DAMAS-{algorithm}-summary.csv")
    print("Single-Agent Methods:")
    for algorithm in singleagent_algorithms:
        print(f"  - {algorithm}-single-summary.csv")
    print("BO Methods:")
    print("  - BO-DAMAS-UCB-summary.csv")

if __name__ == "__main__":
    main()