from PG_utils import *

def collect_data(env, pol, num_episode, horizon, eps=0, verbal=False):
    states = []
    actions = []
    next_states = []
    rewards = []

    bar = range(num_episode) if verbal else range(num_episode)
    for _ in bar:
        epi_states = []
        epi_actions = []
        epi_next_states = []
        epi_rewards = []

        state = env.reset()
        for h in range(horizon):
            action = pol.get_action(state, eps)
            next_state, reward, done, _ = env.step(action)

            epi_states.append(state)
            epi_actions.append(float(action))
            epi_next_states.append(next_state)
            epi_rewards.append(reward)

            state = next_state
            if done:
                break

        states.append(epi_states)
        actions.append(epi_actions)
        next_states.append(epi_next_states)
        rewards.append(epi_rewards)

    return states, actions, next_states, rewards


def PG(policy_net, optimizer, env, num_episode, gamma, N, H):
    score_his = []
    result = []
    bar = range(num_episode)
    for _ in bar:
        states, actions, _, rewards = collect_data(env, policy_net, N, H)

        score_his += get_cumu_discounted_rewards(rewards, 1)
        result.append(print_score(score_his))

        optimizer.zero_grad()
        sample_gradient(policy_net, states, actions, rewards, gamma)
        optimizer.step()

    return result


def sample_gradient(policy_net, states, actions, rewards, gamma):
    N = len(states)

    scores = process(calculate_scores(rewards, gamma))
    states = process(states, policy_net.n_state)
    actions = process(actions)

    policy_net.zero_grad()
    probs = policy_net(Variable(torch.from_numpy(states).float()))
    selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
    loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / N
    loss.backward()

    return -get_flat_grads_from(policy_net)
