from envs.factored_mdp import FactoredMDP, factored_to_tabular
from envs.tabular_mdp import TabularMDP
from envs.random_generators import random_factored, random_factored_sparse
from algs.psrl import psrl, Prior
from algs.factored_psrl import factored_psrl, FactoredPrior
from algs.causal_psrl import causal_psrl, HierarchicalPrior, generate_causal_prior

import numpy as np
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str)                      # {random, random_sparse}
parser.add_argument('--alg', type=str)                      # {psrl, factored_psrl, causal_psrl}
parser.add_argument('--n_iters', type=int)                  # alg iterations
parser.add_argument('--n_seeds', type=int)                  # experiment seeds
parser.add_argument('--mdp_seed', type=int, default=None)   # seed of the random environemnt (for random)
parser.add_argument('--T', type=int)                        # env horizon
parser.add_argument('--d_S', type=int, default=None)        # state features
parser.add_argument('--d_A', type=int, default=None)        # action features
parser.add_argument('--Z', type=int, default=None)          # sparsity
parser.add_argument('--n', type=int, default=None)          # feature support
parser.add_argument('--eta', type=int, default=0)           # edges in the causal prior (for causal psrl)
parser.add_argument('--eps_greedy', type=bool, default=True) # epsilon greedy exploration
parser.add_argument('--eps', type=float, default=1e-2)      # epsilon parameter
args = parser.parse_args()

dir_name = './data/'
# experiment
if args.env == 'random':
    if args.mdp_seed is None:
        dir_name = dir_name + 'random_mdp/'
    else:
        dir_name = dir_name + 'random_mdp_' + str(args.mdp_seed) + '/'
if args.env == 'random_sparse':
    if args.mdp_seed is None:
        dir_name = dir_name + 'random_mdp_sparse/'
    else:
        dir_name = dir_name + 'random_mdp_sparse_' + str(args.mdp_seed) + '/'
d_S = args.d_S
d_A = args.d_A
Z = args.Z
n = args.n
T = args.T
dir_name = dir_name + 'dS_' + str(d_S) + '_dA_' + str(d_A) + '_Z_' + str(Z) + '_n_' + str(n) + '_T_' + str(T) + '/'
if args.alg == 'psrl':
    dir_name = dir_name + 'psrl/'
if args.alg == 'factored_psrl':
    dir_name = dir_name + 'factored_psrl/'
if args.alg == 'causal_psrl':
    dir_name = dir_name + 'causal_psrl_eta_' + str(args.eta) + '/'
n_iters = args.n_iters
n_seeds = args.n_seeds
seeds = range(n_seeds)
dir_name = dir_name + 'iters_' + str(n_iters) + '_seeds_' + str(n_seeds) + '/'

# run
regrets = np.zeros(shape=(n_seeds, n_iters))
errors = np.zeros(shape=(n_seeds, n_iters))
values = np.zeros(shape=(n_seeds, n_iters))
if args.mdp_seed is not None:
    if args.env == 'random':
        np.random.seed(args.mdp_seed)
        fmdp = random_factored(d_S, d_A, Z, n, T)
    elif args.env == 'random_sparse':
        fmdp = random_factored_sparse(d_S, d_A, Z, n, T)
for seed in seeds:
    np.random.seed(seed)
    if args.mdp_seed is None:
        if args.env == 'random':
            fmdp = random_factored(d_S, d_A, Z, n, T)
        elif args.env == 'random_sparse':
            fmdp = random_factored_sparse(d_S, d_A, Z, n, T)
    if args.alg == 'psrl':
        mdp = factored_to_tabular(fmdp)
        prior = Prior(mdp.nS, mdp.nA)
        _, regret, error, value = psrl(mdp, prior, n_iters, eps_greedy=args.eps_greedy, eps=args.eps)
    elif args.alg == 'factored_psrl':
        prior = FactoredPrior(fmdp.scopes, fmdp.n)
        _, regret, error, value = factored_psrl(fmdp, prior, n_iters, eps_greedy=args.eps_greedy, eps=args.eps)
    elif args.alg == 'causal_psrl':
        causal_prior = generate_causal_prior(fmdp.Z, fmdp.scopes, args.eta)
        prior = HierarchicalPrior(fmdp.d_S, fmdp.d_A, fmdp.Z, fmdp.n, causal_prior)
        _, regret, error, value = causal_psrl(fmdp, prior, n_iters, eps_greedy=args.eps_greedy, eps=args.eps)
    regrets[seed] = regret
    errors[seed] = error
    values[seed] = value


################## SAVE RESULTS ##################

# check folder
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

# save raw results
np.savetxt(fname=dir_name + 'regret_raw.csv', X=regrets, delimiter=',')
np.savetxt(fname=dir_name + 'error_raw.csv', X=errors, delimiter=',')
np.savetxt(fname=dir_name + 'value_raw.csv', X=values, delimiter=',')

# save aggregated results
# regret
regrets_mean = np.mean(regrets, axis=0)
regrets_std = np.std(regrets, axis=0)
regrets_lower = regrets_mean - 2 * regrets_std / np.sqrt(n_seeds)
regrets_upper = regrets_mean + 2 * regrets_std / np.sqrt(n_seeds)
table = np.column_stack((regrets_mean, regrets_lower, regrets_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'regret.csv', X=table, delimiter=',', header=header)
# error
errors_mean = np.mean(errors, axis=0)
errors_std = np.std(errors, axis=0)
errors_lower = errors_mean - 2 * errors_std / np.sqrt(n_seeds)
errors_upper = errors_mean + 2 * errors_std / np.sqrt(n_seeds)
table = np.column_stack((errors_mean, errors_lower, errors_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'error.csv', X=table, delimiter=',', header=header)
# value
values_mean = np.mean(values, axis=0)
values_std = np.std(values, axis=0)
values_lower = values_mean - 2 * values_std / np.sqrt(n_seeds)
values_upper = values_mean + 2 * values_std / np.sqrt(n_seeds)
table = np.column_stack((values_mean, values_lower, values_upper))
header = 'mean,lower,upper'
np.savetxt(fname=dir_name + 'value.csv', X=table, delimiter=',', header=header)

