import gym
from PG_utils import *
import matplotlib.pyplot as plt
import os


class PolicyNet(nn.Module):
    def __init__(self, n_state, n_action):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_state, n_action)

    def forward(self, x):
        x = torch.softmax(self.fc1(x), dim=1)
        return x


def wrap(s):
    state = np.zeros(64)
    state[s] = 1
    return state


def collect_data(env, pol, num_episode, n_state):
    states = []
    actions = []
    rewards = []

    for e in range(num_episode):
        epi_states = []
        epi_actions = []
        epi_rewards = []
        done = False

        state = env.reset()
        state = wrap(state)

        while not done:
            action = get_action(pol, state, n_state)
            next_state, reward, done, _ = env.step(action)
            next_state = wrap(next_state)

            epi_states.append(state)
            epi_actions.append(float(action))
            epi_rewards.append(reward)

            state = next_state

        rewards.append(epi_rewards)
        states.append(epi_states)
        actions.append(epi_actions)

    return states, actions, rewards


def PG(env, num_episode, lr, gamma, N, n_state, n_action):
    score_his = []

    policy_net = PolicyNet(n_state, n_action)
    optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)

    for e in range(num_episode):
        states, actions, rewards = collect_data(env, policy_net, N, n_state)
        scores = calculate_scores(rewards, gamma)

        states = process(states, n_state)
        actions = process(actions)
        scores = process(scores)

        score_his += get_cumu_discounted_rewards(rewards, 1)
        print_score(score_his)

        optimizer.zero_grad()
        probs = policy_net(Variable(torch.from_numpy(states).float()))
        selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
        loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / N
        loss.backward()
        optimizer.step()

    return score_his
