import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import glob

# --- Configurable parameters (Please modify here according to your needs) ---

# 1. The root directory of the experimental log
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(CURRENT_DIR)
BASE_LOG_PATH = os.path.join(PROJECT_ROOT, 'exp', 'qc')

# 2. The env name
# Note: The code will automatically match the folder containing this name. For example, 'can' will match 'can-mh-low_dim'.
ENV_NAME_LIST = [
    'can', 
    'lift', 
    'square',
    'cube-double-play-singletask-task2',
    'cube-double-play-singletask-task3',
    'cube-double-play-singletask-task4',
    'cube-triple-play-singletask-task2',
    'cube-triple-play-singletask-task3',
    'cube-triple-play-singletask-task4',
]

# 3. The algorithm name
ALGORITHMS_TO_PLOT = [
    'qc', 
    'bfn', 
    'fql',
    'meanflow'
]

# 4. The column names to be plotted and the X-axis column names
VALUE_TO_PLOT = 'success'
X_AXIS_COLUMN = 'step'

def find_log_files(base_path, algorithms, env_name):
    """Find the paths of all eval.csv files that meet the conditions"""
    all_files = []
    print("Start looking for log files...")
    
    for algo in algorithms:
        # Use glob to match the environment folder, for example 'can-mh-low_dim'
        env_pattern = os.path.join(base_path, algo, f'{env_name}*')
        env_folders = glob.glob(env_pattern)
        
        if not env_folders:
            print(f"Warning: Environment '{env_name}' folder not found in '{os.path.join(base_path, algo)}'.")
            continue
        
        print(f"\nAlogrithm '{algo}' find seed:")
        for env_folder in env_folders:
            # Match all seed folders, such as 'sd00020250819_053848'
            seed_pattern = os.path.join(env_folder, 'sd*')
            seed_folders = glob.glob(seed_pattern)
            
            if not seed_folders:
                print(f"  (No seed folders found in '{env_folder}')")
                continue
            
            for seed_folder in seed_folders:
                seed_name = os.path.basename(seed_folder)
                log_file = os.path.join(seed_folder, 'eval.csv')
                if os.path.exists(log_file):
                    all_files.append((algo, log_file))
                    print(f"  {seed_name}")
                else:
                    print(f"  {seed_name} (No eval.csv)")
    
    print(f"\nA total of {len(all_files)} valid log files were found.")
    return all_files

def load_data(files_to_load, x_col, y_col):
    """Load data from the file list into a pandas DataFrame"""
    data_frames = []
    for algo, file_path in files_to_load:
        try:
            df = pd.read_csv(file_path)
            # Make sure the required columns exist
            if x_col in df.columns and y_col in df.columns:
                temp_df = df[[x_col, y_col]].copy()
                temp_df['algorithm'] = algo.upper() # Convert the algorithm name to uppercase, such as 'qc' -> 'QC'
                
                # add a start point for each algorithm
                start_point = pd.DataFrame({
                    x_col: [0],
                    y_col: [0],
                    'algorithm': [algo.upper()]
                })
                temp_df = pd.concat([start_point, temp_df], ignore_index=True)
                data_frames.append(temp_df)
            else:
                print(f"Warning: The column '{x_col}' or '{y_col}' is missing in the file '{file_path}'.")
        except Exception as e:
            print(f"Error: Error occurred when reading the file '{file_path}' : {e}")
    
    if not data_frames:
        return pd.DataFrame()
        
    return pd.concat(data_frames, ignore_index=True)

def plot_results(data, env_name, x_col, y_col, base_log_path):
    """Plot the results using seaborn"""
    if data.empty:
        print("There is no data available for plotting.")
        return

    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(8, 4))
    plt.axvspan(0, 1000000, alpha=0.4, color='lightgray')

    palette = sns.color_palette("husl", len(data['algorithm'].unique()))
    markers = ['o', 's', 'D', '^', 'v', 'P', '*']

    ax = sns.lineplot(
        data=data,
        x=x_col,
        y=y_col,
        hue='algorithm',
        style='algorithm',
        markers=markers[:len(data['algorithm'].unique())],
        markersize=8,
        linewidth=2.5,
        palette=palette,
        errorbar='sd'  # Use the standard deviation as the error band
    )
    
    ax.set_title(f'{env_name.capitalize()}', fontsize=16)
    ax.set_ylabel('Success Rate', fontsize=14)
    ax.set_xlabel('Steps ($\\times10^6$)', fontsize=14)

    formatter = FuncFormatter(lambda x, pos: f'{x / 1e6:.1f}')
    ax.xaxis.set_major_formatter(formatter)

    ax.set_ylim(-0.05, 1.05)
    
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, title='Algorithm', fontsize=11, title_fontsize=12, loc='best')

    plt.tight_layout()

    output_filename_pdf = os.path.join(base_log_path, f'success_rate_{env_name}.pdf')
    plt.savefig(output_filename_pdf, dpi=300, bbox_inches='tight', format='pdf')
    
    output_filename_png = os.path.join(base_log_path, f'success_rate_{env_name}.png')
    plt.savefig(output_filename_png, dpi=300, bbox_inches='tight', format='png')
    
    print(f"\nThe plotting is completed!")
    print(f"PDF saved as: '{output_filename_pdf}'")
    print(f"PNG saved as: '{output_filename_png}'")
    
    plt.show()


if __name__ == '__main__':
    for env_name in ENV_NAME_LIST:
        print(f"\nProcessing environment: {env_name}")
        # 1. Search for all log files
        log_files = find_log_files(BASE_LOG_PATH, ALGORITHMS_TO_PLOT, env_name)

        if log_files:
            # 2. Load data
            full_data = load_data(log_files, x_col=X_AXIS_COLUMN, y_col=VALUE_TO_PLOT)
            
            # 3. plot Image
            plot_results(full_data, env_name, x_col=X_AXIS_COLUMN, y_col=VALUE_TO_PLOT, base_log_path=BASE_LOG_PATH)
        else:
            print("\nNo log files were found.")
