# Re-import necessary modules after code execution state reset
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import entropy

################ JSD ################
def compute_rally_length_distribution(
    player_names, player_opps, real_folder, sim_folder,
    output_png="rally_length_distribution.png"
):
    name_map = {
        'bc': 'BC',
        'dd': 'DD',
        'dbc2': 'DBC',
        'dbc': 'DP',
        'ddgi': 'DDGI',
    }
    display_order = ['BC', 'DD', 'DP', 'DBC', 'DDGI']

    os.makedirs(output_png, exist_ok=True)

    def compute_rally_lengths_from_column(rally_series):
        rally_lengths = []
        current_rally = None
        count = 0
        for val in rally_series:
            if val != current_rally:
                if current_rally is not None:
                    rally_lengths.append(count)
                current_rally = val
                count = 1
            else:
                count += 1
        if current_rally is not None:
            rally_lengths.append(count)
        return rally_lengths

    for player_name in player_names:
        rally_data_by_model = {}
        jsd_results = {}

        real_lengths = []
        for csv_file in glob.glob(os.path.join(real_folder, "*.csv")):
            if os.path.basename(csv_file) != f"{player_name}_dataset.csv":
                continue
            df = pd.read_csv(csv_file)
            if 'rally' not in df.columns:
                continue
            real_lengths = compute_rally_lengths_from_column(df['rally'].dropna().astype(int).tolist())
            break

        for csv_file in glob.glob(os.path.join(sim_folder, "*.csv")):
            filename = os.path.basename(csv_file)
            filename_no_ext = filename.replace('.csv', '')
            segments = filename_no_ext.split('_')
            if len(segments) < 4:
                continue
            method_key, file_playerA, _, file_playerB = segments[0], segments[1], segments[2], segments[3]

            if file_playerA != player_name or file_playerB not in player_opps:
                continue

            method = name_map.get(method_key)
            if method is None:
                continue

            df = pd.read_csv(csv_file)
            if 'rally' not in df.columns:
                continue
            lengths = compute_rally_lengths_from_column(df['rally'].dropna().astype(int).tolist())
            rally_data_by_model.setdefault(method, []).extend(lengths)

        fig, axs = plt.subplots(1, len(display_order), figsize=(4 * len(display_order), 4), sharey=True)
        if len(display_order) == 1:
            axs = [axs]

        for i, method in enumerate(display_order):
            ax = axs[i]
            if method not in rally_data_by_model:
                ax.set_visible(False)
                continue
            sim_lengths = rally_data_by_model[method]
            real_counts, _ = np.histogram(real_lengths, bins=range(0, 51), density=True)
            sim_counts, _ = np.histogram(sim_lengths, bins=range(0, 51), density=True)

            p = np.asarray(real_counts)
            q = np.asarray(sim_counts)
            m = 0.5 * (p + q)
            jsd_value = 0.5 * entropy(p, m) + 0.5 * entropy(q, m)
            jsd_results[method] = jsd_value

            sns.histplot(real_lengths, bins=range(0, 51), color="skyblue", stat="density",
                         label="Real", kde=False, ax=ax, alpha=0.3)
            sns.histplot(sim_lengths, bins=range(0, 51), color="sandybrown", stat="density",
                         label="Simulated", kde=False, ax=ax, alpha=0.3)

            ax.text(10, 0.12, f"JSD = {jsd_value:.4f}", fontsize=20, weight='bold')
            ax.set_title(method, fontsize=16, weight='bold')
            if i == 0:
                ax.set_ylabel("Density", fontsize=14)
            ax.set_xlabel("Rally Length", fontsize=14)
            #ax.legend()
            #ax.get_legend().remove()
            
        plt.tight_layout()
        output_path = os.path.join(output_png, f"rally_length_{player_name}.png")
        plt.savefig(output_path)
        plt.close()
        
        
def rally_legend_image(output_path):
    import matplotlib.lines as mlines

    real_patch = mlines.Line2D([], [], color='skyblue', linewidth=10, label='Real', alpha=0.3)
    sim_patch = mlines.Line2D([], [], color='sandybrown', linewidth=10, label='Simulated', alpha=0.3)

    fig, ax = plt.subplots(figsize=(3, 0.8))
    ax.axis("off")
    legend = ax.legend(handles=[real_patch, sim_patch], loc="center", frameon=True, ncol=2)
    for text in legend.get_texts():
        text.set_fontsize(12)

    plt.tight_layout()
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=150)
    plt.close()


################ JSD ################


if __name__ == "__main__":
    player_list = ['Viktor AXELSEN', 'Kento MOMOTA', 'CHOU Tien Chen']
                
    compute_rally_length_distribution(
        player_names=player_list, 
        player_opps=player_list, 
        real_folder="./data/badminton", 
        sim_folder="./evaluation/data/badminton",
        output_png=f"./evaluation/plot/badminton",
    )
    rally_legend_image(output_path="./evaluation/plot/legend_jsd.png")
