import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

import argparse
from distutils.util import strtobool
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor

from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                        help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsRandomEnemyHRL3-v1",
                        help='the id of the gym environment')
    parser.add_argument('--learning-rate', type=float, default=2.5e-4,
                        help='the learning rate of the optimizer')
    parser.add_argument('--seed', type=int, default=1,
                        help='seed of the experiment')
    parser.add_argument('--total-timesteps', type=int, default=10000000,
                        help='total timesteps of the experiments')
    parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                        help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                        help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
                        help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
                        help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                        help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                        help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--n-minibatch', type=int, default=4,
                        help='the number of mini batch')
    parser.add_argument('--num-envs', type=int, default=8,
                        help='the number of parallel game environment')
    parser.add_argument('--num-steps', type=int, default=128,
                        help='the number of steps per game environment')
    parser.add_argument('--gamma', type=float, default=0.99,
                        help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
                        help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                        help="coefficient of the entropy")
    parser.add_argument('--vf-coef', type=float, default=0.5,
                        help="coefficient of the value function")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                        help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.1,
                        help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=4,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.03,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                         help="Toggles advantages normalization")
    parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                         help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
                         help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')

    parser.add_argument('--start-e', type=float, default=1.0,
                        help="the starting epsilon for exploration")
    parser.add_argument('--end-e', type=float, default=0.00,
                        help="the ending epsilon for exploration")
    parser.add_argument('--start-a', type=float, default=1.0,
                        help="the starting alpha for exploration")
    parser.add_argument('--end-a', type=float, default=0.8,
                        help="the ending alpha for exploration")
    parser.add_argument('--exploration-fraction', type=float, default=0.8,
                        help="the fraction of `total-timesteps` it takes from start-e to go end-e")
    
    parser.add_argument('--shift', type=int, default=400000,
                        help='the number of mini batch')
    parser.add_argument('--adaptation', type=int, default=1000000,
                        help='the number of mini batch')
    parser.add_argument('--positive-likelihood', type=float, default=0.0,
                        help="the likelihood to still do update if there are no positive reward in the trajectory")
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.n_minibatch)

class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=(old_shape[-1], old_shape[0], old_shape[1]),
            dtype=np.int32,
        )

    def observation(self, observation):
        return np.transpose(observation, axes=(2, 0, 1))

class VecPyTorch(VecEnvWrapper):
    def __init__(self, venv, device):
        super(VecPyTorch, self).__init__(venv)
        self.device = device

    def reset(self):
        obs = self.venv.reset()
        obs = torch.from_numpy(obs).float().to(self.device)
        return obs

    def step_async(self, actions):
        actions = actions.cpu().numpy()
        self.venv.step_async(actions)

    def step_wait(self):
        obs, reward, done, info = self.venv.step_wait()
        obs = torch.from_numpy(obs).float().to(self.device)
        reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
        return obs, reward, done, info

class MicroRTSStatsRecorder(gym.Wrapper):

    def reset(self, **kwargs):
        observation = super(MicroRTSStatsRecorder, self).reset(**kwargs)
        self.raw_rewards = []
        return observation

    def step(self, action):
        observation, reward, done, info = super(MicroRTSStatsRecorder, self).step(action)
        self.raw_rewards += [info["raw_rewards"]]
        if done:
            raw_rewards = np.array(self.raw_rewards).sum(0)
            raw_names = [str(rf) for rf in self.rfs]
            info['microrts_stats'] = dict(zip(raw_names, raw_rewards))
            self.raw_rewards = []
        return observation, reward, done, info


# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True, save_code=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
def make_env(gym_id, seed, idx):
    def thunk():
        env = gym.make(gym_id)
        env = ImageToPyTorch(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = MicroRTSStatsRecorder(env)
        if args.capture_video:
            if idx == 0:
                env = Monitor(env, f'videos/{experiment_name}')
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env
    return thunk
envs = VecPyTorch(DummyVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)]), device)
# if args.prod_mode:
#     envs = VecPyTorch(
#         SubprocVecEnv([make_env(args.gym_id, args.seed+i, i) for i in range(args.num_envs)], "fork"),
#         device
#     )
assert isinstance(envs.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"

# ALGO LOGIC: initialize agent here:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope =  (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Scale(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        return x * self.scale

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, frames=4):
        super(Agent, self).__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(27, 16, kernel_size=3, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(16, 32, kernel_size=2)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(32*3*3, 128)),
            nn.ReLU(),)
        self.actor = layer_init(nn.Linear(128, envs.action_space.nvec.sum()), std=0.01)
        self.critic = layer_init(nn.Linear(128, 1), std=1)

    def forward(self, x):
        return self.network(x)

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.actor(self.forward(x))
        split_logits = torch.split(logits, envs.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        return action, logprob.sum(0), entropy.sum(0)

    def get_value(self, x):
        return self.critic(self.forward(x))


num = envs.get_attr("num_reward_function")[0]
# agent = Agent().to(device)
agents = [Agent().to(device) for _ in range(num)]
# optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
optimizers =  [optim.Adam(agents[i].parameters(), 
    lr=args.learning_rate, eps=1e-5) for i in range(num)]
if args.anneal_lr:
    # https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/defaults.py#L20
    lr = lambda f: f * args.learning_rate

# ALGO Logic: Storage for epoch data
obs = torch.zeros((args.num_steps, args.num_envs) + envs.observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
# rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
# dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
# values = torch.zeros((args.num_steps, args.num_envs)).to(device)
# AC logic:
rewardss = [torch.zeros((args.num_steps, args.num_envs)).to(device) for _ in range(num)]
doness = [torch.zeros((args.num_steps, args.num_envs)).to(device) for _ in range(num)]
valuess = [torch.zeros((args.num_steps, args.num_envs)).to(device) for _ in range(num)]
invalid_action_masks = torch.zeros((args.num_steps, args.num_envs) + (envs.action_space.nvec.sum(),)).to(device)
# TRY NOT TO MODIFY: start the game
global_step = 0
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = envs.reset()
# next_done = torch.zeros(args.num_envs).to(device)
next_dones = [torch.zeros(args.num_envs).to(device) for _ in range(num)]
num_updates = args.total_timesteps // args.batch_size
for update in range(1, num_updates+1):
    # Annealing the rate if instructed to do so.
    if args.anneal_lr:
        frac = 1.0 - (update - 1.0) / num_updates
        lrnow = lr(frac)
        for optimizer in optimizers:  
            optimizer.param_groups[0]['lr'] = lrnow

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(0, args.num_steps):
        # envs.env_method("render", indices=0)
        global_step += 1 * args.num_envs
        obs[step] = next_obs
        for i in range(num): doness[i][step] = next_dones[i]
        invalid_action_masks[step] = torch.Tensor(np.array(envs.get_attr("action_mask")))

        # ALGO LOGIC: put action logic here
        # AC logic:
        alpha = linear_schedule(args.start_a, args.end_a, args.exploration_fraction*args.total_timesteps, global_step)
        # if global_step % 100 == 0:
        # if random.random() < epsilon:
        #     selected_policy_idx = np.random.randint(1, num)
        # else:
        #     selected_policy_idx = 0
        if 0 <= global_step <= args.shift:
            epsilon = 0.95
            if random.random() < epsilon:
                selected_policy_idx = np.random.randint(1, 2)
            else:
                selected_policy_idx = 0
        else:
            epsilon = linear_schedule(0.95, args.end_e, args.adaptation, global_step-args.shift)
            if random.random() < epsilon:
                selected_policy_idx = np.random.randint(1, 2)
            else:
                selected_policy_idx = 0

        with torch.no_grad():
            for i in range(num):
                valuess[i][step] = agents[i].get_value(obs[step]).flatten()
            action, logproba, _ = agents[selected_policy_idx].get_action(obs[step], invalid_action_masks=invalid_action_masks[step])
        
        actions[step] = action.T
        logprobs[step] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rs, ds, infos = envs.step(action.T)
        temp_rewards, temp_dones = [], []
        for info in infos:
            temp_rewards += [info['rewards']]
            temp_dones += [info['dones']]
        # rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)
        for i in range(num):
            rewardss[i][step] = torch.tensor(temp_rewards)[:,i]
            next_dones[i][:] = torch.Tensor(ds)
            # doness[i][step] = torch.tensor(temp_dones)[:,i]

        for info in infos:
            if 'episode' in info.keys():
                print(f"global_step={global_step}, episode_reward={info['episode']['r']}")
                writer.add_scalar("charts/episode_reward", info['episode']['r'], global_step)
                for key in info['microrts_stats']:
                    writer.add_scalar(f"charts/episode_reward/{key}", info['microrts_stats'][key], global_step)
                break

    # bootstrap reward if not done. reached the batch limit
    raw_names = [str(rf) for rf in envs.get_attr("rfs")[0]]
    for i in range(num):
        agent = agents[i]
        dones = doness[i]
        rewards = rewardss[i]
        values = valuess[i]
        optimizer = optimizers[i]
        next_done = next_dones[i]
        if rewards.sum() == 0.0:
            if random.random() > args.positive_likelihood:
                continue
        
        with torch.no_grad():
            last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)
            if args.gae:
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        nextvalues = last_value
                    else:
                        nextnonterminal = 1.0 - dones[t+1]
                        nextvalues = values[t+1]
                    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(device)
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        next_return = last_value
                    else:
                        nextnonterminal = 1.0 - dones[t+1]
                        next_return = returns[t+1]
                    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
                advantages = returns - values

        # flatten the batch
        b_obs = obs.reshape((-1,)+envs.observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,)+envs.action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        b_invalid_action_masks = invalid_action_masks.reshape((-1, invalid_action_masks.shape[-1]))

        # Optimizaing the policy and value network
        target_agent = Agent().to(device)
        inds = np.arange(args.batch_size,)
        for i_epoch_pi in range(args.update_epochs):
            np.random.shuffle(inds)
            target_agent.load_state_dict(agent.state_dict())
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                minibatch_ind = inds[start:end]
                mb_advantages = b_advantages[minibatch_ind]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                _, newlogproba, entropy = agent.get_action(
                    b_obs[minibatch_ind],
                    b_actions.long()[minibatch_ind].T,
                    b_invalid_action_masks[minibatch_ind])
                ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()

                # Stats
                approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-args.clip_coef, 1+args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()
                entropy_loss = entropy.mean()

                # Value loss
                new_values = agent.get_value(b_obs[minibatch_ind]).view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
                    v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind], -args.clip_coef, args.clip_coef)
                    v_loss_clipped = (v_clipped - b_returns[minibatch_ind])**2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 *((new_values - b_returns[minibatch_ind]) ** 2)

                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.kle_stop:
                if approx_kl > args.target_kl:
                    break
            if args.kle_rollback:
                if (b_logprobs[minibatch_ind] - agent.get_action(
                        b_obs[minibatch_ind],
                        b_actions.long()[minibatch_ind].T,
                        b_invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                    agent.load_state_dict(target_agent.state_dict())
                    break

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar(f"losses/value_loss/{raw_names[i]}", v_loss.item(), global_step)
        writer.add_scalar(f"losses/policy_loss/{raw_names[i]}", pg_loss.item(), global_step)
        writer.add_scalar(f"losses/entropy/{raw_names[i]}", entropy.mean().item(), global_step)
        writer.add_scalar(f"losses/approx_kl/{raw_names[i]}", approx_kl.item(), global_step)
        if args.kle_stop or args.kle_rollback:
            writer.add_scalar(f"debug/pg_stop_iter/{raw_names[i]}", i_epoch_pi, global_step)
    writer.add_scalar(f"charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar(f"losses/epsilon", epsilon, global_step)
    
envs.close()
writer.close()
