from lpcmdp.env.FrozenLake import *
import torch
from tqdm import tqdm
from lpcmdp.algorithm.model import *
from lpcmdp.algorithm.utils import *

env = FrozenLakeEnv_nocost(ncol=4, nrow=4)

expert_percent = 0.5

lr = 1e-4
num_episodes = 2000
hidden_dim = 64
gamma = 0.9
epsilon = 0.5
target_update = 50
buffer_size = 5000
minimal_size = 1000
batch_size = 32
device = 'cpu'

state_dim = env.state_size
action_dim = 4

def train_DDQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size):
    return_list = []
    max_q_value = 0
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                observation = env.reset()
                state = Encode(env, observation, type='state')
                done = False
                while not done:
                    action = agent.take_action(state)
                    # print(observation, action)
                    max_q_value = agent.max_q_value(state) * 0.005 + max_q_value * 0.995
                    next_observation, reward, cost, goal, hole = env.step[observation][action]
                    next_state = Encode(env, next_observation)
                    done = goal or hole
                    replay_buffer.add(state, action, reward, next_state, done)
                    observation = next_observation
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {
                            'states': b_s,
                            'actions': b_a,
                            'rewards': b_r,
                            'next_states': b_ns,
                            'dones': b_d
                        }
                        agent.update(transition_dict)
                    if goal == True:
                        return_list.append(reward)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({
                        'episode': '%d' % (num_episodes * i / 10 + i_episode + 1),
                        'return': '%.3f' % np.sum(return_list)
                    })
                pbar.update(1)
    return return_list

replay_buffer = ReplayBuffer(buffer_size)
agent = DDQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update)
_ = train_DDQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size)
test(env, type='DDQN', para=agent, random=True)

