import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


# algorithms_list = [
#     "Discounted-klUCB",
#     "Discounted-TS",
#     "Master+UCB",
#     "M-UCB ($w=200$, with diminishing)",
#     "M-UCB ($w=200$)",
#     "CUSUM-UCB (with diminishing)",
#     "CUSUM-UCB",
#     "GLR-UCB (with diminishing)",
#     "GLR-UCB"
# ]

# markers = ['o', 'd', 'p', '^', '^', 's', 's', 'v', 'v']
# colors = ['#bcbd22', '#2ca02c', '#d62728', '#9467bd', '#9467bd','#1f77b4','#1f77b4', '#8B564A','#8B564A']  
# linestyles = ['-','-','-',':','-',':','-',':','-']

# 設定顏色和 marker 的對應字典
# 深色系為非 diminishing，淺色系為 diminishing
colors = {
    'Master+UCB': {'normal': '#d62728', 'diminishing': '#d62728'},  # 淺藍色
    'CUSUM-UCB': {'normal': '#1f77b4', 'diminishing': '#1f77b4'},  # 淺橙色
    'GLR-UCB': {'normal': '#8B564A', 'diminishing': '#8B564A'},  # 淺綠色
    'M-UCB': {'normal': '#9467bd', 'diminishing': '#9467bd'}  # 淺紅色
}

# 選擇較清晰的 marker
markers = {
    'diminishing': 'X',  # 上三角形表示帶有 diminishing 的演算法
    'normal': '^'  # 下三角形表示沒有 diminishing 的演算法
}

def get_algorithm_base_name(algorithm_name):
    """提取括號前的基礎演算法名稱"""
    if '(' in algorithm_name:
        return algorithm_name.split('(')[0].strip()
    else:
        return algorithm_name

def get_algorithm_marker(algorithm_name):
    """判斷是否帶有 diminishing，返回不同的 marker"""
    if 'diminishing' in algorithm_name:
        return markers['diminishing']
    else:
        return markers['normal']

def get_algorithm_key(algorithm_name):
    """根據基礎名稱和是否帶有 diminishing 來生成唯一鍵"""
    base_name = get_algorithm_base_name(algorithm_name)
    if 'diminishing' in algorithm_name:
        return f"{base_name}_diminishing"
    else:
        return f"{base_name}_normal"

def get_algorithm_color(algorithm_name):
    """根據演算法名稱，返回正確的顏色（深淺區別）"""
    base_name = get_algorithm_base_name(algorithm_name)
    if 'diminishing' in algorithm_name:
        return colors[base_name]['diminishing']
    else:
        return colors[base_name]['normal']

def read_algorithm_name(path):
    alg_file_path = os.path.join(path, 'alg_str.txt')
    with open(alg_file_path, 'r') as file:
        # 讀取每一行並去掉首尾空白字符，包括換行符
        algorithm_names = [line.strip() for line in file.readlines()]
    return algorithm_names

def read_mean_regret(path, algorithm_name):
    file_name = f"mean_regrets_{algorithm_name}.csv"
    file_path = os.path.join(path, file_name)
    data = pd.read_csv(file_path)
    last_value = data.iloc[-1].values[0]  # 取最後一行的第一個值
    return last_value

def read_execution_time(path, algorithm_name):
    file_name = f"execution_time_log_rep_100{algorithm_name}.txt"
    file_path = os.path.join(path, file_name)
    with open(file_path, 'r') as file:
        values = [float(line.strip()) for line in file.readlines()]
    return values

def process_paths(*paths, output):
    algorithm_results = {}

    for path in paths:
        # 讀取每個資料夾中的所有演算法名稱
        algorithm_names = read_algorithm_name(path)

        for algorithm_name in algorithm_names:
            mean_regret = read_mean_regret(path, algorithm_name)
            execution_times = read_execution_time(path, algorithm_name)

            # 使用 get_algorithm_key() 來生成唯一鍵值
            algorithm_key = get_algorithm_key(algorithm_name)

            if algorithm_key not in algorithm_results:
                algorithm_results[algorithm_key] = {
                    'mean_regrets': [],
                    'execution_times': [],
                    'full_names': []
                }

            algorithm_results[algorithm_key]['mean_regrets'].append(mean_regret)
            algorithm_results[algorithm_key]['execution_times'].extend(execution_times)
            algorithm_results[algorithm_key]['full_names'].append(algorithm_name)

    visualize_algorithm_performance(algorithm_results, output)

def visualize_algorithm_performance(algorithm_results, output):
    fig, ax = plt.subplots(figsize=(8, 8))
    
    plotted_labels = set()  # 用於追蹤已經繪製的演算法，避免重複

    for algorithm_key, data in algorithm_results.items():
        mean_regrets = data['mean_regrets']
        execution_times = data['execution_times']
        full_names = data['full_names']

        # 計算mean和std
        mean_regret_avg = np.mean(mean_regrets)
        mean_regret_std = np.std(mean_regrets)  # 橫軸 error bar
        execution_time_avg = np.mean(execution_times)
        execution_time_std = np.std(execution_times)  # 縱軸 error bar

        # 獲取演算法名稱（基礎名稱和 diminishing 與否的區別）
        for i, full_name in enumerate(full_names):
            marker = get_algorithm_marker(full_name)
            color = get_algorithm_color(full_name)  # 根據 diminishing 和非 diminishing 決定顏色

            # 繪製散點圖，並添加 error bar
            if full_name not in plotted_labels:
                ax.errorbar(mean_regret_avg, execution_time_avg, 
                             xerr=mean_regret_std, yerr=execution_time_std,
                             fmt=marker, color=color, label=full_name, markersize=20, capsize=7,lw=2)
                plotted_labels.add(full_name)  # 記錄這個名稱
            else:
                ax.errorbar(mean_regret_avg, execution_time_avg, 
                             xerr=mean_regret_std, yerr=execution_time_std,
                             fmt=marker, color=color, markersize=7, capsize=7,lw=2)

    ax.set_xlabel('Mean Regret',fontsize=20)
    ax.set_ylabel('Computation Time (Average)',fontsize=20)
    # 標題和標籤
    # plt.title('Algorithm Performance: Regret vs Computation Time')
    # plt.xlabel('Mean Regret')
    # plt.ylabel('Computation Time (Average)')
    # plt.legend()
    ax.set_facecolor("white")
    for spine in ax.spines.values():
        spine.set_edgecolor("black")
        spine.set_linewidth(3)
    
    # 儲存圖片
    ax.tick_params(axis='both', which='major', labelsize=16)
    plt.savefig(f'{output}.png', format='png')
    plt.savefig(f'{output}.eps', format='eps')
    plt.savefig(f'{output}.pdf', format='pdf')
    plt.show()

# Example usage:
# process_paths('C:\\Users\\USER\\Code\\diminishing-exploration\\plot\\normal\\K3_T20000_N100_M51_envId16_seed10', 
#               'C:\\Users\\USER\\Code\\diminishing-exploration\\plot\\normal\\K3_T20000_N100_M51_envId16_seed100', 
#               'C:\\Users\\USER\\Code\\diminishing-exploration\\plot\\normal\\K3_T20000_N100_M51_envId16_seed1000',
#               output="main")
