# taken from: https://gist.github.com/Pocuston/13f1a7786648e1e2ff95bfad02a51521
# also based on: http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
import gym
from gym import wrappers
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import wandb
import numpy as np
from tqdm import tqdm
from models import NonLinearModel, LinearModel
import matplotlib.pyplot as plt
from utils import upload_plot_to_wandb


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def sample_specific_action(self, batch_size, desired_action):
        filtered_data = []
        for a_datum in self.memory:
            obs_current, action, obs_next, reward = a_datum
            if action == desired_action:
                filtered_data.append(a_datum)
            if len(filtered_data) == batch_size:
                return filtered_data

    def __len__(self):
        return len(self.memory)


class DQN:
    def __init__(self, config):
        self.config = config
        self.env = gym.make(self.config["env"])
        self.steps_done = 0
        self.pretrain_episodes = 0

        self.q_network = NonLinearModel(
            hidden_size=self.config["policy_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.action_space.n,
        )
        # self.q_network.cuda()
        self.q_network_opt = optim.Adam(
            self.q_network.parameters(), self.config["dqn_lr"]
        )
        self.memory = ReplayMemory(self.config["replay_size"])
        self.episode_lengths = []

    def learn_q_network_weights(self):
        for e in range(self.config["episodes"]):
            self.run_episode(e + self.pretrain_episodes)

    def DQN_update(self):
        if len(self.memory) < self.config["q_network_batch_size"]:
            return

        transitions = self.memory.sample(self.config["q_network_batch_size"])
        batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)

        batch_state = Variable(torch.cat(batch_state))
        batch_action = Variable(torch.cat(batch_action))
        batch_reward = Variable(torch.cat(batch_reward))
        batch_next_state = Variable(torch.cat(batch_next_state))

        current_q_values = self.q_network(batch_state).gather(1, batch_action)
        max_next_q_values = self.q_network(batch_next_state).detach().max(1)[0]
        expected_q_values = batch_reward + (self.config["gamma"] * max_next_q_values)

        loss = F.smooth_l1_loss(current_q_values[:, 0], expected_q_values)

        self.q_network_opt.zero_grad()
        loss.backward()
        self.q_network_opt.step()

    def select_action(self, state):
        sample = random.random()
        eps_threshold = self.config["eps_end"] + (
            self.config["eps_start"] - self.config["eps_end"]
        ) * math.exp(-1.0 * self.steps_done / self.config["eps_decay"])
        self.steps_done += 1
        if sample > eps_threshold:
            with torch.no_grad():
                return (
                    self.q_network(Variable(state).type(torch.FloatTensor))
                    .data.max(1)[1]
                    .view(1, 1)
                )
        else:
            return torch.LongTensor([[random.randrange(2)]])

    def run_episode(self, e):
        state = self.env.reset()
        steps = 0
        while True:
            action = self.select_action(torch.FloatTensor([state]))
            next_state, reward, done, _ = self.env.step(action[0, 0].item())

            if done:
                reward = -1

            self.memory.push(
                (
                    torch.FloatTensor([state]),
                    action,  # action is already a tensor
                    torch.FloatTensor([next_state]),
                    torch.FloatTensor([reward]),
                )
            )

            state = next_state
            steps += 1

            if done:
                self.episode_lengths.append(steps)
                print(
                    "{2} Episode {0} finished after {1} steps".format(
                        e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                    )
                )
                break


class HomoDQNLearnt(DQN):
    def __init__(self, config):
        super().__init__(config)
        lmbda = lambda epoch: 0.8
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.q_network_opt, step_size=2, gamma=0.1
        )
        self.lr_scheduler.step()
        self.lens = LinearModel(
            hidden_size=self.config["lens_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.lens_opt = optim.Adam(
            params=self.lens.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.can_change_lr = True
        self.transition_model_forwards = LinearModel(
            hidden_size=self.config["transition_model_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.transition_model_forwards_opt = optim.Adam(
            params=self.transition_model_forwards.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.transition_model_forwards_lens = LinearModel(
            hidden_size=self.config["transition_model_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.transition_model_forwards_lens_opt = optim.Adam(
            params=self.transition_model_forwards_lens.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.transition_model_backwards = LinearModel(
            hidden_size=self.config["transition_model_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.transition_model_backwards_opt = optim.Adam(
            params=self.transition_model_backwards.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.trained_lens = False

    def select_action_naive(self, state):
        sample = random.random()
        eps_threshold = self.config["eps_end"] + (
            self.config["eps_start"] - self.config["eps_end"]
        ) * math.exp(-1.0 * self.steps_done / self.config["eps_decay"])
        self.steps_done += 1
        if sample > eps_threshold:
            with torch.no_grad():
                return (
                    self.q_network(Variable(state).type(torch.FloatTensor))
                    .data.max(1)[1]
                    .view(1, 1)
                )
        else:
            return torch.LongTensor([[random.randrange(2)]])

    def can_lens(self, state):
        with torch.no_grad():
            next_state = self.transition_model_forwards_lens(state)
            pred_lensed_state = self.transition_model_backwards(next_state)
            lensed_state = self.lens(state)
            diff = F.mse_loss(lensed_state, pred_lensed_state)
            if diff < self.config["lens_tolerance"]:
                return True
            return False

    def select_action(self, state):
        if self.can_lens(state) == False:
            return self.select_action_naive(state)
        lensed_state = self.lens_an_observation(state)
        sample = random.random()
        eps_threshold = self.config["eps_end"] + (
            self.config["eps_start"] - self.config["eps_end"]
        ) * math.exp(-1.0 * self.steps_done / self.config["eps_decay"])
        self.steps_done += 1
        if sample > eps_threshold:
            with torch.no_grad():
                if (
                    self.q_network(Variable(state).type(torch.FloatTensor))[0][0]
                    > self.q_network(Variable(lensed_state.type(torch.FloatTensor)))[0][
                        0
                    ]
                ):
                    next_action = torch.LongTensor([[self.config["canonical_action"]]])
                else:
                    next_action = torch.LongTensor([[self.config["lens_action"]]])
        else:
            next_action = torch.LongTensor([[random.randrange(2)]])
        return next_action

    def lens_batch(self, batch_state, batch_action, batch_reward, batch_next_state):
        can_lens = []
        with torch.no_grad():
            for index, a_state in enumerate(batch_state):
                if batch_action[index].item() == self.config["lens_action"]:
                    if (
                        self.can_lens(batch_state[index]) == True
                        and self.can_lens(batch_next_state[index]) == True
                    ):
                        can_lens.append(True)
                        batch_state[index] = self.lens(batch_state[index])
                        batch_action[index] = self.config["canonical_action"]
                    else:
                        can_lens.append(False)
                else:
                    can_lens.append(True)
        return batch_state, batch_action, batch_reward, batch_next_state, can_lens

    def DQN_update(self):
        if len(self.memory) < self.config["q_network_batch_size"]:
            return

        transitions = self.memory.sample(self.config["q_network_batch_size"])
        batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)

        batch_state = Variable(torch.cat(batch_state))
        batch_action = Variable(torch.cat(batch_action))
        batch_reward = Variable(torch.cat(batch_reward))
        batch_next_state = Variable(torch.cat(batch_next_state))

        (
            batch_state,
            batch_action,
            batch_reward,
            batch_next_state,
            can_lens,
        ) = self.lens_batch(batch_state, batch_action, batch_reward, batch_next_state)

        loss = 0

        # separate based on action
        # do two next step updates one for lens possible one without
        for index, _ in enumerate(batch_state):
            current_q_values = self.q_network(batch_state[index])[batch_action[index]]
            if can_lens[index] == True:
                some_values = self.q_network(batch_next_state[index]).detach()[0]
                more_values = self.q_network(
                    self.lens(batch_next_state[index])
                ).detach()[
                    0
                ]  # next state must be lensible also
                max_next_q_values = torch.stack([some_values, more_values]).max()
            elif can_lens[index] == False:
                max_next_q_values = (
                    self.q_network(batch_next_state[index]).detach().max()
                )

            expected_q_values = batch_reward[index] + (
                self.config["gamma"] * max_next_q_values
            )
            loss += F.smooth_l1_loss(current_q_values[0], expected_q_values)

        self.q_network_opt.zero_grad()
        loss.backward()
        self.q_network_opt.step()

    def train_lens(self):
        self.lens = LinearModel(
            hidden_size=self.config["lens_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.lens_opt = optim.Adam(
            params=self.lens.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.transition_model = LinearModel(
            hidden_size=self.config["transition_model_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.transition_opt = optim.Adam(
            params=self.transition_model.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.transition_model_lens = LinearModel(
            hidden_size=self.config["transition_model_hidden_size"],
            input_size=self.env.observation_space.shape[0],
            output_size=self.env.observation_space.shape[0],
        )
        self.transition_opt_lens = optim.Adam(
            params=self.transition_model_lens.parameters(),
            lr=self.config["lens_learning_rate"],
        )
        self.lens.cuda()
        self.transition_model.cuda()
        self.transition_model_lens.cuda()
        self.train_transition_model()
        self.transition_model.eval()
        torch.save(self.transition_model.state_dict(), "forwards_model.pt")
        torch.save(self.transition_model_lens.state_dict(), "forwards_model_lens.pt")
        self.train_lens_model()
        self.test_lens()
        upload_plot_to_wandb()
        self.lens.eval()
        self.transition_model.eval()
        torch.save(self.lens.state_dict(), "lens.pt")
        self.train_q_network_on_memories()

    def train_q_network_on_memories(self):
        for an_update_step in range(
            int(sum(self.episode_lengths) / self.config["q_network_batch_size"])
        ):
            self.DQN_update()

    def test_lens(self):
        print("testing lens...")
        replay_batch = self.memory.sample(batch_size=10)
        obs_t, obs_tp1 = self.process_replay_batch(replay_batch)
        lensed = self.lens(obs_t).detach().cpu()
        feature_strings = ["Position", "Velocity", "Angle", "Angular Velocity"]
        for i, a_feature in enumerate(feature_strings):
            plt.scatter(
                obs_t[:, i].detach().cpu().numpy(),
                lensed[:, i].detach().cpu().numpy(),
                marker=".",
                alpha=0.2,
            )
            plt.title(a_feature)
            plt.xlabel("Before Lens")
            plt.ylabel("After Lens")
            plt.savefig(f"{a_feature}.png")
            plt.close()

    def process_replay_batch(self, replay_batch):
        replay_batch = [list(i) for i in replay_batch]
        current_obs = []
        next_obs = []
        for a_datum in replay_batch:
            current_obs.append(a_datum[0])
            next_obs.append(a_datum[2])
        current_obs = torch.stack(current_obs)
        next_obs = torch.stack(next_obs)
        current_obs = current_obs.reshape(current_obs.shape[0], -1)
        next_obs = next_obs.reshape(next_obs.shape[0], -1)
        return current_obs, next_obs

    def train_transition_models(self):
        print("Training transition model...")
        for _ in tqdm(range(self.config["lens_training_steps"])):
            replay_batch = self.memory.sample_specific_action(
                batch_size=self.config["lens_batch_size"],
                desired_action=self.config["canonical_action"],
            )
            # train backwards model
            input_tensor, output_tensor = self.process_replay_batch(replay_batch)
            self.transition_opt.zero_grad()
            pred_out = self.transition_model(input_tensor)
            loss = F.mse_loss(pred_out, output_tensor)
            wandb.log({"transition model loss": loss})
            loss.backward()
            self.transition_model_backwards_opt.step()

            replay_batch = self.memory.sample_specific_action(
                batch_size=self.config["lens_batch_size"],
                desired_action=self.config["lens_action"],
            )
            input_tensor, output_tensor = self.process_replay_batch(replay_batch)
            self.transition_opt_lens.zero_grad()
            pred_out = self.transition_model_lens(input_tensor)
            loss = F.mse_loss(pred_out, output_tensor)
            wandb.log({"transition model loss": loss})
            loss.backward()
            self.transition_model_forwards_opt.step()

            replay_batch = self.memory.sample_specific_action(
                batch_size=self.config["lens_batch_size"],
                desired_action=self.config["lens_action"],
            )
            input_tensor, output_tensor = self.process_replay_batch(replay_batch)
            self.transition_model_forwards_lens_opt.zero_grad()
            pred_out = self.transition_model_forwards_lens(input_tensor)
            loss = F.mse_loss(pred_out, output_tensor)
            loss.backward()
            self.transition_model_forwards_lens_opt.step()

    def run_episode(self, e):
        state = self.env.reset()
        steps = 0
        while True:
            action = self.select_action(torch.FloatTensor([state]))
            next_state, reward, done, _ = self.env.step(action[0, 0].item())

            if done:
                reward = -1

            self.memory.push(
                (
                    torch.FloatTensor([state]),
                    action,  # action is already a tensor
                    torch.FloatTensor([next_state]),
                    torch.FloatTensor([reward]),
                )
            )
            self.DQN_update()
            state = next_state
            steps += 1

            if done:
                self.episode_lengths.append(steps)
                print(
                    "{2} Episode {0} finished after {1} steps".format(
                        e, steps, "\033[92m" if steps >= 195 else "\033[99m"
                    )
                )
                if steps > 190:
                    if self.can_change_lr == True:
                        print("stepped!")
                        self.lr_scheduler.step()
                        self.can_change_lr = False
                print(self.lr_scheduler.get_last_lr())
                if len(self.memory) > 3 * self.config["lens_batch_size"]:
                    if self.trained_lens == False:
                        self.train_transition_models()
                        self.transition_model_backwards.eval()
                        self.transition_model_forwards.eval()
                        self.train_lens_model()
                        self.test_lens()
                        self.lens.eval()
                        self.trained_lens = True
                break

    def train_lens_model(self):
        obs_arr = []
        action_arr = []
        print("Training lens...")
        for step in tqdm(range(self.config["lens_training_steps"])):
            self.lens_opt.zero_grad()
            replay_batch = self.memory.sample_specific_action(
                batch_size=self.config["lens_batch_size"],
                desired_action=self.config["lens_action"],
            )
            obs_t_tensor, obs_tp1_tensor = self.process_replay_batch(replay_batch)
            obs_t_tensor_alt = self.transition_model_backwards(obs_tp1_tensor)
            lensed_obs_t = self.lens(obs_t_tensor)
            loss = F.mse_loss(lensed_obs_t, obs_t_tensor_alt)
            loss.backward()
            self.lens_opt.step()
