import os
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import numpy as np
import pandas as pd

# 1. ribbon plot
def vis_lam(lam_log, episode_means, labels):
    mean = np.mean(lam_log, axis=0)
    std  = np.std(lam_log, axis=0)
    steps = np.arange(len(mean))
    plt.fill_between(steps, mean-std, mean+std, alpha=0.2)
    plt.plot(steps, mean)

    # 2. violin
    df = pd.DataFrame({"lambda_mean": episode_means, "setting": labels})
    sns.violinplot(data=df, x="setting", y="lambda_mean")
    


def plot_alpha(folder_path, save_path="comparison_plot.png", title_env="env"):
    """
    Plots mean ± std curves from all CSV files in a folder, with formal styling.
    CSV filenames must contain 'lam_a<alpha>_d<...>.csv' and include columns: time_step, mean, std.
    
    Parameters:
    - folder_path: str, path to the folder containing CSV files.
    - save_path: str, output path to save the plot image.
    """
    files = []
    alpha_values = []

    # Collect files and parse alpha values
    for filename in os.listdir(folder_path):
        if filename.endswith(".csv") and "lam_a" in filename:
            match = re.search(r'lam_a([\d.]+)_d[\d.]+\.csv', filename)
            if match:
                alpha = float(match.group(1))
                files.append((alpha, filename))
                alpha_values.append(alpha)

    # Sort by alpha
    files.sort(key=lambda x: x[0])
    colors = cm.viridis(np.linspace(0, 1, len(files)))  # Use colormap

    # Temporary storage for computing dynamic y-limits
    all_y_min, all_y_max = [], []
    handles, labels = [], []
    
    plt.figure(figsize=(8, 6), dpi=150)
    for (alpha, filename), color in zip(files, colors):
        file_path = os.path.join(folder_path, filename)
        df = pd.read_csv(file_path)

        if {'time_step', 'mean', 'std'}.issubset(df.columns):
            time = df['time_step']
            mean = df['mean']
            std = df['std']
            
            label = f"$\\alpha={alpha}$"
            line, = plt.plot(time, mean, label=label, color=color, linewidth=2.5)
            plt.fill_between(time, mean - std, mean + std, color=color, alpha=0.2)
            
            handles.append(line)
            labels.append(label)

            # Collect y range
            all_y_min.append((mean - std).min())
            all_y_max.append((mean + std).max())

    # Adjust y-axis range to highlight small differences
    y_min = min(all_y_min)
    y_max = max(all_y_max)
    y_margin = (y_max - y_min) * 0.1  # add 10% margin
    plt.ylim([y_min - y_margin, y_max + y_margin])

    # Final styling
    plt.title(title_env, fontsize=18)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.xlabel("Denoise Step", fontsize=16)
    plt.ylabel("Lam", fontsize=16)
    plt.grid(True, linestyle='--', alpha=0.5)
    #plt.legend(title="Alpha", fontsize=10, title_fontsize=11, loc="best", frameon=True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    save_legend_only(handles, labels, save_path="evaluation/plot/legend_alpha.png", ncol=len(labels))
    

def save_legend_only(handles, labels, save_path="legend_only.png", ncol=None):
    """
    Save a standalone legend as an image.
    
    Parameters:
    - handles: list of line handles (from plt.plot or similar)
    - labels: list of legend labels corresponding to the handles
    - save_path: path to save the legend PNG
    - ncol: number of columns in the legend (optional)
    """
    fig_legend = plt.figure(figsize=(len(labels), 0.8), dpi=150)
    fig_legend.legend(handles, labels, loc='center', ncol=ncol, frameon=True)
    plt.axis('off')
    fig_legend.tight_layout(pad=0.2)
    fig_legend.savefig(save_path, bbox_inches='tight', dpi=150, transparent=True)
    plt.close(fig_legend)


if __name__ == "__main__":
    env_list = ['push', 'tag', 'spread', 'box', 'tennis', 'connect4', 'holdem']
    for env in env_list:
        plot_alpha(folder_path=f"evaluation/plot/{env}", 
                            save_path=f"evaluation/plot/{env}/alpha.png",
                            title_env=env)
    # push, tag, spread, box, tennis, connect4, holdem
