from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmapy
import random

def plot_bars(means,stds,labels,snapshots,title):
    # width of the bars
    barWidth = 0.3
    num_plots=len(means) # number of methods
    num_clusters=len(means[0]) # number of snapshots
    random.seed(1)
    color_choices=['red','blue','green','pink','yellow','black','gray']
    colors=random.sample(color_choices, num_plots)

    r1 = np.arange(num_clusters) # range from 0-number of snapshots
   
    r_list=[r1]
    for j in range(num_plots-1):

        r2=[x + barWidth for x in r_list[j]]
        r_list.append(r2)

    
    for i in range(num_plots):
        bar=means[i]
        yer=stds[i]
        plt.bar(r_list[i], bar, width = barWidth, color = colors[i], edgecolor = 'black', yerr=yer, capsize=7, label=labels[i])
    
    # general layout
    plt.xticks([r + barWidth for r in r1], snapshots)
    plt.ylabel('Reward after 100k fine-tuning')
    plt.legend()
    plt.title(title)
    # Show graphic
    plt.show()


snapshots=['100000','500000','1000000','2000000']
methods=['alpha1.0beta3.0','alpha1.0beta0.5','shannon']
# tasks=['walker_walk','walker_stand','walker_run','walker_flip']
tasks=['quadruped_stand','quadruped_walk','quadruped_run','quadruped_jump']
num_trials=5
result = list(Path("./data").rglob("*eval.csv"))
result_str=['./'+str(file) for file in result]
# print(result_str)
i=0 # method
j=0 # snapshot
for task in tasks:
    j=0
    print('task',task)
    task_filtered=list(filter(lambda k: task in k, result_str))
    mean_array=np.zeros([len(methods),len(snapshots)]) # mean_array[method][snapshot]
    std_array=np.zeros([len(methods),len(snapshots)])  # std_array[method][snapshot] 
    for snapshot in snapshots:
        i=0
        print('snapshot',snapshot)
        snapshot_filter='_'+snapshot
        snapshot_filtered=list(filter(lambda k: snapshot_filter in k, task_filtered))
        for method in methods:
            print(method)
            print(i,j)
            method_filtered=list(filter(lambda k: method in k, snapshot_filtered))
            # print(method_filtered)
            rewards=[]
            for file in method_filtered:
                data=pd.read_csv(file,usecols = ['episode_reward']) # extract rewards for evals
                final_data=np.ravel(data.tail(1).to_numpy()) # extract last reward for 100k steps
                rewards.append(final_data[0])
            mean_array[i][j]=np.mean(rewards)
            std_array[i][j]=np.std(rewards) 
            print(np.mean(rewards),np.std(rewards))
            print('************************************')
            i+=1
        j+=1
    print(i,j)
    print(mean_array)
    print(std_array)
    plot_bars(mean_array,std_array,methods,snapshots,task)




# task_walk=list(filter(lambda k: '' in k, result_str))
# task_stand
# task_run
# task_flip
# for file in result_str:
#     data=pd.read_csv(file)
#     print(data)
