from envs.factored_mdp import FactoredMDP
from envs.taxi.taxi import TaxiEnv
from algs.taxi.causal_psrl import causal_psrl, HierarchicalPrior, generate_causal_prior

import numpy as np
import argparse
import os



dir_name = './data_taxi/'
# experiment
parser = argparse.ArgumentParser()
parser.add_argument('--n_seed', type=int)                  # experiment seed
args = parser.parse_args()


d_S = 4
d_A = 1
Z = 5
n_iters = 400
n = np.array([5,5,2,1])
nactions = np.array([6])
T = 15
n_runs = 20
nn = np.concatenate((n,nactions))
seed = np.load("/causal-psrl/scripts/taxi/seeds.npy")[args.n_seed]

mdp = TaxiEnv()
P,R = mdp.get_transitions_rewards()
factors = mdp.factored_P
scopes = mdp.scopes
nS = mdp.num_states
nA = mdp.num_actions
fmdp = FactoredMDP(d_S, d_A, Z, R, n, nactions, T)
fmdp.setscopes(scopes)
fmdp.setP(factors)
mu = np.zeros(fmdp.nS)
mu[0] = 1
fmdp.setmu(mu=mu)

dir_name = dir_name + 'dS_' + str(d_S) + '_dA_' + str(d_A) + '_Z_' + str(Z) + '_n_' + str(n) + '_T_' + str(T) + '/'

n_iters = 400
n_seeds = len(seed)
dir_name = dir_name + 'iters_' + str(n_iters) + '_seed_' + str(seed) + '/'

# run
regrets = np.zeros(shape=(n_seeds, n_iters))
errors = np.zeros(shape=(n_seeds, n_iters))
values = np.zeros(shape=(n_seeds, n_iters))

np.random.seed(seed)
causal_prior = generate_causal_prior(fmdp.Z, fmdp.scopes, 2)
prior = HierarchicalPrior(fmdp.d_S, fmdp.d_A, fmdp.Z, np.concatenate((fmdp.n, fmdp.nactions)), causal_prior)
_, r, error, value = causal_psrl(fmdp, R, prior, n_iters)
regrets[seed] = r
errors[seed] = error
values[seed] = value


################## SAVE RESULTS ##################

# check folder
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

# save raw results
np.savetxt(fname=dir_name + 'regret_raw.csv', X=regrets, delimiter=',')
np.savetxt(fname=dir_name + 'error_raw.csv', X=errors, delimiter=',')
np.savetxt(fname=dir_name + 'value_raw.csv', X=values, delimiter=',')

# save aggregated results
# regret
regrets_mean = np.mean(regrets, axis=0)
regrets_std = np.std(regrets, axis=0)
regrets_lower = regrets_mean - 2 * regrets_std / np.sqrt(n_seeds)
regrets_upper = regrets_mean + 2 * regrets_std / np.sqrt(n_seeds)
table = np.column_stack((regrets_mean, regrets_lower, regrets_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'regret.csv', X=table, delimiter=',', header=header)
# error
errors_mean = np.mean(errors, axis=0)
errors_std = np.std(errors, axis=0)
errors_lower = errors_mean - 2 * errors_std / np.sqrt(n_seeds)
errors_upper = errors_mean + 2 * errors_std / np.sqrt(n_seeds)
table = np.column_stack((errors_mean, errors_lower, errors_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'error.csv', X=table, delimiter=',', header=header)
# value
values_mean = np.mean(values, axis=0)
values_std = np.std(values, axis=0)
values_lower = values_mean - 2 * values_std / np.sqrt(n_seeds)
values_upper = values_mean + 2 * values_std / np.sqrt(n_seeds)
table = np.column_stack((values_mean, values_lower, values_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'value.csv', X=table, delimiter=',', header=header)

