import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

def process_real_landing(player_name, real_folder):
    player_landings = {player_name: {'win': [], 'lose': []}}

    target_file = os.path.join(real_folder, f"{player_name}_dataset.csv")
    if not os.path.exists(target_file):
        print(f"No real data found for {player_name}")
        return player_landings

    df = pd.read_csv(target_file)
    df = df.dropna(subset=['time'])
    if 'rally' not in df.columns or 'player' not in df.columns:
        return player_landings

    rally_ids = df['rally'].unique()

    for i in range(len(rally_ids) - 1):
        rally_i = df[df['rally'] == rally_ids[i]]
        rally_next = df[df['rally'] == rally_ids[i + 1]]

        if rally_i.empty or rally_next.empty:
            continue

        last_row = rally_i.iloc[-1]
        current_player = last_row['player']
        next_player = rally_next.iloc[-1]['player']

        if current_player != player_name:
            continue

        x = last_row['landing_x']
        y = last_row['landing_y']
        if pd.isna(x) or pd.isna(y):
            continue

        if current_player == next_player:
            y_coord = (y + 1) / 2
            player_landings[player_name]['win'].append((x, y_coord))
        else:
            y_coord = - (y + 1) / 2
            player_landings[player_name]['lose'].append((-x, y_coord))

    return player_landings


def process_sim_landing(player_name, player_opps, sim_folder):
    sim_landings = {}

    for csv_file in glob.glob(os.path.join(sim_folder, "*.csv")):
        filename = os.path.basename(csv_file).replace('.csv', '')
        segments = filename.split('_')
        if len(segments) < 4:
            continue
        method, file_playerA, _, file_playerB = segments[0], segments[1], segments[2], segments[3]
        if file_playerA != player_name or file_playerB not in player_opps:
            continue

        df = pd.read_csv(os.path.join(sim_folder, f"{filename}.csv"))
        if 'rally' not in df.columns or 'player' not in df.columns:
            continue

        sim_landings.setdefault(method, {'win': [], 'lose': []})
        rally_ids = df['rally'].unique()

        for i in range(len(rally_ids) - 1):
            rally_i = df[df['rally'] == rally_ids[i]]
            rally_next = df[df['rally'] == rally_ids[i + 1]]
            if rally_i.empty or rally_next.empty:
                continue

            last_row = rally_i.iloc[-1]
            current_p = last_row['player']
            next_p = rally_next.iloc[-1]['player']

            if current_p != 'A':
                continue

            if 'landing_x' not in last_row or 'landing_y' not in last_row:
                continue

            if last_row['landing_y'] < 0 or last_row['landing_y'] > 480:
                continue

            x = last_row['landing_x'] / 177.5
            y = last_row['landing_y'] / 480

            if current_p == next_p:
                sim_landings[method]['win'].append((x, y))
            else:
                sim_landings[method]['lose'].append((-x, -y))

    return sim_landings


def plot_landing_density(
    player_name, real_data, sim_data, output_dir="output_landing_density_colored"
):
    os.makedirs(output_dir, exist_ok=True)
    name_map = {
        'bc': 'BC',
        'dd': 'DD',
        'dbc': 'DBC',
        'dP': 'DP',
        'ddgi': 'DDGI',
    }
    display_order = ['Real', 'BC', 'DD', 'DP', 'DBC', 'DDGI']
    methods = ['Real'] + [name_map[k] for k in name_map if k in sim_data]
    n_methods = len(display_order)
    fig, axs = plt.subplots(1, n_methods, figsize=(5 * n_methods, 10), sharey=True)

    if n_methods == 1:
        axs = [axs]

    def draw_court_background(ax):
        for x in [-0.9, 0.0, 0.9]:
            ax.axvline(x, color='black', linewidth=1)
        for y in [-0.9, -0.2, 0.0, 0.2, 0.9]:
            ax.axhline(y, color='black', linewidth=1)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_aspect('equal')
        ax.axis('off')

    def kde_contour(ax, data, color):
        if len(data) > 1:
            data = np.array(data)
            x = data[:, 0]
            y = data[:, 1]
            
            ##
            """jitter_eps = 0.08
            x += np.random.normal(0, jitter_eps, size=x.shape)
            y += np.random.normal(0, jitter_eps, size=y.shape)"""
            
            xy = np.vstack([x, y])
            kde = gaussian_kde(xy)
            xgrid = np.linspace(-1, 1, 100)
            ygrid = np.linspace(-1, 1, 100)
            X, Y = np.meshgrid(xgrid, ygrid)
            positions = np.vstack([X.ravel(), Y.ravel()])
            Z = np.reshape(kde(positions), X.shape)
            ax.contourf(X, Y, Z, levels=15, cmap=color, alpha=0.6)
            
            
    for i, method in enumerate(methods):
        ax = axs[i]
        if method == 'Real':
            win = real_data['win']
            lose = real_data['lose']
        else:
            inverse_map = {v: k for k, v in name_map.items()}
            internal_name = inverse_map.get(method)
            if internal_name not in sim_data:
                continue
            win = sim_data[internal_name]['win']
            lose = sim_data[internal_name]['lose']

        kde_contour(ax, win, 'Blues')
        kde_contour(ax, lose, 'Reds')
        draw_court_background(ax)
        ax.set_title(method, fontsize=24, weight='bold')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{player_name}_landing.png"))
    plt.close()


def kde_colorbar_legend(output_path="output_landing_density_colored/kde_colorbar.png"):
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    import matplotlib.cm as cm
    import numpy as np

    fig, ax = plt.subplots(figsize=(1.8, 6))
    fig.subplots_adjust(left=0.5, right=0.7)

    gradient = np.linspace(-1, 1, 256).reshape(256, 1)
    cmap = cm.get_cmap('coolwarm')

    ax.imshow(gradient, aspect='auto', cmap=cmap, origin='lower')
    ax.set_xticks([])
    ax.set_yticks([0, 64, 128, 192, 255])
    ax.set_yticklabels(['-1.0', '-0.5', '0.0', '0.5', '1.0'], fontsize=11)
    ax.set_title('Density', fontsize=12)
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position("right")

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=150)
    plt.close()


if __name__ == "__main__":
    player_list = ['Viktor AXELSEN', 'Kento MOMOTA', 'CHOU Tien Chen']
    
    real_data = process_real_landing('CHOU Tien Chen', "./data/badminton")
    sim_data = process_sim_landing('CHOU Tien Chen', ['Kento MOMOTA', 'Viktor AXELSEN'], "./evaluation/data/badminton")

    plot_landing_density(
        player_name="CHOU Tien Chen",
        real_data=real_data["CHOU Tien Chen"],
        sim_data=sim_data,
        output_dir="./evaluation/plot/badminton"
    )
    
    #kde_colorbar_legend(output_path="./evaluation/plot/legend_landing.png")
                