# Solution of Open AI gym environment "Cartpole-v0" (https://gym.openai.com/envs/CartPole-v0) using DQN and Pytorch.
# It is is slightly modified version of Pytorch DQN tutorial from
# http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html.
# The main difference is that it does not take rendered screen as input but it simply uses observation values from the \
# environment.
import wandb
import gym
from gym import wrappers
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models import NonLinearModel, LinearModel
import numpy as np
from copy import deepcopy
from tqdm import tqdm


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def sample_specific_action(self, batch_size, desired_action):
        filtered_data = []
        for a_datum in self.memory:
            obs_current, action, obs_next, reward = a_datum
            if action == desired_action:
                filtered_data.append(a_datum)
            if len(filtered_data) == batch_size:
                return filtered_data

    def __len__(self):
        return len(self.memory)


class Network(nn.Module):
    def __init__(self, activation_function):
        nn.Module.__init__(self)
        self.l1 = nn.Linear(4, HIDDEN_LAYER)
        self.l2 = nn.Linear(HIDDEN_LAYER, HIDDEN_LAYER)
        self.l3 = nn.Linear(HIDDEN_LAYER, 1)
        if activation_function == "relu":
            self.activation_function = F.relu
        elif activation_function == "tanh":
            self.activation_function = torch.tanh

    def forward(self, x):
        x = self.activation_function(self.l1(x))
        x = self.activation_function(self.l2(x))
        x = self.l3(x)
        return x


def select_action_lens(state):
    state = Variable(state).type(FloatTensor)
    value_0 = model(state)[0][0]
    state_tp1 = forwards_model_lens(state)
    equivalent_state = backwards_model_canonical(state_tp1)
    value_1 = model(equivalent_state)[0][0]
    values = [value_0.detach().cpu().numpy(), value_1.detach().cpu().numpy()]
    action = np.argmax(values)
    return LongTensor([[action]])


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
        -1.0 * steps_done / EPS_DECAY
    )
    steps_done += 1
    if sample > eps_threshold:
        return select_action_lens(state)
    else:
        return LongTensor([[random.randrange(2)]])


def run_episode(e, environment):
    global CAN_STEP
    state = environment.reset()
    steps = 0
    while True:
        action = select_action(FloatTensor([state]))
        next_state, reward, done, _ = environment.step(action[0, 0].item())

        # negative reward when attempt ends
        if done:
            reward = -1

        memory.push(
            (
                FloatTensor([state]),
                action,  # action is already a tensor
                FloatTensor([next_state]),
                FloatTensor([reward]),
            )
        )

        learn()

        state = next_state
        steps += 1
        if e > REDUCE_LR_EP:
            if CAN_STEP == True:
                lr_scheduler.step()
                print(lr_scheduler.get_last_lr())
                CAN_STEP = False

        if done:
            print(
                "{2} Episode {0} finished after {1} steps".format(
                    e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                )
            )
            episode_durations.append(steps)
            break


def get_opposite_actions(actions):
    for index, action in enumerate(actions):
        if action == 0:
            opposite_action = action
        elif action == 1:
            opposite_action = action - 1
        actions[index] = opposite_action
    return actions


def flip_state_actions(states, actions):
    new_states = []
    for index, a_state in enumerate(states):
        if actions[index].item() == 0:
            new_states.append(a_state)
        elif actions[index].item() == 1:
            new_states.append(forwards_model_lens(backwards_model_canonical(a_state)))
        else:
            raise ValueError
    actions = get_opposite_actions(actions)
    new_states = torch.stack(new_states, dim=0)
    return new_states, actions


def process_replay_batch(replay_batch):
    replay_batch = [list(i) for i in replay_batch]
    current_obs = []
    next_obs = []
    for a_datum in replay_batch:
        current_obs.append(a_datum[0])
        next_obs.append(a_datum[2])
    current_obs = torch.stack(current_obs)
    next_obs = torch.stack(next_obs)
    current_obs = current_obs.reshape(current_obs.shape[0], -1)
    next_obs = next_obs.reshape(next_obs.shape[0], -1)
    return current_obs, next_obs


def train_transition_model(canonical_or_lens):
    print("Training transition model...")
    if canonical_or_lens == "lens":
        transition_model_forwards = forwards_model_lens
        transition_opt_forwards = forwards_model_lens_opt
        transition_model_backwards = backwards_model_lens
        transition_opt_backwards = backwards_model_lens_opt
        action = LENS_ACTION
    elif canonical_or_lens == "canonical":
        transition_model_forwards = forwards_model_canonical
        transition_opt_forwards = forwards_model_canonical_opt
        transition_model_backwards = backwards_model_canonical
        transition_opt_backwards = backwards_model_canonical_opt
        action = LENS_ACTION

    for _ in tqdm(range(LENS_TRAINING_STEPS)):
        replay_batch = memory.sample_specific_action(
            batch_size=LENS_BATCH_SIZE,
            desired_action=action,
        )
        input_tensor, output_tensor = process_replay_batch(replay_batch)
        transition_opt_forwards.zero_grad()
        pred_out = transition_model_forwards(input_tensor)
        loss = F.mse_loss(pred_out, output_tensor)
        wandb.log({"transition model loss": loss})
        loss.backward()
        transition_opt_forwards.step()

        replay_batch = memory.sample_specific_action(
            batch_size=LENS_BATCH_SIZE,
            desired_action=action,
        )
        input_tensor, output_tensor = process_replay_batch(replay_batch)
        transition_opt_backwards.zero_grad()
        pred_out = transition_model_backwards(input_tensor)
        loss = F.mse_loss(pred_out, output_tensor)
        wandb.log({"transition model loss": loss})
        loss.backward()
        transition_opt_backwards.step()


def learn():
    global MODELS_TRAINED
    if len(memory) < BATCH_SIZE:
        return

    if MODELS_TRAINED == False:
        train_transition_model(canonical_or_lens="canonical")
        train_transition_model(canonical_or_lens="lens")
        MODELS_TRAINED = True

    transitions = memory.sample(BATCH_SIZE)
    batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)
    batch_state = Variable(torch.cat(batch_state))
    batch_action = Variable(torch.cat(batch_action))
    batch_reward = Variable(torch.cat(batch_reward))
    batch_next_state = Variable(torch.cat(batch_next_state))

    batch_state, batch_action = flip_state_actions(batch_state, batch_action)
    current_q_values = model(batch_state).gather(1, batch_action)
    batch_next_state_tp2 = forwards_model_lens(batch_next_state)
    equivalent_state = backwards_model_canonical(batch_next_state_tp2)
    values = torch.stack(
        (
            model(batch_next_state)[:, 0].detach(),
            model(equivalent_state)[:, 0].detach(),
        ),
        dim=1,
    )
    max_next_q_values = values.max(1)[0]
    expected_q_values = batch_reward + (GAMMA * max_next_q_values)

    loss = F.smooth_l1_loss(current_q_values[:, 0], expected_q_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--EPISODES", type=int, default=30)
parser.add_argument("--EPS_START", type=float, default=0.9)
parser.add_argument("--EPS_END", type=float, default=0.05)
parser.add_argument("--EPS_DECAY", type=int, default=200)
parser.add_argument("--GAMMA", type=float, default=0.8)
parser.add_argument("--LR", type=float, default=1e-3)
parser.add_argument("--HIDDEN_LAYER", type=int, default=1024)
parser.add_argument("--BATCH_SIZE", type=int, default=64)
parser.add_argument("--ACTIVATION_FUNCTION", type=str, default="relu")
parser.add_argument("--REDUCE_LR_EP", type=int, default=15)
args = parser.parse_args()

# hyper parameters
CAN_STEP = True
LENS_BATCH_SIZE = 16
LENS_TRAINING_STEPS = int(2e4)
CANONICAL_ACITON = 0
LENS_ACTION = 1
TRANSITION_LR = 1e-3
MODELS_TRAINED = False
EPISODES = args.EPISODES  # number of episodes
EPS_START = args.EPS_START  # e-greedy threshold start value
EPS_END = args.EPS_END  # e-greedy threshold end value
EPS_DECAY = args.EPS_DECAY  # e-greedy threshold decay
GAMMA = args.GAMMA  # Q-learning discount factor
LR = args.LR  # NN optimizer learning rate
HIDDEN_LAYER = args.HIDDEN_LAYER  # NN hidden layer size
BATCH_SIZE = args.BATCH_SIZE  # Q-learning batch size
NUMBER_OF_REPEATS = 10
ACTIVATION_FUNCTION = args.ACTIVATION_FUNCTION
REDUCE_LR_EP = args.REDUCE_LR_EP
CAN_STEP = True
wandb.init(project="DQN cartpole", config=args)
# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


backwards_model_canonical = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
backwards_model_canonical_opt = optim.Adam(
    backwards_model_canonical.parameters(), TRANSITION_LR
)
backwards_model_canonical.eval()
backwards_model_canonical.to("cuda")
backwards_model_lens = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
backwards_model_lens_opt = optim.Adam(backwards_model_lens.parameters(), TRANSITION_LR)
backwards_model_lens.eval()
backwards_model_lens.to("cuda")
forwards_model_canonical = LinearModel(
    hidden_size=HIDDEN_LAYER, input_size=4, output_size=4
)
forwards_model_canonical_opt = optim.Adam(
    forwards_model_canonical.parameters(), TRANSITION_LR
)
forwards_model_canonical.eval()
forwards_model_canonical.to("cuda")
forwards_model_lens = LinearModel(hidden_size=HIDDEN_LAYER, input_size=4, output_size=4)
forwards_model_lens_opt = optim.Adam(forwards_model_lens.parameters(), TRANSITION_LR)
forwards_model_lens.eval()
forwards_model_lens.to("cuda")

episodes_taken = []
episodes_working_list = []
for repeat in range(NUMBER_OF_REPEATS):
    env = gym.make("CartPole-v0")
    model = Network(ACTIVATION_FUNCTION)
    if use_cuda:
        model.cuda()
    memory = ReplayMemory(10000)
    optimizer = optim.Adam(model.parameters(), LR)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    lr_scheduler.step()
    CAN_STEP = True
    steps_done = 0
    episode_durations = []

    for e in range(EPISODES):
        run_episode(e, env)

    searching_for_episodes_to_solve = True
    for index, an_episode_duration in enumerate(episode_durations):
        wandb.log({"episode length": an_episode_duration, "episode": index})
        if searching_for_episodes_to_solve:
            if an_episode_duration > 175:
                episodes_taken.append(index)
                searching_for_episodes_to_solve = False

    if len(episodes_taken) != repeat + 1:
        episodes_taken.append(len(episode_durations))
    env.close()

average_episodes_taken = np.mean(episodes_taken)
average_episodes_working = np.mean(episodes_working_list)
wandb.log({"average episodes taken": average_episodes_taken})
wandb.log({"average_episodes_working": average_episodes_working})
