from os import listdir
from sys import argv 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict
from matplotlib.ticker import FuncFormatter

def read_proportions(dir_name):
    print(f"reading results for {dir_name}")
    seeds = listdir(dir_name)
    
    ratios = {
        n_timesteps : []
        for n_timesteps in range(0, int(1e7)+1, 250000)
    }
    for seed in seeds: 
        print(seed)
        for n_timesteps in ratios.keys():
            obtained_prop = pd.read_csv(f"{dir_name}/{seed}/{n_timesteps}/curr_proportion.csv", index_col=False).to_numpy()
            best_prop = pd.read_csv(f"{dir_name}/{seed}/{n_timesteps}/best_proportion.csv", index_col=False).to_numpy()
            ratios[n_timesteps].append((obtained_prop/best_prop).mean())
        
    ratios = OrderedDict(sorted(ratios.items()))
    
    first_quarter = []
    third_quarter = []
    for i in ratios:
        first_quarter.append(np.percentile(ratios[i], 25))
        third_quarter.append(np.percentile(ratios[i], 75))
        ratios[i] = np.median(ratios[i])
    first_quarter = np.array(first_quarter)
    third_quarter = np.array(third_quarter)
    
    x = np.array(list(ratios.keys()))
    y = np.array(list(ratios.values())) 
    return x, y, first_quarter, third_quarter

def read_episode_lengths(dir_name):
    print(f"reading results for {dir_name}")
    seeds = listdir(dir_name)
    
    ep_lengths = {
        n_timesteps : [] 
        for n_timesteps in range(0, int(1e7)+1, 250000) 
    }


    for seed in seeds: 
        print(seed)
        for n_timesteps in ep_lengths.keys():
            ep_length = pd.read_csv(f"{dir_name}/{seed}/{n_timesteps}/episode_lengths.csv", index_col=False).to_numpy()
            ep_lengths[n_timesteps].append(ep_length.mean())
        
    ep_lengths = OrderedDict(sorted(ep_lengths.items()))
    
    first_quarter = []
    third_quarter = []
    for i in ep_lengths:
        first_quarter.append(np.percentile(ep_lengths[i], 25))
        third_quarter.append(np.percentile(ep_lengths[i], 75))
        ep_lengths[i] = np.median(ep_lengths[i])
    first_quarter = np.array(first_quarter)
    third_quarter = np.array(third_quarter)
    
    x = np.array(list(ep_lengths.keys()))
    y = np.array(list(ep_lengths.values())) 
    return x, y, first_quarter, third_quarter

def read_vec_returns(dir_name):
    pass

def read_scalarized_returns(dir_name):
    pass

def plot_results(x, ys, first_quarters, third_quarters, legends, xlabel, ylabel, max_y=1, metric="prop"):
    f_size = 18
    major_xticks = np.arange(0, 1e7+1, int(2e6), dtype=np.int64)
    print(major_xticks)
    cs = ["#1f77b4", "red"] 
    minor_xticks = np.arange(0, 1e7+1, 250000)
    major_yticks = np.arange(0, max_y+1, max_y/5)
    minor_yticks = np.arange(0, max_y+1, max_y/20)
    for i in range(len(ys)):
        #plt.rcParams['axes.facecolor'] = '#d190f040'
        plt.grid(which="both")
        plt.grid(which="major", alpha = 0.7)
        plt.grid(which="minor", alpha = 0.3)
        plt.xticks(minor_xticks, minor=True)
        plt.xticks(major_xticks, fontsize=f_size)
     
        plt.yticks(minor_yticks, minor=True)
        plt.yticks(major_yticks, fontsize=f_size)

        plt.plot(x, ys[i], "-", marker="D", label=legends[i]) #color=cs[i])
        plt.fill_between(x, first_quarters[i], third_quarters[i],  alpha=0.2)
        plt.legend(fontsize=16)
    plt.xlim([0, 10e6])
    plt.ylim([0, max_y]) 
    plt.xlabel(xlabel, fontsize=f_size)
    plt.ylabel(ylabel, fontsize=f_size)



        
    plt.savefig(f"{argv[1]}-{metric}.pdf", format="pdf", bbox_inches="tight")
    plt.show()

env = argv[1]
metric = argv[2] 
if(metric == "prop"):
    max_y = 1
    y_label = "minimal proportion"
    f = read_proportions
elif(metric == "ep_length"): 
    max_y = 1000
    y_label = "episode length"
    f = read_episode_lengths
else:
    raise ValueError()
algorithms = ["gru-dec-EUPG-nash","split-objective(0,1)-gru-Dec-PG", "centralized-partial-gru-EUPG-nash"]
legends = ["Decentralized-EUPG","Decomposition Baseline","Centralized Baseline"]
x, ys,first_quarters,third_quarters =None,  [], [], []
for algo in algorithms:
    curr_dir = f"results/{env}/{algo}"
    # metric = argv[4]
    x, y, first_quarter, third_quarter = f(curr_dir)
    ys.append(y)
    first_quarters.append(first_quarter)
    third_quarters.append(third_quarter)
    
    print("done")
x_label="timesteps" 
plot_results(x, ys, first_quarters, third_quarters, legends, x_label, y_label, max_y, metric)
