# taken from: https://gist.github.com/Pocuston/13f1a7786648e1e2ff95bfad02a51521
# also based on: http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
import gym
from gym import wrappers
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import wandb
import numpy as np
from tqdm import tqdm
from models import NonLinearModel, LinearModel
import matplotlib.pyplot as plt
from utils import upload_plot_to_wandb
import gridworld
from wrappers import FlattenWrapper


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def sample_specific_action(self, batch_size, desired_action):
        filtered_data = []
        for a_datum in self.memory:
            obs_current, action, obs_next, reward = a_datum
            if action == desired_action:
                filtered_data.append(a_datum)
            if len(filtered_data) == batch_size:
                return filtered_data

    def __len__(self):
        return len(self.memory)


def select_action_lens_dict(state):
    values = []
    for an_action in range(env.action_space.n):
        key_string = "forward" + str(state) + str(an_action)
        pred_next_state = transition_model[key_string]
        key_string = "backward" + str(pred_next_state) + str(0)
        equivalent_state = transition_model[key_string]
        values.append(
            model(torch.cuda.FloatTensor([equivalent_state]))[0][0].detach().cpu()
        )
    action = np.argmax(values)
    return LongTensor([[action]])


def select_action(state):
    global STEPS_DONE
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
        -1.0 * STEPS_DONE / EPS_DECAY
    )
    STEPS_DONE += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return (
                model(Variable(state).type(torch.cuda.FloatTensor))
                .data.max(1)[1]
                .view(1, 1)
            )
    else:
        return torch.cuda.LongTensor([[random.randrange(5)]])


def learn():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)

    batch_state = Variable(torch.cat(batch_state))
    batch_action = Variable(torch.cat(batch_action))
    batch_reward = Variable(torch.cat(batch_reward))
    batch_next_state = Variable(torch.cat(batch_next_state))

    current_q_values = model(batch_state).gather(1, batch_action)
    max_next_q_values = model(batch_next_state).detach().max(1)[0]
    expected_q_values = batch_reward + (GAMMA * max_next_q_values)

    loss = F.smooth_l1_loss(current_q_values[:, 0], expected_q_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def run_episode(e):
    state = env.reset()
    steps = 0
    return_value = 0
    while True:
        action = select_action(torch.cuda.FloatTensor([state]))
        next_state, reward, done, _ = env.step(action[0, 0].item())

        return_value += reward

        memory.push(
            (
                torch.cuda.FloatTensor([state]),
                action,  # action is already a tensor
                torch.cuda.FloatTensor([next_state]),
                torch.cuda.FloatTensor([reward]),
            )
        )

        learn()

        state = next_state
        steps += 1

        if done:
            wandb.log({"return": return_value})
            return_value = 0
            episode_durations.append(steps)
            print(
                "{2} Episode {0} finished after {1} steps".format(
                    e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                )
            )
            break


import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--LR", type=float, default=1e-2)
args = parser.parse_args()

EPISODES = 30
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
GAMMA = 0.8
LR = args.LR
BATCH_SIZE = 64
NUMBER_OF_REPEATS = 1
wandb.init(project="predator prey")

# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

episodes_working_list = []
for repeat in range(NUMBER_OF_REPEATS):
    env = FlattenWrapper(gym.make("GridEnv-v1"))
    model = LinearModel(
        hidden_size=1024,
        input_size=env.size[0] * 3 * env.size[1] * 3,
        output_size=env.action_space.n,
    )
    model.cuda()
    episodes_working = 0
    memory = ReplayMemory(100000)
    optimizer = optim.Adam(model.parameters(), LR)
    CAN_STEP = True
    STEPS_DONE = 0
    episode_durations = []

    for e in range(EPISODES):
        run_episode(e)

    for index, an_episode_duration in enumerate(episode_durations):
        wandb.log({"episode length": an_episode_duration, "episode": index})
        if an_episode_duration < 10:
            episodes_working += 1
    env.close()

    episodes_working_list.append(episodes_working)

average_episodes_working = np.mean(episodes_working_list)
wandb.log({"average_episodes_working": average_episodes_working})
