import gym
from TSIVR_PG_utils import *
import matplotlib.pyplot as plt
import os


class PolicyNet(nn.Module):
    def __init__(self, n_state, n_action):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_state, n_action)

    def forward(self, x):
        x = torch.softmax(self.fc1(x), dim=1)
        return x


def wrap(s):
    state = np.zeros(64)
    state[s] = 1
    return state


def collect_data(env, pol, num_episode, n_state):
    states = []
    actions = []
    rewards = []

    for e in range(num_episode):
        epi_states = []
        epi_actions = []
        epi_rewards = []
        done = False
        state = env.reset()
        state = wrap(state)

        while not done:
            action = get_action(pol, state, n_state)
            next_state, reward, done, _ = env.step(action)
            next_state = wrap(next_state)

            epi_states.append(state)
            epi_actions.append(float(action))
            epi_rewards.append(reward)

            state = next_state

        rewards.append(epi_rewards)
        states.append(epi_states)
        actions.append(epi_actions)

    return states, actions, rewards


def TSIVR_PG(env, num_episode, lr, gamma, N, m, B, n_state, n_action, delta=0.1):
    score_his = []

    policy_net = PolicyNet(n_state, n_action)
    policy_reference_net = PolicyNet(n_state, n_action)
    optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
    optimizer_ref = torch.optim.Adam(policy_reference_net.parameters(), lr=lr)

    for e in range(num_episode):
        set_flat_params_to(policy_reference_net, get_flat_params_from(policy_net))

        states, actions, rewards = collect_data(env, policy_net, N, n_state)
        scores = calculate_scores(rewards, gamma)

        states = process(states, n_state)
        actions = process(actions)
        scores = process(scores)

        score_his += get_cumu_discounted_rewards(rewards, 1)
        print_score(score_his)

        optimizer.zero_grad()
        probs = policy_net(Variable(torch.from_numpy(states).float()))
        selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
        loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / N
        loss.backward()
        base_grad = get_flat_grads_from(policy_net)

        for t in range(m):
            states, actions, rewards = collect_data(env, policy_net, B, n_state)
            ratios = calculate_w(states, actions, policy_reference_net, policy_net, n_state)
            scores_orig = calculate_scores(rewards, gamma)
            scores = calculate_scores_w(rewards, gamma, ratios)

            states = process(states, n_state)
            actions = process(actions)
            scores_orig = process(scores_orig)
            scores = process(scores)

            score_his += get_cumu_discounted_rewards(rewards, 1)

            optimizer.zero_grad()
            probs = policy_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores_orig)).sum() / B
            loss.backward()
            grad1 = get_flat_grads_from(policy_net)

            optimizer_ref.zero_grad()
            probs = policy_reference_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / B
            loss.backward()
            grad2 = get_flat_grads_from(policy_reference_net)

            grad = base_grad + grad1 - grad2

            base_grad = grad
            set_flat_params_to(policy_reference_net, get_flat_params_from(policy_net))

            if lr * torch.norm(grad) >= delta:
                truncated_grad = grad / (torch.norm(grad) * lr) * delta
            else:
                truncated_grad = grad

            set_flat_grads_to(policy_net, truncated_grad)
            optimizer.step()

    return score_his
