import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shutil

cur_dir = os.getcwd()
output_dir = os.path.join(cur_dir, 'output')


class Visualization(object):
    def __init__(self, model_list, n_tail=50, show_outliers=False):
        self.model_list = model_list
        self.model_num = len(model_list)
        # self.metric_list = metric_list
        self.n_tail = n_tail
        self.show_outliers = show_outliers
        
        self.df = self.load_progress()
        self.df = self.df.reindex(sorted(self.df.columns), axis=1)
        
        self.df['succeed'] = (self.df['episode_len'] < 50)
        print(self.df.groupby(['model'])['succeed'].mean())
        
        self.df = self.df[self.df['succeed'] == True]
        
        self.df =  self.df.groupby(['model', 'exp']).mean().reset_index()
        
        # print(self.df)
        # raise
        # self.df_exp_mean = self.df_exp_mean[self.df_exp_mean['episode_len']<20] 
        # print(self.df_exp_mean[self.df_exp_mean['episode_len']<20])
        
    
    def load_progress(self):
        df_list = []
        for model in self.model_list:
            model_dir = os.path.join(output_dir, model)
            print(model)
            for exp in os.listdir(model_dir):
                results_path = os.path.join(model_dir, exp, 'progress.csv')
                if os.path.exists(results_path):
                    try:
                        df_temp = pd.read_csv(results_path)
                        df_temp = df_temp[[col for col in df_temp.columns if col.startswith('metric')] + ['rollout/episode_len']]
                        df_temp = df_temp.dropna()
                        df_temp = df_temp.tail(self.n_tail)
                        df_temp.insert(loc=0, column='exp', value=int(exp.split('_')[-1]))
                        df_temp.insert(loc=0, column='model', value=model)
                    except:
                        print('fail')
                        print(exp)

                else:
                    print(os.path.join(model_dir, exp))
                    print('no result')
                    continue
                # print(df_temp)
                try:
                    df_list.append(df_temp)
                except:
                    print(exp)

        print(len(df_list))
        df = pd.concat(df_list).reset_index(drop=True)
        df.columns = [col.split('/')[-1] for col in df.columns]
        return  df


    def filter_outliers(self, group_col, value_col):
        Q1 = self.df.groupby(group_col)[value_col].transform('quantile', 0.25)
        Q3 = self.df.groupby(group_col)[value_col].transform('quantile', 0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        print(self.df)
        return self.df[(self.df[value_col] >= lower_bound) & (self.df[value_col] <= upper_bound)]


    def trimmed_dataframe(self, df, group_col, value_cols):
        # Determine the quantiles for trimming
        def remove_outliers(sub_df):
            Q1 = sub_df.quantile(0.25)
            Q3 = sub_df.quantile(0.75)
            IQR = Q3 - Q1
            return sub_df[~((sub_df < (Q1 - 1.5 * IQR)) | (sub_df > (Q3 + 1.5 * IQR))).any(axis=1)]
        
        # Apply the outlier removal and calculate mean
        mean = df.groupby(group_col).apply(lambda x: remove_outliers(x[value_cols])).groupby(group_col).mean()
        std = df.groupby(group_col).apply(lambda x: remove_outliers(x[value_cols])).groupby(group_col).std()
        return mean, std
    
    
    def print_metrics(self):
        df = self.df.copy()
        
        metric_list = ['cr', 'mse', 'range', 'kl_div']
        
        trimmed_mean, trimmed_std = self.trimmed_dataframe(df, 'model', metric_list)
        # for metric in metric_avg_list:
        #     df[metric+'_avg'] = df[[col for col in df.columns if col.startswith(metric)]].mean(axis=1)
        
        save_path = os.path.join(cur_dir,'metrics.txt')
        with open(save_path, 'w') as f:
            
            df_median = df.groupby('model').median().round(5)
            df_mean = df.groupby('model').mean().round(5)
            df_std = df.groupby('model').std().round(5)
            f.write('\n ---------------------------- median -----------------------------\n')
            f.write(df_median.to_string())
            
            # for metric in metric_avg_list:
                # print(df[metric+'_avg'])
                # df_trimmed = self.trimmed_dataframe(df, 'model', metric+'_avg')
                # df_trimmed_mean = df_trimmed.groupby('model')[[col for col in df.columns if col.startswith(metric)]].mean().round(5)
                # df_trimmed_std = df_trimmed.groupby('model')[[col for col in df.columns if col.startswith(metric)]].std().round(5)
            f.write('\n ----------------------------  average ----------------------------\n')
            f.write(df_mean.to_string())
            f.write('\n -----------------------  standard deviation -----------------------\n')
            f.write(df_std.to_string())  
            f.write('\n ----------------------------  trimmed mean ----------------------------\n')
            f.write(trimmed_mean.to_string())
            f.write('\n ----------------------------  trimmed std ----------------------------\n')
            f.write(trimmed_std.to_string())
        f.close()
        
    def create_melt_df(self, metric_name):    
        col_list = [metric_name]
        
        
        df_melt = self.df.melt(
            id_vars = 'model',
            value_vars = col_list,
            var_name = 'columns'
        )
        # df_melt['columns'] = df_melt['columns'].replace({col: act for col, act in zip(col_list, act_list)})
        return df_melt
    
        
    def plot_cr_boxplot(self, mapping=None):
        df = self.df.copy()
        if mapping is None:
            mapping = {}
            for key in df['model'].unique():
                mapping[key] = key
        df['model'] = df['model'].replace(mapping)
        
        plt.figure(figsize=(16, 1 + self.model_num/3))
        ax = sns.boxplot(data=df, x='cr', y='model', orient = 'h',
                        order=mapping.values() ,
                            meanprops={"marker":"o",
                        "markerfacecolor":"white", 
                        "markeredgecolor":"black",
                        "markersize":"8"}
                    , boxprops={ "alpha": 0.3}
                    , showfliers=self.show_outliers
                    , showmeans=False
                    )
        plt.xlabel('Coverage rate', fontsize=14)
        plt.ylabel('Algorithm', fontsize=14)   
        # plt.title('Boxplot grouped by model')
        plt.tight_layout()     
        
        x_min, x_max = 0.3, 1.0
        x_ticks = np.arange(x_min, x_max, 0.05)
        x_ticks_major = np.arange(x_min, x_max+0.05, 0.1)
        
        ax.set_xticks(x_ticks, minor=True)
        ax.set_xticks(x_ticks_major, minor=False)
        ax.axvline(x=0.95, color='red', linestyle='--', label='95%')
        ax.xaxis.grid(True, which='major', linestyle='-', linewidth=1)
        ax.xaxis.grid(True, which='minor', linestyle='--', linewidth=0.5)
        # legend = plt.legend(fontsize=11, loc='lower right')
        legend = plt.legend(fontsize=11, loc='upper left')
        legend.get_title().set_text('') 
        plt.xticks(fontsize=11)  # Adjust the fontsize as needed for the x-axis
        plt.yticks(fontsize=11)
        # sns.set(font_scale=1.3)
        plt.savefig(os.path.join('boxplot','cr_boxplot.png'))
        print('cr boxplot saved')
        plt.close()
    
    def plot_kl_boxplot(self, mapping=None):
        df = self.df.copy()
        if mapping is None:
            mapping = {}
            for key in df['model'].unique():
                mapping[key] = key
        df['model'] = df['model'].replace(mapping)
        
        plt.figure(figsize=(16, 1 + self.model_num/3))
        ax = sns.boxplot(data=df, x='kl_div', y='model', orient = 'h',
                        order=mapping.values() ,
                            meanprops={"marker":"o",
                        "markerfacecolor":"white", 
                        "markeredgecolor":"black",
                        "markersize":"8"}
                    , boxprops={ "alpha": 0.3}
                    , showfliers=self.show_outliers
                    , showmeans=False
                    )
        plt.xlabel("KL-divergence", fontsize=14)
        plt.ylabel('Algorithm', fontsize=14)   
        plt.tight_layout()     
        
        plt.xticks(fontsize=11)  # Adjust the fontsize as needed for the x-axis
        plt.yticks(fontsize=11)
        plt.xscale('log')
        plt.savefig(os.path.join('boxplot','kl_boxplot.png'))
        print('kl boxplot saved')
        plt.close()    
    
    def plot_range_boxplot(self, mapping=None):
        df = self.df.copy()
        if mapping is None:
            mapping = {}
            for key in df['model'].unique():
                mapping[key] = key
        df['model'] = df['model'].replace(mapping)

        
        plt.figure(figsize=(16, 1 + self.model_num/3))
        ax = sns.boxplot(data=df, x='range', y='model', orient = 'h',
                        order=mapping.values() ,
                            meanprops={"marker":"o",
                        "markerfacecolor":"white", 
                        "markeredgecolor":"black",
                        "markersize":"8"}, 
                            boxprops={ "alpha": 0.3}, 
                            showfliers=self.show_outliers)
        plt.xlabel('Interval range', fontsize=14)
        plt.ylabel('Algorithm', fontsize=14)   
        # plt.title('Boxplot grouped by model')
        plt.tight_layout()     
        
        # plt.xlim(left=0.0)
        ax.xaxis.grid(True, which='major', linestyle='-', linewidth=1)
        ax.xaxis.grid(True, which='minor', linestyle='--', linewidth=0.5)
        # legend = plt.legend(fontsize=11, loc='upper right')
        # legend.get_title().set_text('') 
        plt.xticks(fontsize=11)  # Adjust the fontsize as needed for the x-axis
        plt.yticks(fontsize=11)
        # plt.xscale('log')
        # sns.set(font_scale=1.3)
        plt.savefig(os.path.join('boxplot','range_boxplot.png'))
        print('range boxplot saved')
        plt.close()
    
    def plot_mse_boxplot(self, mapping=None):
        df = self.df.copy()
        if mapping is None:
            mapping = {}
            for key in df['model'].unique():
                mapping[key] = key
        df['model'] = df['model'].replace(mapping)
        
        plt.figure(figsize=(16, 1 + self.model_num/3))
        ax = sns.boxplot(data=df, x='mse', y='model', orient = 'h',
                        order=mapping.values() ,
                            meanprops={"marker":"o",
                        "markerfacecolor":"white", 
                        "markeredgecolor":"black",
                        "markersize":"8"}
                    , boxprops={ "alpha": 0.3}, showfliers=self.show_outliers)
        
        ax.xaxis.grid(True, which='major', linestyle='-', linewidth=0.5)
        ax.tick_params(axis='x', which='minor', bottom=False)

        plt.xlabel(r'$MSE(\hat{V})$', fontsize=14)
        plt.ylabel('Algorithm', fontsize=14)  
        plt.tight_layout()     
        plt.xticks(fontsize=11)  # Adjust the fontsize as needed for the x-axis
        plt.yticks(fontsize=11)
        plt.xscale('log')
        
        plt.savefig(os.path.join('boxplot','mse_boxplot.png'))
        print('mse boxplot saved')
        plt.close()
        
        
def plot_cr_curves(model):
    df_list = []
    model_dir = os.path.join(output_dir, model)
    for exp in os.listdir(model_dir):
        results_path = os.path.join(model_dir, exp, 'progress.csv')
        if os.path.exists(results_path):
            try:
                df_temp = pd.read_csv(results_path)
                df_temp = df_temp[[col for col in df_temp.columns if col.startswith('metric')]+['rollout/timesteps']]
                df_temp.insert(loc=0, column='exp', value=int(exp.split('_')[-1]))
                df_temp.insert(loc=0, column='model', value=model)
                df_list.append(df_temp)
            except:
                print("{}: {}".format(model, exp))
        else:
            print(os.path.join(model_dir, exp))
            print('no result')
            continue
        # print(df_temp)
            
    df = pd.concat(df_list).reset_index(drop=True)
    df.columns = [col.split('/')[-1] for col in df.columns]
    df = df.dropna()
    # df_melted = df.melt(id_vars=['timesteps', 'exp'], value_vars=['cr'], 
    #                 var_name='Variable', value_name='Value')
    # sns.lineplot(x='timesteps', y='Value', hue='Variable', data=df_melted, estimator='median', errorbar='sd')
    sns.lineplot(x='timesteps', y='cr', data=df, estimator='median', errorbar='pi')
    plt.axhline(y=0.95, color='red', linestyle='--', label='95%')
    plt.ylim(0.0, 1.0)
    plt.xlabel("Timesteps")
    plt.ylabel("Coverage Rate")
    plt.title('Coverage Rate of Prediction Interval')
    plt.savefig(os.path.join('cr_plot','{}.png'.format(model)))
    plt.close()
    
    
def plot_cr_curves_all(models):
    all_data = []

    for model in models:
        model_dir = os.path.join(output_dir, model)
        if not os.path.exists(model_dir):
            print(f"Directory {model_dir} does not exist.")
            continue
        
        for exp in os.listdir(model_dir):
            results_path = os.path.join(model_dir, exp, 'progress.csv')
            if os.path.exists(results_path):
                df_temp = pd.read_csv(results_path)
                df_temp = df_temp[[col for col in df_temp.columns if col.startswith('metric')]+['rollout/timesteps']]
                df_temp.insert(loc=0, column='exp', value=int(exp.split('_')[-1]))
                df_temp.insert(loc=0, column='model', value=model)
                all_data.append(df_temp)
            else:
                print(f"No result found for {os.path.join(model_dir, exp)}")

    if not all_data:
        print("No data to plot.")
        return

    # Concatenate all dataframes
    df = pd.concat(all_data).reset_index(drop=True)
    df.columns = [col.split('/')[-1] for col in df.columns]
    df = df.dropna()

    # Plotting
    sns.lineplot(x='timesteps', y='cr', hue='model', data=df, estimator='median', errorbar='sd')
    plt.axhline(y=0.95, color='red', linestyle='--', label='95% Target')
    plt.ylim(0.0, 1.0)
    plt.xlabel("Timesteps")
    plt.ylabel("Coverage Rate")
    plt.title('Coverage Rate of Prediction Interval Across Models')
    plt.legend(title='Model')
    plt.savefig(os.path.join('cr_plot','cr_plot.png'))
    plt.close()
    
if __name__ == '__main__':

    name_list = [f for f in os.listdir(output_dir) if not f.startswith('.') and os.path.isdir(os.path.join(output_dir, f))]
    name_list = sorted(name_list)
    name_mapping = None
    
    name_mapping = {       

                    
                    # 'a2c': r'A2C',
                    'a2c_new': r'A2C',

                    
                    # 'lt_a2c_new_pp10': r'LT-A2C: $\mathcal{N}=1\times 10^4$',
                    # 'lt_a2c_new_pp20': r'LT-A2C: $\mathcal{N}=2\times 10^4$',
                    # 'lt_a2c_new_pp40': r'LT-A2C: $\mathcal{N}=4\times 10^4$',
                    
                    'lt_a2c_v2_pp10': r'LT-A2C: $\mathcal{N}=1\times 10^4$',
                    'lt_a2c_v2_pp20': r'LT-A2C: $\mathcal{N}=2\times 10^4$',
                    'lt_a2c_v2_pp40': r'LT-A2C: $\mathcal{N}=4\times 10^4$',

                    'ppo_new': r'PPO',
                    

                    # 'lt_ppo_new_pp10': r'LT-PPO: $\mathcal{N}=1\times 10^4$',
                    # 'lt_ppo_new_pp20': r'LT-PPO: $\mathcal{N}=2\times 10^4$',
                    # 'lt_ppo_new_pp40': r'LT-PPO: $\mathcal{N}=4\times 10^4$',
                    
                    'lt_ppo_v2_pp10': r'LT-PPO: $\mathcal{N}=1\times 10^4$',
                    'lt_ppo_v2_pp20': r'LT-PPO: $\mathcal{N}=2\times 10^4$',
                    'lt_ppo_v2_pp40': r'LT-PPO: $\mathcal{N}=4\times 10^4$',
    }
    

    name_list = [


                    'a2c_new',


                    # 'lt_a2c_new_pp10',
                    # 'lt_a2c_new_pp20',
                    # 'lt_a2c_new_pp40',

                    'lt_a2c_v2_pp10',
                    'lt_a2c_v2_pp20',
                    'lt_a2c_v2_pp40',

                    'ppo_new',

                    # 'lt_ppo_new_pp10',
                    # 'lt_ppo_new_pp20',
                    # 'lt_ppo_new_pp40',

                    'lt_ppo_v2_pp10',
                    'lt_ppo_v2_pp20',
                    'lt_ppo_v2_pp40',
    ]

    
    print(name_list)
    
    visualization = Visualization(name_list, 10, False)

    visualization.print_metrics()
    
    visualization.plot_cr_boxplot(name_mapping)
    visualization.plot_kl_boxplot(name_mapping)
    visualization.plot_mse_boxplot(name_mapping)
    visualization.plot_range_boxplot(name_mapping)
    
    
    # Specify the directory
    folder = os.path.join(cur_dir, 'cr_plot')


    # Iterate over each item in the directory
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            # Check if it is a file or a directory
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))
    
    
    
    models = [
        # 'lt_a2c_sghmc_pp05k',
        # 'lt_a2c_sghmc_pp10k',
        # 'lt_a2c_sghmc_pp20k',
        
        # 'lt_a2c_sgld_pp05k',
        'lt_a2c_sgld_pp10k',
        # 'lt_a2c_sgld_pp20k',
        
        # 'a2c_rms',
        'a2c_adam',
        
        # 'ppo_rms',
        'ppo_adam',
    ]
    

    for model in name_list:
        plot_cr_curves(model)

    
    # plot_cr_curves_all(models)