import gym
from utils import *
import matplotlib.pyplot as plt
import os


class PolicyNet(nn.Module):
    def __init__(self, n_state, n_action):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_state, n_action)

    def forward(self, x):
        x = torch.softmax(self.fc1(x), dim=1)
        return x


def wrap(s):
    state = np.zeros(64)
    state[s] = 1
    return state


def collect_data(env, pol, num_episode, n_state):
    states = []
    actions = []
    rewards = []

    for e in range(num_episode):
        epi_states = []
        epi_actions = []
        epi_rewards = []
        done = False

        state = env.reset()
        state = wrap(state)

        while not done:
            action = get_action(pol, state, n_state)
            next_state, reward, done, _ = env.step(action)
            next_state = wrap(next_state)

            epi_states.append(state)
            epi_actions.append(float(action))
            epi_rewards.append(reward)

            state = next_state

        rewards.append(epi_rewards)
        states.append(epi_states)
        actions.append(epi_actions)

    return states, actions, rewards


def cal_lam(states, gamma):
    res = 0
    l = len(states)
    for epi_states in states:
        temp = 0
        k = 0
        for s in epi_states:
            temp = gamma**k * s + temp
            k = k + 1
        res = res + temp / l
    return res


def cal_lam_IP(states, gamma, ratios):
    res = 0
    l = len(states)
    for i in range(l):
        epi_states = states[i]
        epi_ratios = ratios[i]
        temp = 0
        m = len(epi_states)
        for j in range(m):
            temp = gamma ** j * epi_states[j] * epi_ratios[j] + temp
        res = res + temp / l
    return res


def reward_fn(lam):
    return 1/(lam + 1/8)


def convert_rew(lam, states):
    res = []
    fn = reward_fn(lam)
    for epi_states in states:
        epi_res = []
        for s in epi_states:
            ind = np.argmax(s)
            epi_res.append(fn[ind])
        res.append(epi_res)
    return res


def SIVR_PG(env, num_episode, lr, gamma, N, m, B, n_state, n_action):
    score_his = []

    policy_net = PolicyNet(n_state, n_action)
    policy_reference_net = PolicyNet(n_state, n_action)
    optimizer = torch.optim.RMSprop(policy_net.parameters(), lr=lr)
    optimizer_ref = torch.optim.RMSprop(policy_reference_net.parameters(), lr=lr)

    for e in range(num_episode):
        params = get_flat_params_from(policy_net)
        set_flat_params_to(policy_reference_net, params)

        states, actions, _ = collect_data(env, policy_net, N, n_state)
        lam = cal_lam(states, gamma)
        lam_prev = lam
        lam_prev_prev = lam

        rewards = convert_rew(lam, states)
        scores = calculate_scores(rewards, gamma)

        score_his += get_entropy(states, gamma)
        print_score(score_his)

        states = process(states, n_state)
        actions = process(actions)
        scores = process(scores)

        optimizer.zero_grad()
        probs = policy_net(Variable(torch.from_numpy(states).float()))
        selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
        loss = (-torch.log(selected_probs) * torch.from_numpy(scores)).sum() / N
        loss.backward()
        base_grad = get_flat_grads_from(policy_net)

        for t in range(m):
            states, actions, _ = collect_data(env, policy_net, B, n_state)
            ratios = calculate_IP(states, actions, policy_reference_net, policy_net, n_state)
            lam1 = cal_lam(states, gamma)
            lam2 = cal_lam_IP(states, gamma, ratios)
            lam = lam + lam1 - lam2
            rewards_prev = convert_rew(lam_prev, states)
            rewards_prev_prev = convert_rew(lam_prev_prev, states)
            scores_orig = calculate_scores(rewards_prev, gamma)
            scores = calculate_scores_IP(rewards_prev_prev, gamma, ratios)

            lam_prev_prev = lam_prev
            lam_prev = lam

            score_his += get_entropy(states, gamma)

            states = process(states, n_state)
            actions = process(actions)
            scores = process(scores)
            scores_orig = process(scores_orig)
            ratios = process(ratios)

            optimizer.zero_grad()
            probs = policy_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores_orig)).sum() / B
            loss.backward()
            grad1 = get_flat_grads_from(policy_net)

            optimizer_ref.zero_grad()
            probs = policy_reference_net(Variable(torch.from_numpy(states).float()))
            selected_probs = probs.gather(1, Variable(torch.from_numpy(actions).long()))
            loss = (-torch.log(selected_probs) * torch.from_numpy(scores) * torch.from_numpy(ratios)).sum() / B
            loss.backward()
            grad2 = get_flat_grads_from(policy_net)
            grad = base_grad + grad1 - grad2

            base_grad = grad

            set_flat_grads_to(policy_net, grad)
            optimizer.step()

    return score_his


if __name__ == '__main__':
    learning_rate = 0.09
    eps = 200
    gamma = 0.99

    env = gym.make('FrozenLake8x8-v0')
    n_state = 64
    n_action = env.action_space.n

    gaps = []
    seeds = range(20)
    N = 100
    B = int(np.sqrt(N))
    m = int(np.sqrt(N))
    res = []
    for i in range(20):
        print('N=' + str(N) + ' i=' + str(i))
        torch.manual_seed(seeds[i])
        np.random.seed(seeds[i])
        env.seed(seeds[i])
        res.append(SIVR_PG(env, eps, learning_rate, gamma, N, B, m, n_state, n_action))

    res = np.array(res)
    filepath = './results/lr' + str(learning_rate).replace('.', '_') + '_eps_' + str(eps)
    filename = filepath + '/N=' + str(N)
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    np.save(filename, res)
