import os
import time
import logging
import pickle

from typing import Any, Dict, Tuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

import path_config

# https://github.com/tejaskhot/pytorch-LunarLander/blob/master/reinforce.py

# https://github.com/pytorch/examples/blob/master/imagenet/main.py#L139

logging_format = "%(lineno)4s: %(asctime)s: %(message)s"
logging_level = 15
logging.basicConfig(level=logging_level, format=logging_format)

logger = logging.getLogger(__name__)


# if gpu is to be used
use_cuda = torch.cuda.is_available()
logger.info("use_cuda : {}".format(use_cuda))
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


class Policy(nn.Module):
    def __init__(self, state_size, num_actions):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 16)
        self.fc4 = nn.Linear(16, num_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.softmax(self.fc4(x), dim=0)
        return x


def select_action(state: Tensor, policy) -> Tuple[np.ndarray, Tensor]:
    state = torch.autograd.Variable(Tensor(state))
    action_probs = policy(state)
    log_probs = action_probs.log()
    action = torch.distributions.Categorical(action_probs).sample()
    action_np = action.data.cpu().numpy()
    action_log_prob = log_probs[action]
    return action_np, action_log_prob


def play_episode(env, policy):
    state = env.reset()
    steps = 0
    rewards = []
    log_probs = []
    states = []
    while True:
        action, log_prob = select_action(state, policy)
        state, reward, is_terminal, _ = env.step(action)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        steps += 1
        if is_terminal:
            break
    return steps, rewards, log_probs


def test_policy(env,
                policy,
                num_episodes: int) -> Tuple[list, list]:
    testing_rewards = []
    testing_steps = []
    for epoch in range(num_episodes):
        steps, rewards, log_probs = play_episode(env, policy)
        testing_rewards.append(sum(rewards))
        testing_steps.append(steps)
    logger.info("Mean reward achieved : {} ".format(np.mean(testing_rewards)))
    logger.info("--------------------------------------------------")
    if np.mean(testing_rewards) >= 200:
        logger.info("------------------ Solved! -----------------------")
        logger.info(
            "Mean reward achieved : {} in {} steps".format(
                np.mean(testing_rewards), np.mean(testing_steps)
            )
        )
        logger.info("-------------------------------------------------")
    return testing_rewards, testing_steps


def optimize(optimizer, rewards, log_probs, gamma) -> float:
    r = torch.zeros(1, 1).type(FloatTensor)
    total_loss = 0
    for i in reversed(range(len(rewards))):
        # downscaling rewards by 1e-2 to help training
        r = gamma * r + (rewards[i] * 1e-2)
        total_loss = total_loss - (log_probs[i] * torch.autograd.Variable(r))
    loss = total_loss / len(rewards)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    this_loss = loss.detach().cpu().numpy().item()
    return this_loss


def train_policy(env,
                 policy,
                 optimizer,
                 start_epoch: int,
                 end_epoch: int,
                 test_episodes) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    gamma = .99
    save_freq = 1e3
    test_freq = 500

    train_losses = np.full((end_epoch, ), np.nan)
    train_rewards = np.full((end_epoch, ), np.nan)
    train_steps = np.full((end_epoch, ), np.nan)

    for epoch in range(start_epoch, end_epoch):
        steps, rewards, log_probs = play_episode(env, policy)
        episode_total_reward = sum(rewards)
        this_loss = optimize(optimizer, rewards, log_probs, gamma)

        train_rewards[epoch] = episode_total_reward
        train_steps[epoch] = steps
        train_losses[epoch] = this_loss

        if epoch % 100 == 0 and epoch > 0:
            to_log = "Episode: {:05d}, reward: {:+.4f}, steps: {}".format(epoch, episode_total_reward, steps)
            logger.info(to_log)
            plt.plot(train_losses)
            plt.plot(train_rewards)

        if epoch % test_freq == 0 and epoch > 0:
            logger.info("----------- testing now ----------")
            test_policy(env, policy, test_episodes)

        if epoch % save_freq == 0 and epoch > 0:
            checkpoint_fullfilename = build_checkpoint_name(epoch)

            checkpoint = {
                "epoch": epoch,
                "policy_state_dict": policy.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }
            logger.info("Saving to {}".format(checkpoint_fullfilename))

            with open(checkpoint_fullfilename, 'wb') as pklfile:
                pickle.dump(checkpoint, pklfile, protocol=pickle.HIGHEST_PROTOCOL)
    return train_rewards, train_losses, train_steps


def episode_play_video(env, policy):
    state = env.reset()
    steps = 0
    max_steps = 1000

    rewards = []
    log_probs = []
    states = []

    while True:
        action, log_prob = select_action(state, policy)
        state, reward, is_terminal, _ = env.step(action)

        if steps % 100 == 0:
            print("Steps: {}, rewards = {}".format(steps, reward))
        env.render()
        time.sleep(.1)
        log_probs.append(log_prob)
        rewards.append(reward)
        states.append(state)
        steps += 1
        if is_terminal or steps > max_steps:
            break
    return steps, rewards, log_probs


def _get_checkpoint_dir() -> str:
    paths = path_config.get_paths()
    checkpoint_filedir = paths["cached_calculations"]
    return checkpoint_filedir


def build_checkpoint_name(epoch: int) -> str:
    checkpoint_filedir = _get_checkpoint_dir()
    checkpoint_filename = "checkpoint_{:05d}.pkl".format(epoch)
    checkpoint_fullfilename = os.path.join(checkpoint_filedir,
                                           checkpoint_filename)
    return checkpoint_fullfilename


def get_last_epoch_filename() -> str:
    paths = path_config.get_paths()
    checkpoint_filedir = paths["cached_calculations"]
    listed = os.listdir(checkpoint_filedir)

    checkpoint_files = [x for x in listed if x.startswith("checkpoint_")]
    last_epoch_filename = max(checkpoint_files)
    return last_epoch_filename


def load_last_checkpoint() -> Dict[str, Any]:
    last_epoch_filename = get_last_epoch_filename()
    # fullfilename = build_checkpoint_name(last_epoch_filename)
    checkpoint_filedir = _get_checkpoint_dir()
    last_epoch_fullfilename = os.path.join(checkpoint_filedir, last_epoch_filename)

    with open(last_epoch_fullfilename, 'rb') as pklfile:
        loaded = pickle.load(pklfile)
    return loaded


if __name__ == "__main__":
    # end_epoch = 50000
    end_epoch = 50000
    test_episodes = 100

    env_name = "LunarLander-v2"
    env = gym.make(env_name)

    ident = "LunarLanderV2"

    seed = 123
    env.seed(seed)
    torch.manual_seed(seed)
    # plt.ion()

    lr = 5e-4
    policy = Policy(env.observation_space.shape[0],
                    env.action_space.n)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    start_epoch = 0

    # use_last_checkpoint = False
    use_last_checkpoint = True

    if use_last_checkpoint:
        # policy = None # load_policy()
        # last_epoch = get_last_epoch()
        last_checkpoint = load_last_checkpoint()
        policy.load_state_dict(last_checkpoint['policy_state_dict'])
        optimizer.load_state_dict(last_checkpoint['optimizer_state_dict'])
        start_epoch = last_checkpoint["epoch"]
    if use_cuda:
        policy.cuda()

    play_video_before_starting_fit = False
    if play_video_before_starting_fit:

        episode_play_video(env, policy)
        env.close()


    train_rewards, train_losses, train_steps = train_policy(env,
                                                            policy,
                                                            optimizer,
                                                            start_epoch,
                                                            end_epoch,
                                                            test_episodes)
