# taken from: https://gist.github.com/Pocuston/13f1a7786648e1e2ff95bfad02a51521
# also based on: http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
import gym
from gym import wrappers
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import wandb
import numpy as np
from tqdm import tqdm
from models import NonLinearModel, LinearModel
import matplotlib.pyplot as plt
from utils import upload_plot_to_wandb
import gridworld
from wrappers import FlattenWrapper


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def sample_specific_action(self, batch_size, desired_action):
        filtered_data = []
        for a_datum in self.memory:
            obs_current, action, obs_next, reward = a_datum
            if action == desired_action:
                filtered_data.append(a_datum)
            if len(filtered_data) == batch_size:
                return filtered_data

    def __len__(self):
        return len(self.memory)


def get_equivalent_state_torch(state, action):
    action_tensor = (np.ones_like(state) * action) / 5
    an_input = np.concatenate((state, action_tensor))
    an_input = torch.FloatTensor(an_input)
    an_input = an_input.to("cuda:0")
    next_state = model_forwards(an_input)
    canonical_action_tensor = torch.FloatTensor((np.ones_like(state) * 1) / 5).to(
        "cuda:0"
    )
    next_state = torch.cat((next_state, canonical_action_tensor))
    equivalent_state = model_backwards(next_state).reshape(1, -1)
    return equivalent_state


def get_equivalent_state(state, action):
    key_string = "forward" + str(state) + str(action)
    pred_next_state = transition_model[key_string]
    key_string = "backward" + str(pred_next_state) + str(1)
    equivalent_state = transition_model[key_string]
    equivalent_state = torch.cuda.FloatTensor([equivalent_state])
    return equivalent_state


def select_action_lens(state):
    values = []
    for an_action in range(env.action_space.n):
        equivalent_state = get_equivalent_state_torch(state, an_action)
        values.append(model(equivalent_state)[0][1].detach().cpu())
    action = np.argmax(values)
    return LongTensor([[action]])


def select_action(state):
    global STEPS_DONE
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
        -1.0 * STEPS_DONE / EPS_DECAY
    )
    STEPS_DONE += 1
    if sample > eps_threshold:
        return select_action_lens(state)
    else:
        return LongTensor([[random.randrange(5)]])


def process_replay_batch_forwards(replay_batch):
    replay_batch = [list(i) for i in replay_batch]
    state_t_with_action = []
    state_tp1 = []
    for a_datum in replay_batch:
        action = torch.zeros((1, env.action_space.n))
        action[:, a_datum[1].item()] = 1
        action = action.to("cuda:0")
        state_t_with_action.append(torch.cat((action, a_datum[0]), dim=1))
        state_tp1.append(a_datum[2])
    state_t_with_action = torch.cat(state_t_with_action)
    state_tp1 = torch.cat(state_tp1)
    return state_t_with_action, state_tp1


def process_replay_batch_backwards(replay_batch):
    replay_batch = [list(i) for i in replay_batch]
    state_t = []
    state_tp1_with_action = []
    for a_datum in replay_batch:
        action = torch.zeros((1, env.action_space.n))
        action[:, a_datum[1].item()] = 1
        action = action.to("cuda:0")
        state_t.append(a_datum[0])
        state_tp1_with_action.append(torch.cat((action, a_datum[2]), dim=1))
    state_t = torch.cat(state_t)
    state_tp1_with_action = torch.cat(state_tp1_with_action)
    return state_t, state_tp1_with_action


def flip_state_actions(batch_state, batch_action):
    states = []
    actions = []
    for i, an_action in enumerate(batch_action):
        action = an_action.detach().cpu().numpy()[0]
        state = batch_state[i].detach().cpu().numpy()
        equivalent_state = get_equivalent_state_torch(state, action)
        states.append(equivalent_state)
        actions.append(1)
    states = torch.cat(states)
    actions = torch.LongTensor(actions).unsqueeze(1)
    return states.cuda(), actions.cuda()


def learn():
    global MODELS_TRAINED
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)
    batch_state = Variable(torch.cat(batch_state))
    batch_action = Variable(torch.cat(batch_action))
    batch_reward = Variable(torch.cat(batch_reward))
    batch_next_state = Variable(torch.cat(batch_next_state))

    batch_state, batch_action = flip_state_actions(batch_state, batch_action)
    current_q_values = model(batch_state).gather(1, batch_action)
    all_equivalent_states = []
    for an_action in range(5):
        equivalent_states = []
        for a_state in batch_next_state:
            equivalent_state = get_equivalent_state_torch(
                a_state.detach().cpu().numpy(), an_action
            )
            equivalent_states.append(equivalent_state)
        all_equivalent_states.append(torch.cat(equivalent_states, dim=0))
    pred_values = []
    all_equivalent_states = torch.stack(all_equivalent_states)
    for an_action in range(5):
        pred_values.append(model(all_equivalent_states[an_action])[:, 1])
    max_next_q_values = torch.stack(pred_values).max(0)[0].detach()
    expected_q_values = batch_reward + (GAMMA * max_next_q_values)

    loss = F.smooth_l1_loss(current_q_values[:, 0], expected_q_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def run_episode(e):
    state = env.reset()
    steps = 0
    return_value = 0
    while True:
        action = select_action(state)
        next_state, reward, done, _ = env.step(action[0, 0].item())

        return_value += reward

        memory.push(
            (
                torch.cuda.FloatTensor([state]),
                action,  # action is already a tensor
                torch.cuda.FloatTensor([next_state]),
                torch.cuda.FloatTensor([reward]),
            )
        )

        learn()

        state = next_state
        steps += 1

        if done:
            wandb.log({"return": return_value})
            return_value = 0
            episode_durations.append(steps)
            print(
                "{2} Episode {0} finished after {1} steps".format(
                    e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                )
            )
            break


import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--LR", type=float, default=1e-3)
args = parser.parse_args()

# hyper parameters
EPISODES = 10  # number of episodes
EPS_START = 0.9  # e-greedy threshold start value
EPS_END = 0.05  # e-greedy threshold end value
EPS_DECAY = 200  # e-greedy threshold decay
GAMMA = 0.8  # Q-learning discount factor
LR = args.LR
HIDDEN_LAYER = 64  # NN hidden layer size
BATCH_SIZE = 64  # Q-learning batch size
NUMBER_OF_REPEATS = 1
# ACTIVATION_FUNCTION = "tanh"
CAN_STEP = True
LENS_BATCH_SIZE = 64
LENS_TRAINING_STEPS = int(2e4)
CANONICAL_ACITON = 0

LENS_ACTION = 1
TRANSITION_LR = 1e-3
MODELS_TRAINED = False
wandb.init(project="predator prey")

env = FlattenWrapper(gym.make("GridEnv-v1"))
# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor
model_forwards = NonLinearModel(input_size=882, output_size=441, hidden_size=512).to(
    "cuda:0"
)
model_forwards.load_state_dict(torch.load("model_forward.pt"))
model_backwards = NonLinearModel(input_size=882, output_size=441, hidden_size=512).to(
    "cuda:0"
)
model_backwards.load_state_dict(torch.load("model_backward.pt"))
transition_model = np.load("model.npy", allow_pickle=True).item()

episodes_working_list = []
for repeat in range(NUMBER_OF_REPEATS):
    env = FlattenWrapper(gym.make("GridEnv-v1"))
    model = LinearModel(
        hidden_size=1024,
        input_size=env.size[0] * 3 * env.size[1] * 3,
        output_size=env.action_space.n,
    )
    model.cuda()
    episodes_working = 0
    memory = ReplayMemory(100000)
    optimizer = optim.Adam(model.parameters(), LR)
    CAN_STEP = True
    STEPS_DONE = 0
    episode_durations = []

    for e in range(EPISODES):
        run_episode(e)

    for index, an_episode_duration in enumerate(episode_durations):
        wandb.log({"episode length": an_episode_duration, "episode": index})
        if an_episode_duration < 10:
            episodes_working += 1
    env.close()

    episodes_working_list.append(episodes_working)

average_episodes_working = np.mean(episodes_working_list)
wandb.log({"average_episodes_working": average_episodes_working})
