import numpy as np
import solvers, mdps
import os 

import warnings
warnings.filterwarnings('ignore')

ALGORITHMS = {
    # "Random Policy": lambda mdp: solvers.UniformRandom(mdp.info),
    # "Optimal Policy": lambda mdp: solvers.LinearProgramming(mdp.info, mdp=mdp), # to get the optimal policy, we cheat
    "CAP (kappa=0.00)": lambda mdp: solvers.ReplayLP(mdp.info, mdp=mdp),
    "CAP (kappa=0.01)": lambda mdp: solvers.ReplayLP(mdp.info, kappa=0.01, mdp=mdp),
    "CAP (kappa=0.05)": lambda mdp: solvers.ReplayLP(mdp.info, kappa=0.05, mdp=mdp),
    "CAP (kappa=0.1)": lambda mdp: solvers.ReplayLP(mdp.info, kappa=0.1, mdp=mdp),
    "CAP": lambda mdp: solvers.ReplayLPAdaptive(mdp.info, kappa=0.0, mdp=mdp),
}

DIMS = [8,8]
C_LIMIT = 0.1
TOTAL_DATA = 500
TRIALS = 100
ITERATIONS = 30
results = [["algo", "trial", "iter", "return", "cost", "violation", "feasible", "kappa"]]

for trial_j in range(TRIALS):
    print("Trial %d/%d" % (trial_j, TRIALS), flush=True)
    mdp = mdps.ConstrainedGridworld(DIMS, C_limit=C_LIMIT)
    learners = {algo: fn(mdp) for algo, fn in ALGORITHMS.items()}
    for algo in learners:
        initial_data_size = 0
        collection_policy = solvers.UniformRandom(mdp.info).get_policy()
        violation = 0
        feasible_count = 0
        for iter_i in range(ITERATIONS):
            # collect data
            other_data = mdp.sample_data(policy=collection_policy,
             count=TOTAL_DATA - initial_data_size, state_dist=mdp.s0)
            # update model
            learners[algo].ingest_data(other_data)
            # optimize policy
            policy, rho_star, feasible = learners[algo].get_policy()
            policy_return, policy_cost = learners[algo].evaluate_policy(policy, rho_star)

            if feasible:
                feasible_count += 1
            if policy_cost > C_LIMIT + 1e-3:
                violation += 1
            results.append([algo, trial_j, iter_i, policy_return, policy_cost, violation, feasible_count,
            learners[algo].kappa])
            collection_policy = policy 

            # no need to iterate over ground truth optimal LP solution
            # if algo == 'Optimal Policy':
            #     break 

results_dir = 'results'
if not os.path.exists(results_dir):
    os.makedirs(results)
with open(f"{results}/cmdp_TRIALS{TRIALS}_ITERATIONS{ITERATIONS}.csv", "w") as f:
    s = "\n".join([",".join([str(x) for x in line]) for line in results])
    f.write(s)