import numpy as np
import pandas as pd
import os
from scipy.stats import spearmanr

perf = pd.read_csv("../target_policies/performances.txt", delimiter='\t', header=None)
perf = perf[~perf[0].duplicated()]

low = perf.loc[perf[1] < perf[1].quantile(.2)].sort_values(1)
low_sampled = low.iloc[np.round(np.linspace(0, low.shape[0]-1, 10)).astype(np.int64)]

mid = perf.loc[np.logical_and(perf[1] >= perf[1].quantile(.2), perf[1] < perf[1].quantile(.6))].sort_values(1)
mid_sampled = mid.iloc[np.round(np.linspace(0, mid.shape[0]-1, 10)).astype(np.int64)]

high = perf.loc[perf[1] >= perf[1].quantile(.6)].sort_values(1)
high_sampled = high.iloc[np.round(np.linspace(0, high.shape[0]-1, 10)).astype(np.int64)]

target_policies = pd.concat([
    low.iloc[np.round(np.linspace(1, low.shape[0]-2, 5)).astype(np.int64)],
    mid.iloc[np.round(np.linspace(1, mid.shape[0]-2, 3)).astype(np.int64)],
    high.iloc[np.round(np.linspace(1, high.shape[0]-2, 2)).astype(np.int64)],
])

target_policies = target_policies.iloc[np.round(np.linspace(0, target_policies.shape[0]-1, 8)).astype(np.int)]

target_policies[5] = target_policies[0].map(lambda x : "/".join(x.split("/")[:-2]) + "/{}.pkl".format(x.split("/")[-1].split(".")[0]))

low_sampled = low_sampled.iloc[np.round(np.linspace(0, low_sampled.shape[0]-1, 5)).astype(np.int)]

target_policies[5].map(lambda x : "/".join(x.split("/")[-1:])).values

iw_results = [i for i in os.listdir("./iw/low")]

target_policies[0] = target_policies[5].map(lambda x : "/".join(x.split("/")[-1:]).replace("/","_").replace(".pkl",""))

del target_policies[5]

for i in range(2,11):
    target_policies[i] = np.nan

target_policies.index = target_policies[0]
del target_policies[0]

for f in iw_results:
    current = "_".join(f.split("_")[:-5])
    target_policies.loc[current, 8] = np.loadtxt("./iw/low/"+f)/.005


target_policies.loc[~target_policies[8].isna(), 1] *= 100
target_policies.loc[~target_policies[8].isna(), 1] += 25


truth = target_policies[1].loc[~target_policies[8].isna()].values
pred = target_policies.loc[:, 8].loc[~target_policies[8].isna()].values

maes = np.asarray([np.abs((truth - pred) / truth).mean()])
ranks = np.asarray([spearmanr(truth, pred)[0] ])
regrets = np.asarray([(truth.max() - truth[np.argmax(pred)]) / truth.max() ])

print("MAE", maes.mean())

print("Rank", ranks.mean())

print("Regret@1", regrets.mean())
