from lpcmdp.algorithm.utils import *
import torch
from lpcmdp.algorithm.model import *
import itertools
import OSRL
from tqdm import tqdm
from lpcmdp.env.FrozenLake import FrozenLakeEnv, FrozenLakeEnv_nocost
from lpcmdp.env.datacollector import DataCollector
from lpcmdp.algorithm.importance_sampling_solver import ImportanceSamplingDiscreteSolver, ImportanceSamplingApproximateSolver
from lpcmdp.algorithm.coptidice import CoptidiceTrainer

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

seed_all(567895) #seed_all(221)
env_nocost = FrozenLakeEnv_nocost(ncol=8, nrow=8)
exp = ValueIteration(env_nocost, 0.01, env_nocost.gamma)
exp.value_iteration()
# plot_policy(np.array(exp.pi), env_nocost, "w")
env = FrozenLakeEnv(ncol=8, nrow=8)

offline_dataset_collector0 = DataCollector(env, expert_pi=exp.pi, percent=0.0)
#print(offline_dataset_collector0.get_behave_policy_prob())
#discretepolicy = ImportanceSamplingDiscreteSolver(env, offline_dataset_collector0, 0, 0.1, 0.0001)
#env.plot_policy(discretepolicy)
#test_goalrate = test(env, discretepolicy, 1000)
#print(test_goalrate)
approximatesolver0 = ImportanceSamplingApproximateSolver(env, offline_dataset_collector0, 0, 0.1, 0.0001, epoches=200000)
approximatepolicy = approximatesolver0.train()
print(approximatesolver0.get_logger()['test_reward'])
print(approximatesolver0.get_logger()['test_cost'])

coptidice_trainer0 = CoptidiceTrainer(env, offline_dataset_collector0, epochs=200000)
coptidice_trainer0.train()
print(coptidice_trainer0.get_logger()['test_reward'])
print(coptidice_trainer0.get_logger()['test_cost'])

offline_dataset_collector50 = DataCollector(env, expert_pi=exp.pi, percent=0.5)


approximatesolver50 = ImportanceSamplingApproximateSolver(env, offline_dataset_collector50, 0, 0.1, 0.0001, epoches=200000, behavior_policy_style='real')
approximatepolicy = approximatesolver50.train()
print(approximatesolver50.get_logger()['test_reward'])
print(approximatesolver50.get_logger()['test_cost'])

coptidice_trainer50 = CoptidiceTrainer(env, offline_dataset_collector50, epochs=200000)
coptidice_trainer50.train()
print(coptidice_trainer50.get_logger()['test_reward'])
print(coptidice_trainer50.get_logger()['test_cost'])

offline_dataset_collector75 = DataCollector(env, expert_pi=exp.pi, percent=0.75)
approximatesolver75 = ImportanceSamplingApproximateSolver(env, offline_dataset_collector50, 0, 0.1, 0.0001, epoches=200000)
approximatepolicy = approximatesolver75.train()
print(approximatesolver75.get_logger()['test_reward'])
print(approximatesolver75.get_logger()['test_cost'])

coptidice_trainer75 = CoptidiceTrainer(env, offline_dataset_collector75, epochs=200000)
coptidice_trainer75.train()
print(coptidice_trainer75.get_logger()['test_reward'])
print(coptidice_trainer75.get_logger()['test_cost'])

offline_dataset_collector25 = DataCollector(env, expert_pi=exp.pi, percent=0.25)
approximatesolver25 = ImportanceSamplingApproximateSolver(env, offline_dataset_collector25, 0, 0.1, 0.0001, epoches=200000)
approximatepolicy = approximatesolver25.train()
print(approximatesolver25.get_logger()['test_reward'])
print(approximatesolver25.get_logger()['test_cost'])

coptidice_trainer25 = CoptidiceTrainer(env, offline_dataset_collector25, epochs=200000)
coptidice_trainer25.train()
print(coptidice_trainer25.get_logger()['test_reward'])
print(coptidice_trainer25.get_logger()['test_cost'])

ours_reward0 = approximatesolver0.get_logger()['test_reward']
coptidice_reward0 = coptidice_trainer0.get_logger()['test_reward']
ours_reward25 = approximatesolver25.get_logger()['test_reward']
coptidice_reward25 = coptidice_trainer25.get_logger()['test_reward']
ours_reward50 = approximatesolver50.get_logger()['test_reward']
coptidice_reward50 = coptidice_trainer50.get_logger()['test_reward']
ours_reward75 = approximatesolver75.get_logger()['test_reward']
copitdice_reward75 = coptidice_trainer75.get_logger()['test_reward']

ours_cost0 = approximatesolver0.get_logger()['test_cost']
coptidice_cost0 = coptidice_trainer0.get_logger()['test_cost']
ours_cost25 = approximatesolver25.get_logger()['test_cost']
coptidice_cost25 = coptidice_trainer25.get_logger()['test_cost']
ours_cost50 = approximatesolver50.get_logger()['test_cost']
coptidice_cost50 = coptidice_trainer50.get_logger()['test_cost']
ours_cost75 = approximatesolver75.get_logger()['test_cost']
copitdice_cost75 = coptidice_trainer75.get_logger()['test_cost']

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# from matplotlib.backends.backend_pdf import PdfPages

# matplotlib.use('PDF')
# matplotlib.rcParams['ps.useafm'] = True
# matplotlib.rcParams['pdf.use14corefonts'] = True
# matplotlib.rcParams['text.usetex'] = True

plt.plot(ours_goal_rates75, label='Alg with 75% optimal data', linestyle='solid', color='red')
plt.plot(ours_goal_rates50, label='Alg with 50% optimal data', linestyle='dashdot', color='red')
plt.plot(ours_goal_rates25, label='Alg with 25% optimal data', linestyle='dashed', color='red')
plt.plot(ours_goal_rates0, label='Alg with 0% optimal data', linestyle='dotted', color='red')

plt.plot(coptidice_goal_rates75, label='COptiDICE with 75% optimal data', linestyle='solid', color='blue')
plt.plot(coptidice_goal_rates50, label='COptiDICE with 50% optimal data', linestyle='dashdot', color='blue')
plt.plot(coptidice_goal_rates25, label='COptiDICE with 25% optimal data', linestyle='dashed', color='blue')
plt.plot(coptidice_goal_rates0, label='COptiDICE with 0% optimal data', linestyle='dotted', color='blue')


# plt.legend(fontsize=8)
plt.xlabel('Train step')
# plt.text(1.0, -0.037, r'$\times 10^3$', verticalalignment='center', horizontalalignment='left', transform=plt.gca().transAxes)
plt.ylabel('Goal Rate')
plt.savefig('Goal-Rate-with-ite1.png')
plt.show()
