import gym
from utils import *
import matplotlib.pyplot as plt
import os


class PolicyNet(nn.Module):
    def __init__(self, n_state, n_action):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_state, n_action)

    def forward(self, x):
        x = torch.softmax(self.fc1(x), dim=1)
        return x


def wrap(s):
    state = np.zeros(64)
    state[s] = 1
    return state


def collect_data(env, pol, num_episode, n_state):
    states = []
    actions = []
    rewards = []

    for e in range(num_episode):
        epi_states = []
        epi_actions = []
        epi_rewards = []
        done = False

        state = env.reset()
        state = wrap(state)

        while not done:
            action = get_action(pol, state, n_state)
            next_state, reward, done, _ = env.step(action)
            next_state = wrap(next_state)

            epi_states.append(state)
            epi_actions.append(float(action))
            epi_rewards.append(reward)

            state = next_state

        rewards.append(epi_rewards)
        states.append(epi_states)
        actions.append(epi_actions)

    return states, actions, rewards


def TSIVR_PG(env, num_episode, lr, gamma, N, m, B, n_state, n_action):
    score_his = []

    policy_net = PolicyNet(n_state, n_action)
    policy_reference_net = PolicyNet(n_state, n_action)
    optimizer = torch.optim.RMSprop(policy_net.parameters(), lr=lr)
    optimizer2 = torch.optim.RMSprop(policy_net.parameters(), lr=lr)
    optimizer_ref = torch.optim.RMSprop(policy_reference_net.parameters(), lr=lr)

    for e in range(num_episode):
        set_flat_params_to(policy_reference_net, get_flat_params_from(policy_net))

        states, actions, rewards = collect_data(env, policy_net, N, n_state)
        scores = calculate_scores(rewards, gamma)

        states = process(states, n_state)
        actions = process(actions)
        scores = process(scores)

        score_his += get_cumu_discounted_rewards(rewards, gamma)

        optimizer.zero_grad()
        probs = policy_net(Variable(torch.from_numpy(states).float()))
        selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
        loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / N
        loss.backward()
        base_grad = get_flat_grads_from(policy_net)
        optimizer2.step()

        for t in range(m):
            states, actions, rewards = collect_data(env, policy_net, B, n_state)
            ratios = calculate_w(states, actions, policy_reference_net, policy_net, n_state)
            scores_orig = calculate_scores(rewards, gamma)
            scores = calculate_scores_w(rewards, gamma, ratios)

            states = process(states, n_state)
            actions = process(actions)
            scores_orig = process(scores_orig)
            scores = process(scores)

            score_his += get_cumu_discounted_rewards(rewards, gamma)

            optimizer.zero_grad()
            probs = policy_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores_orig)).sum() / B
            loss.backward()
            grad1 = get_flat_grads_from(policy_net)

            optimizer_ref.zero_grad()
            probs = policy_reference_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / B
            loss.backward()

            grad2 = get_flat_grads_from(policy_net)
            grad = base_grad + grad1 - grad2
            base_grad = grad

            set_flat_grads_to(policy_net, grad)
            optimizer.step()

    return score_his


if __name__ == '__main__':
    learning_rate = 0.01
    eps = 200
    gamma = 0.99

    env = gym.make('FrozenLake8x8-v0')
    n_state = 64
    n_action = env.action_space.n

    seeds = range(50)
    Ns = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 400, 500]
    for N in Ns:
        B = int(np.sqrt(N))
        m = int(np.sqrt(N))
        res = []
        for i in range(50):
            print('N=' + str(N) + ' i=' + str(i))
            torch.manual_seed(seeds[i])
            np.random.seed(seeds[i])
            env.seed(seeds[i])
            res.append(TSIVR_PG(env, eps, learning_rate, gamma, N, B, m, n_state, n_action))

        res = np.array(res)
        filepath = './lr' + str(learning_rate).replace('.', '_') + '_eps_' + str(eps)
        filename = filepath + '/N=' + str(N)
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        np.save(filename, res)
