from lpcmdp.algorithm.utils import *
import torch
from lpcmdp.algorithm.model import *
import itertools
import OSRL
from tqdm import tqdm
from lpcmdp.env.FrozenLake import FrozenLakeEnv, FrozenLakeEnv_nocost
from lpcmdp.env.datacollector import DataCollector
from lpcmdp.algorithm.importance_sampling_solver import ImportanceSamplingDiscreteSolver, ImportanceSamplingApproximateSolver
from lpcmdp.algorithm.coptidice import CoptidiceTrainer
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.backends.backend_pdf import PdfPages


matplotlib.use('PDF')
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

def get_goalrate(seed, datasize):
    seed_all(seed)
    env_nocost = FrozenLakeEnv_nocost(ncol=8, nrow=8)
    exp = ValueIteration(env_nocost, 0.01, env_nocost.gamma)
    exp.value_iteration()
    # plot_policy(np.array(exp.pi), env_nocost, "w")
    env = FrozenLakeEnv(ncol=8, nrow=8)

    offline_dataset_collector = DataCollector(env, expert_pi=exp.pi, percent=0.5, num_trajectories=datasize)
    approximatesolver = ImportanceSamplingApproximateSolver(env, offline_dataset_collector, 0, 0.1, 0.0001, behavior_policy_style='real')
    approximatepolicy = approximatesolver.train()
    print(approximatesolver.get_logger()['test_reward'])
    print(approximatesolver.get_logger()['test_cost'])
    ours_reward = approximatesolver.get_logger()['test_reward'][-1]
    ours_cost = approximatesolver.get_logger()['test_cost'][-1]
    
    coptidice_trainer = CoptidiceTrainer(env, offline_dataset_collector, behavior_policy_style='behavior_clone', epochs=100000)
    coptidice_trainer.train()
    print(coptidice_trainer.get_logger()['test_reward'])
    print(coptidice_trainer.get_logger()['test_cost'])
    coptidice_reward = coptidice_trainer.get_logger()['test_reward'][-1]
    coptidice_cost = coptidice_trainer.get_logger()['test_cost'][-1]
    
    return ours_reward, ours_cost, coptidice_reward, coptidice_cost

seed = [567895, 5665251, 0, 1, 2, 3, 6]   #[2, 3, 4, 5, 7]
datasizes = [20, 100, 200, 500, 1000, 2000, 20000] #[2, 20, 200, 500, 1000, 2000, 20000] 
ours_reward_list, ours_cost_list, coptidice_reward_list, coptidice_cost_list = [], [], [], []
# coptidice_goalrate_bc = []
print(1)

with tqdm (total=len(datasizes)*len(seed)) as pbar:
    for datasize in datasizes:
        ours_reward_list.append([])
        ours_cost_list.append([])
        coptidice_reward_list.append([])
        coptidice_cost_list.append([])
        for s in seed:
            ours_reward, ours_cost, coptidice_reward, coptidice_cost = get_goalrate(s, datasize)
            ours_reward_list[-1].append(ours_reward)
            ours_cost_list[-1].append(ours_cost)
            coptidice_reward_list[-1].append(coptidice_reward)
            coptidice_cost_list[-1].append(coptidice_cost)
            pbar.update(1)
            
avg_ours_cost = np.mean(ours_cost_list, axis=1).squeeze()
avg_coptidice_cost = np.mean(coptidice_cost_list, axis=1).squeeze()

std_ours_cost = np.std(ours_cost_list, axis=1).squeeze()
std_coptidice_cost = np.std(coptidice_cost_list, axis=1).squeeze()

ours_cost_lower_bound = avg_ours_cost - std_ours_cost
ours_cost_upper_bound = avg_ours_cost + std_ours_cost
coptidice_cost_lower_bound = avg_coptidice_cost - std_coptidice_cost
coptidice_cost_upper_bound = avg_coptidice_cost + std_coptidice_cost

plt.figure()
plt.plot(datasizes, avg_ours_cost, label='Ours', color='red')
plt.fill_between(datasizes, ours_cost_lower_bound, ours_cost_upper_bound, color='red', alpha=0.2)
plt.plot(datasizes, avg_coptidice_cost, label='Coptidice', color='blue')
plt.fill_between(datasizes, coptidice_cost_lower_bound, coptidice_cost_upper_bound, color='blue', alpha=0.2)
plt.xscale('log')
plt.xlabel('Datasize')
plt.ylabel('Cost')
plt.legend()
plt.savefig('cost.pdf')


ours_goalrate = np.array(ours_goalrate)
coptidice_goalrate = np.array(coptidice_goalrate)
ours_goalrate_bc = np.array(ours_goalrate_bc)
ours_holerate = np.array(ours_holerate)
ours_holerate_bc = np.array(ours_holerate_bc)
coptidice_holerate = np.array(coptidice_holerate)
# coptidice_goalrate_bc = np.array(coptidice_goalrate_bc)

avg_ours_goalrate = np.mean(ours_goalrate, axis=1)
half_std_ours_goalrate = np.std(ours_goalrate, axis=1) / 2.0
avg_ours_goalrate_bc = np.mean(ours_goalrate_bc, axis=1)
half_std_ours_goalrate_bc = np.std(ours_goalrate_bc, axis=1) / 2.0
avg_coptidice_goalrate = np.mean(coptidice_goalrate, axis=1)
half_std_coptidice_goalrate = np.std(coptidice_goalrate, axis=1) / 2.0
avg_ours_holerate = np.mean(ours_holerate, axis=1)
half_std_ours_holerate = np.std(ours_holerate, axis=1) / 2.0
avg_ours_holerate_bc = np.mean(ours_holerate_bc, axis=1)
half_std_ours_holerate_bc = np.std(ours_holerate_bc, axis=1) / 2.0
avg_coptidice_holerate = np.mean(coptidice_holerate, axis=1)
half_std_coptidice_holerate = np.std(coptidice_holerate, axis=1) / 2.0
# avg_coptidice_goalrate_bc = np.mean(coptidice_goalrate_bc, axis=1)
# half_std_coptidice_goalrate_bc = np.std(coptidice_goalrate_bc, axis=1) / 2.0

lower_bound_ours_goalrate = avg_ours_goalrate - half_std_ours_goalrate
upper_bound_ours_goalrate = avg_ours_goalrate + half_std_ours_goalrate
lower_bound_ours_goalrate_bc = avg_ours_goalrate_bc - half_std_ours_goalrate_bc
upper_bound_ours_goalrate_bc = avg_ours_goalrate_bc + half_std_ours_goalrate_bc
lower_bound_coptidice_goalrate = avg_coptidice_goalrate - half_std_coptidice_goalrate
upper_bound_coptidice_goalrate = avg_coptidice_goalrate + half_std_coptidice_goalrate
lower_bound_ours_holerate = avg_ours_holerate - half_std_ours_holerate
upper_bound_ours_holerate = avg_ours_holerate + half_std_ours_holerate
lower_bound_ours_holerate_bc = avg_ours_holerate_bc - half_std_ours_holerate_bc
upper_bound_ours_holerate_bc = avg_ours_holerate_bc + half_std_ours_holerate_bc
lower_bound_coptidice_holerate = avg_coptidice_holerate - half_std_coptidice_holerate
upper_bound_coptidice_holerate = avg_coptidice_holerate + half_std_coptidice_holerate
# lower_bound_coptidice_goalrate_bc = avg_coptidice_goalrate_bc - half_std_coptidice_goalrate_bc
# upper_bound_coptidice_goalrate_bc = avg_coptidice_goalrate_bc + half_std_coptidice_goalrate_bc


plt.figure()
plt.plot(datasizes, avg_ours_goalrate_bc, label='Alg', color='b')
plt.fill_between(datasizes, lower_bound_ours_goalrate_bc, upper_bound_ours_goalrate_bc, color='b', alpha=0.3)
plt.plot(datasizes, avg_coptidice_goalrate, label='COptiDICE', color='r')
plt.fill_between(datasizes, lower_bound_coptidice_goalrate, upper_bound_coptidice_goalrate, color='r', alpha=0.3)
plt.plot(datasizes, avg_ours_goalrate, label='Alg-BC', color='g')
plt.fill_between(datasizes, lower_bound_ours_goalrate, upper_bound_ours_goalrate, color='g', alpha=0.3)

# plt.plot(datasizes, avg_coptidice_goalrate_bc, label='coptidice_bc', color='y')
# plt.fill_between(datasizes, lower_bound_coptidice_goalrate_bc, upper_bound_coptidice_goalrate_bc, color='y', alpha=0.3)

plt.xscale('log')
plt.xlabel('Trajectory')
plt.ylabel('Reward')
plt.legend()
# plt.show()
plt.savefig('datasize_reward.png')


plt.figure()
plt.plot(datasizes, avg_ours_holerate_bc, label='Alg', color='c')
plt.fill_between(datasizes, lower_bound_ours_holerate_bc, upper_bound_ours_holerate_bc, color='c', alpha=0.3)
plt.plot(datasizes, avg_ours_holerate, label='Alg_BC', color='m')
plt.fill_between(datasizes, lower_bound_ours_holerate, upper_bound_ours_holerate, color='m', alpha=0.3)
plt.plot(datasizes, avg_coptidice_holerate, label='COptiDICE', color='y')
plt.fill_between(datasizes, lower_bound_coptidice_holerate, upper_bound_coptidice_holerate, color='y', alpha=0.3)

plt.xscale('log')
plt.xlabel('Trajectory')
plt.ylabel('Cost')
plt.legend()
plt.savefig('datasize_cost.png')

