import os
import pandas as pd
import matplotlib.pyplot as plt

def load_selected_progress_columns(output_dir, selected_columns):
    """
    Loads specific columns from progress.csv files in the directory structure
    'output/{model}/exp_{number}/progress.csv' and combines them into a single DataFrame.

    Parameters:
    - output_dir (str): Root directory containing models and experiments.
    - selected_columns (list of str): List of column names to load from progress.csv files.

    Returns:
    - pd.DataFrame: Combined DataFrame with the selected columns, including additional
      columns for 'model' and 'exp'.
    """
    data_frames = []

    # Walk through the directory structure
    for model in os.listdir(output_dir):
        model_path = os.path.join(output_dir, model)
        if not os.path.isdir(model_path):
            continue

        for exp in os.listdir(model_path):
            exp_path = os.path.join(model_path, exp)
            if not os.path.isdir(exp_path):
                continue

            csv_file = os.path.join(exp_path, 'progress.csv')
            if os.path.exists(csv_file):
                try:
                    # Read the CSV file with selected columns
                    df = pd.read_csv(csv_file, usecols=selected_columns)
                    # Add columns for 'model' and 'exp'
                    df['model'] = model
                    df['exp'] = exp
                    # Append the DataFrame to the list
                    data_frames.append(df)
                except Exception as e:
                    print(f"Error reading {csv_file}: {e}")

    # Combine all DataFrames into one
    combined_df = pd.concat(data_frames, ignore_index=True) if data_frames else pd.DataFrame()
    return combined_df

# Example usage
output_dir = "output"
selected_columns = ['rollout/progress', 'rollout/timesteps', 'rollout/episode_len']  # Specify the columns you want
df = load_selected_progress_columns(output_dir, selected_columns)
df['rewards'] = -df['rollout/episode_len']
df['max_rewards'] = df.groupby(['model', 'exp'])['rewards'].cummax()
df['success'] = (df['max_rewards'] > -19)

a2c_df = df[df['model'] == 'a2c_no_explore']
lt_a2c_df = df[df['model'] == 'lt_a2c_no_explore']

# a2c_df = df[df['model'] == 'a2c_no_explore2']
# lt_a2c_df = df[df['model'] == 'lt_a2c_no_explore2']

print(a2c_df)
print(lt_a2c_df)

a2c_sr = a2c_df.groupby('rollout/timesteps')['success'].mean()
lt_a2c_sr = lt_a2c_df.groupby('rollout/timesteps')['success'].mean()

plt.plot(a2c_sr, label='A2C')
plt.plot(lt_a2c_sr, label='LT_A2C')

plt.xlabel('Timesteps')
plt.ylabel('Success Rate')

plt.title('Success escape rate vs Timesteps (100 experiments)')

plt.legend()
plt.savefig('success_rate.png')

# a2c_median = a2c_df.groupby('rollout/progress')['rewards'].median()
# lt_a2c_median = lt_a2c_df.groupby('rollout/progress')['rewards'].median()

# # a2c_std = a2c_df.groupby('rollout/progress')['rewards'].std()
# # lt_a2c_std = lt_a2c_df.groupby('rollout/progress')['rewards'].std()

# a2c_upper = a2c_df.groupby('rollout/progress')['rewards'].quantile(1)
# lt_a2c_upper = lt_a2c_df.groupby('rollout/progress')['rewards'].quantile(1)

# a2c_lower = a2c_df.groupby('rollout/progress')['rewards'].quantile(0.1)
# lt_a2c_lower = lt_a2c_df.groupby('rollout/progress')['rewards'].quantile(0.1)


# plt.plot(a2c_median, label='A2C')
# plt.fill_between(a2c_median.index, a2c_lower, a2c_upper, alpha=0.3)

# plt.plot(lt_a2c_median, label='LT_A2C')
# plt.fill_between(lt_a2c_median.index, lt_a2c_lower, lt_a2c_upper, alpha=0.3)

# plt.xlabel('Progress')
# plt.ylabel('Episode Length')

# plt.ylim(-200,0)

# plt.legend()
# plt.show()