from lpcmdp.algorithm.utils import *
import torch
from lpcmdp.algorithm.model import *
from torch.utils.data import DataLoader
import itertools
from tqdm import tqdm
from lpcmdp.env.Taxi import Taxi_nocost
from lpcmdp.env.datacollector import Taxi_DataCollector
env_nocost = Taxi_nocost()
exp = ValueIteration(env_nocost, 0.01, env_nocost.gamma)
exp.value_iteration()
# plot_policy(np.array(exp.pi), env_nocost, "w")
offline_dataset_collector = Taxi_DataCollector(env_nocost, num_episodes=1000, percent=0.5, exper_pi=exp.pi)
# print(offline_dataset_collector.get_original_dataset())
from lpcmdp.algorithm.coptidice import CoptidiceTrainer
from lpcmdp.algorithm.importance_sampling_solver import ImportanceSamplingApproximateSolver
# coptidice_trainer = CoptidiceTrainer(env_nocost, offline_dataset_collector)
# coptidice_policy = coptidice_trainer.train()
# env_nocost.plot_policy(coptidice_policy)
# print(coptidice_trainer.get_logger())
print("Importance Sampling Approximate Solver Training...")
approximatesolver = ImportanceSamplingApproximateSolver(env_nocost, offline_dataset_collector, 0, 0.1, 0.0001, behavior_policy_style='real')
approximatepolicy = approximatesolver.train()
print(approximatesolver.get_logger()['test_reward'])
print(approximatesolver.get_logger()['test_cost'])
ours_reward = approximatesolver.get_logger()['test_reward'][-1]
ours_cost = approximatesolver.get_logger()['test_cost'][-1]
print("reward: ", ours_reward, "cost: ", ours_cost)

print("Coptidice Training...")
coptidice_trainer = CoptidiceTrainer(env_nocost, offline_dataset_collector)
coptidice_policy = coptidice_trainer.train()
print(coptidice_trainer.get_logger()['test_reward'])
print(coptidice_trainer.get_logger()['test_cost'])
coptidice_reward = coptidice_trainer.get_logger()['test_reward'][-1]
coptidice_cost = coptidice_trainer.get_logger()['test_cost'][-1]
print("reward: ", coptidice_reward, "cost: ", coptidice_cost)
