# Solution of Open AI gym environment "Cartpole-v0" (https://gym.openai.com/envs/CartPole-v0) using DQN and Pytorch.
# It is is slightly modified version of Pytorch DQN tutorial from
# http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html.
# The main difference is that it does not take rendered screen as input but it simply uses observation values from the \
# environment.
import wandb
import gym
from gym import wrappers
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models import NonLinearModel, LinearModel
import numpy as np


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class Network(nn.Module):
    def __init__(self, activation_function):
        nn.Module.__init__(self)
        self.l1 = nn.Linear(4, HIDDEN_LAYER)
        self.l2 = nn.Linear(HIDDEN_LAYER, HIDDEN_LAYER)
        self.l3 = nn.Linear(HIDDEN_LAYER, 2)
        if activation_function == "relu":
            self.activation_function = F.relu
        elif activation_function == "tanh":
            self.activation_function = torch.tanh

    def forward(self, x):
        x = self.activation_function(self.l1(x))
        x = self.activation_function(self.l2(x))
        x = self.l3(x)
        return x


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
        -1.0 * steps_done / EPS_DECAY
    )
    steps_done += 1
    if sample > eps_threshold:
        return model(Variable(state).type(FloatTensor)).data.max(1)[1].view(1, 1)
    else:
        return LongTensor([[random.randrange(2)]])


def run_episode(e, environment):
    global CAN_STEP
    state = environment.reset()
    steps = 0
    total_reward = 0
    while True:
        action = select_action(FloatTensor([state]))
        next_state, reward, done, _ = environment.step(action[0, 0].item())

        # negative reward when attempt ends
        if done:
            reward = -1

        memory.push(
            (
                FloatTensor([state]),
                action,  # action is already a tensor
                FloatTensor([next_state]),
                FloatTensor([reward]),
            )
        )

        learn()

        state = next_state
        steps += 1
        if e > REDUCE_LR_EP:
            if CAN_STEP == True:
                lr_scheduler.step()
                print(lr_scheduler.get_last_lr())
                CAN_STEP = False
        total_reward += reward
        if done:
            print(
                "{2} Episode {0} finished after {1} steps".format(
                    e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                )
            )
            episode_durations.append(steps)
            wandb.log({"return": total_reward})
            # plot_durations()
            break


def learn():
    if len(memory) < BATCH_SIZE:
        return

    # random transition batch is taken from experience replay memory
    transitions = memory.sample(BATCH_SIZE)
    batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)

    batch_state = Variable(torch.cat(batch_state))
    batch_action = Variable(torch.cat(batch_action))
    batch_reward = Variable(torch.cat(batch_reward))
    batch_next_state = Variable(torch.cat(batch_next_state))

    # current Q values are estimated by NN for all actions
    current_q_values = model(batch_state).gather(1, batch_action)
    # expected Q values are estimated from actions which gives maximum Q value
    max_next_q_values = model(batch_next_state).detach().max(1)[0]
    expected_q_values = batch_reward + (GAMMA * max_next_q_values)

    # loss is measured from error between current and newly expected Q values
    loss = F.smooth_l1_loss(current_q_values[:, 0], expected_q_values)

    # backpropagation of loss to NN
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.FloatTensor(episode_durations)
    plt.title("Training...")
    plt.xlabel("Episode")
    plt.ylabel("Duration")
    plt.plot(durations_t.numpy())
    # take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated


import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--EPISODES", type=int, default=30)
parser.add_argument("--EPS_START", type=float, default=0.9)
parser.add_argument("--EPS_END", type=float, default=0.05)
parser.add_argument("--EPS_DECAY", type=int, default=200)
parser.add_argument("--GAMMA", type=float, default=0.8)
parser.add_argument("--LR", type=float, default=1e-3)
parser.add_argument("--HIDDEN_LAYER", type=int, default=1024)
parser.add_argument("--BATCH_SIZE", type=int, default=64)
parser.add_argument("--ACTIVATION_FUNCTION", type=str, default="relu")
parser.add_argument("--REDUCE_LR_EP", type=int, default=15)
args = parser.parse_args()

# hyper parameters
EPISODES = args.EPISODES  # number of episodes
EPS_START = args.EPS_START  # e-greedy threshold start value
EPS_END = args.EPS_END  # e-greedy threshold end value
EPS_DECAY = args.EPS_DECAY  # e-greedy threshold decay
GAMMA = args.GAMMA  # Q-learning discount factor
LR = args.LR  # NN optimizer learning rate
HIDDEN_LAYER = args.HIDDEN_LAYER  # NN hidden layer size
BATCH_SIZE = args.BATCH_SIZE  # Q-learning batch size
NUMBER_OF_REPEATS = 1
ACTIVATION_FUNCTION = args.ACTIVATION_FUNCTION
REDUCE_LR_EP = args.REDUCE_LR_EP
CAN_STEP = True
wandb.init(project="DQN cartpole", config=args)

# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


backwards_model_canonical = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
backwards_model_canonical.eval()


backwards_model_lens = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
backwards_model_lens.eval()

forwards_model_canonical = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
forwards_model_canonical.eval()


forwards_model_lens = LinearModel(hidden_size=HIDDEN_LAYER, input_size=4, output_size=4)
forwards_model_lens.eval()

episodes_taken = []
episodes_working_list = []
for repeat in range(NUMBER_OF_REPEATS):
    env = gym.make("CartPole-v0")
    model = Network(ACTIVATION_FUNCTION)
    episodes_working = 0
    if use_cuda:
        model.cuda()
    memory = ReplayMemory(10000)
    optimizer = optim.Adam(model.parameters(), LR)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    lr_scheduler.step()
    CAN_STEP = True
    steps_done = 0
    episode_durations = []

    for e in range(EPISODES):
        run_episode(e, env)

    searching_for_episodes_to_solve = True
    for index, an_episode_duration in enumerate(episode_durations):
        wandb.log({"episode length": an_episode_duration, "episode": index})
        if an_episode_duration > 190:
            episodes_working += 1
        if searching_for_episodes_to_solve:
            if an_episode_duration > 175:
                episodes_taken.append(index)
                searching_for_episodes_to_solve = False

    episodes_working_list.append(episodes_working)
    if len(episodes_taken) != repeat + 1:
        episodes_taken.append(len(episode_durations))
    env.close()

average_episodes_taken = np.mean(episodes_taken)
average_episodes_working = np.mean(episodes_working_list)
wandb.log({"average_episodes_taken": average_episodes_taken})
wandb.log({"average_episodes_working": average_episodes_working})
