import argparse
from portfolio_environment import environment
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

lmbd = 0.001

parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='interval between training status logs (default: 10)')
parser.add_argument('--log_path', type=str, default='results/temp',
                    help='log path')
parser.add_argument('--env_name', type=str, default='CartPole-v1',
                    help='atari game name')
parser.add_argument('--reg_param', type=float, default=1,
                    help='reg param')
parser.add_argument('--trials', type=int, default=100,
                    help='num trials')
args = parser.parse_args()


env = environment()
torch.manual_seed(args.seed)


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(env.W+2, env.W+2)
        self.affine2 = nn.Linear(env.W+2, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)


def select_action(state, policy):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()


def finish_episode(model, optimizer, eps, i_episode, y):
    """
    Training code. Calculates actor and critic loss and performs backprop.
    """

    param = 1/(i_episode + 1)

    R = 0
    policy_loss = []
    for r in model.rewards[::-1]:
        R = r + R

    y += 1e-2*param*(2*R**2 - 1/lmbd - 2*y)

    for log_prob in model.saved_log_probs:
        policy_loss.append(-log_prob * (2*y*R - R**2))
    optimizer.zero_grad()
    policy_loss = param*torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del model.rewards[:]
    del model.saved_log_probs[:]

    return y


def main():
    ep_reward_list_train = []
    ep_reward_list_test = []

    for tr in range(1000):

        running_ep_reward = []

        policy = Policy()
        optimizer = optim.Adam(policy.parameters(), lr=1e-2)
        eps = np.finfo(np.float32).eps.item()

        y = 1

        for i_episode in range(1000):
            state, ep_reward = env.reset(), 0
            for t in range(1, 50):  # Don't infinite loop while learning
                action = select_action(state, policy)
                state, reward = env.step(action)

                policy.rewards.append(reward)
                ep_reward += reward

            running_ep_reward.append(ep_reward)
            y = finish_episode(policy, optimizer, eps, i_episode, y)

            # log results
            if i_episode % args.log_interval == 0:
                print('Episode {}\tLast reward: {:.2f}'.format(
                    i_episode, ep_reward))

        ep_reward_list_train.append(running_ep_reward)
        np.savetxt('results/equm_xie_train%2f.csv' % lmbd,
                   ep_reward_list_train)

        ep_reward_list_test_temp = []

        for i_episode in range(100):

            state = env.reset()
            ep_reward = 0

            for t in range(1, 50):

                # select action from policy
                action = select_action(state, policy)

                # take the action
                state, reward = env.step(action)
                policy.rewards.append(reward)
                ep_reward += reward

            ep_reward_list_test_temp.append(ep_reward)

        ep_reward_list_test.append(ep_reward_list_test_temp)
        np.savetxt('results/equm_xie_test%2f.csv' %
                   lmbd, ep_reward_list_test)


if __name__ == '__main__':
    main()
