# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from gym_macro_overcooked.macActEnvWrapper import MacEnvWrapper
# from score_model.twosome.policy_pomdp import LLMAgent
from critic.policy_pomdp import LLMAgent, obs2text

import logging
import pdb
import json
def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, default="Nethack",
        help="the name of this experiment")
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="if toggled, this experiment will be tracked with Weights and Biases")
    parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
        help="the wandb's project name")
    parser.add_argument("--wandb-entity", type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="whether to capture videos of the agent performances (check out `videos` folder)")

    # Algorithm specific arguments
    parser.add_argument("--total-timesteps", type=int, default=500000,
        help="total timesteps of the experiments")
    
    parser.add_argument("--policy-learning-rate", type=float, default=5e-7,
        help="the learning rate of the optimizer")
    parser.add_argument("--value-learning-rate", type=float, default=1e-5,
        help="the learning rate of the optimizer")
    
    parser.add_argument("--num-envs", type=int, default=1,
        help="the number of parallel game environments")
    parser.add_argument("--num-steps", type=int, default=128*1,
        help="the number of steps to run in each environment per policy rollout")
    parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument("--gamma", type=float, default=0.99,
        help="the discount factor gamma")
    parser.add_argument("--gae-lambda", type=float, default=0.95,
        help="the lambda for the general advantage estimation")
    # parser.add_argument("--num-minibatches", type=int, default=16,
    #     help="the number of mini-batches")
    parser.add_argument("--policy-num-minibatches", type=int, default=64*1,
        help="the number of mini-batches")
    parser.add_argument("--value-num-minibatches", type=int, default=4,
        help="the number of mini-batches")

    parser.add_argument("--update-epochs", type=int, default=1,
        help="the K epochs to update the policy")
    parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggles advantages normalization")
    parser.add_argument("--clip-coef", type=float, default=0.2,
        help="the surrogate clipping coefficient")
    parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
    parser.add_argument("--ent-coef", type=float, default=0.01,
        help="coefficient of the entropy")
    parser.add_argument("--vf-coef", type=float, default=0.5,
        help="coefficient of the value function")
    parser.add_argument("--max-grad-norm", type=float, default=0.5,
        help="the maximum norm for the gradient clipping")
    parser.add_argument("--target-kl", type=float, default=0.02,
        help="the target KL divergence threshold")
    
    parser.add_argument('--gradient-checkpointing-steps', action='store',  type=int,             default=8,                     help='The number of steps for gradient checkpointing')
    parser.add_argument('--critic-warm-up-steps',   action='store',        type=int,             default=0,                  help='The number of time steps to warm up critic')
    
    #env_parameter
    parser.add_argument('--env-id',                 action='store',        type=str,             default='trainTask',  help='Domain name')
    parser.add_argument('--n-agent',                action='store',        type=int,             default=1,                     help='Number of agents')
    parser.add_argument('--grid-dim',               action='store',        type=int,   nargs=2,  default=[7,7],                 help='Grid world size')
    parser.add_argument('--task',                   action='store',        type=int,             default=0,                     help='The receipt agent cooks')
    parser.add_argument('--map-type',               action='store',        type=str,             default="A",                   help='The type of map')
    parser.add_argument('--obs-radius',             action='store',        type=int,             default=2,                     help='The radius of the agents')
    parser.add_argument('--env-reward',             action='store',        type=float, nargs=4,  default=[0.1, 1, 0, 0.001],    help='The reward list of the env')
    parser.add_argument('--mode',                   action='store',        type=str,             default="vector",              help='The type of the observation(vector/image)')    
    parser.add_argument('--debug',                  action='store',        type=bool,            default=False,                 help='Whehter print the debug information and render') 
    
    
    parser.add_argument('--load-8bit',              action='store',        type=bool,            default=False,                 help='Whether to convert model to 8bits')
    
    parser.add_argument('--save-path',              action='store',        type=str,             default="saved_models",        help='The path to save the checkpoint')
    parser.add_argument('--save-interval',          action='store',        type=int,             default=10,                    help='The interval for saving model for certain num_updates')
    parser.add_argument('--resume',                 action='store',        type=bool,            default=False,                 help='Whehter resume from previous checkpoint')
    parser.add_argument('--load-path',              action='store',        type=str,             default="saved_models",        help='The path to load the checkpoint')    
    parser.add_argument('--record-path',            action='store',        type=str,             default="critic/workdir",           help='The path to save the tensorboard results')    

    parser.add_argument('--normalization-mode',     action='store',        type=str,             default="word",               help='The normalization mode of how to deal with the logits of each token')    
    
    # add these for running the _run code without reporting any errors
    parser.add_argument("--n-level", type=str, default="L5", help="level for task")
    parser.add_argument("--test-config", type=str, default="vanilla", help="test config file for llm planner")
    parser.add_argument("--train-config", type=str, default="train_evolve", help="test config file for llm planner")
    parser.add_argument("--transfer-config", type=str, default="transfer_evolve", help="test config file for llm planner")
    parser.add_argument("--n-planner", type=int, default=3, help="level for task")
    
    # add these for running the nethack without reporting any errors
    parser.add_argument('--agent', type=str, default="llm", help="Choose which agent to run (handcrafted, llm or copic).")
    parser.add_argument('-agent_mode', type=str, default="cot-few-shot", help="Choose the mode of llm (vanilla, cot-zero-shot, cot-few-shot, reflexion)")
    parser.add_argument('--task_id', type=int, default=5, help="the id of task")
    parser.add_argument('configs/base_config.yaml', type=str, default="none", help="the path of base config file")

    args = parser.parse_args()
    assert args.num_envs == 1, "only 1 env is supported for now."
    args.batch_size = int(args.num_envs * args.num_steps)
    
    args.policy_minibatch_size = int(args.batch_size // args.policy_num_minibatches)
    args.value_minibatch_size = int(args.batch_size // args.value_num_minibatches)
    
    # fmt: on
    return args


# def make_env(env_id, seed, idx, capture_video, run_name, env_params):
#     def thunk():

#         env = gym.make(env_id, **env_params)
#         env = MacEnvWrapper(env)
#         if capture_video:
#             if idx == 0:
#                 env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
#         env.seed(seed)
#         env.action_space.seed(seed)
#         env.observation_space.seed(seed)
#         return env

#     return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

# if __name__ == "__main__":
args = parse_args()

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)  # If you're using CUDA
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# rewardList = {"subtask finished": args.env_reward[0], "correct delivery": args.env_reward[1], "wrong delivery": -args.env_reward[2], "step penalty": -args.env_reward[3]}
TASKLIST = ["tomato salad", "lettuce salad", "onion salad", "lettuce-tomato salad", "onion-tomato salad", "lettuce-onion salad", "lettuce-onion-tomato salad"]
# env_params = {'grid_dim': args.grid_dim,
#                 'task': TASKLIST[args.task],
#                 'rewardList': rewardList,
#                 'map_type': args.map_type,
#                 'n_agent': args.n_agent,
#                 'obs_radius': args.obs_radius,
#                 'mode': args.mode,
#                 'debug': args.debug
#             }

# env setup
# envs = gym.vector.SyncVectorEnv(
#     [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, env_params) for i in range(args.num_envs)]
# )
# assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

class Twosome:
    def __init__(
        self,
        resume=False,
        load_path=None,
        record_path=None,
        run_name=None,
        task_id=None,
        infer=False,
    ) -> None:
        pass
        self.record_path = record_path
        # self.run_name = run_name

        if not infer:
            # self.writer = SummaryWriter(os.path.join(args.record_path, f"task_{task_id}", run_name))
            self.writer = SummaryWriter(self.record_path)
            self.writer.add_text(
                "hyperparameters",
                "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
            )
        if resume:
            self.agent = LLMAgent(normalization_mode=args.normalization_mode, load_path=load_path, load_8bit=args.load_8bit, task=args.task)
        else:
            self.agent = LLMAgent(normalization_mode=args.normalization_mode, load_8bit=args.load_8bit, task=args.task)

        self.action_list = ["Plan1", "Plan2", "Plan3"]
        self.action_name2id = {action_name: idx for idx, action_name in enumerate(self.action_list)}
        self.action_id2name = {idx: action_name for idx, action_name in enumerate(self.action_list)}
        
        self.n_actions = len(self.action_list)
        
        # if not resume:
        self.policy_optimizer = optim.AdamW(filter(lambda p: p.requires_grad, self.agent.actor.parameters()), lr=args.policy_learning_rate, eps=1e-5, weight_decay=0)
        self.value_optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.agent.critic.parameters()), lr=args.value_learning_rate, eps=1e-5)

        self.storage_setup()
        self.init_params()
        
        self.num_subtask_each_step = 1
        
        self.step_each_evolve = 0
        self.meet_flag = False
        
        self.skills = {}
        self.n_skill = 0
        
    def record_skill(self, skill_list):
        self.skills[self.n_skill] = {i: skill for i, skill in enumerate(skill_list)}
        self.n_skill += 1
        
    def store_skill(self):
        with open(os.path.join(self.record_path, "skills.json"), "w") as f:
            json.dump(self.skills, f, indent=4)
    
    def storage_setup(self):
        # ALGO Logic: Storage setup
        # self.obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
        self.obs = [[None for _ in range(args.num_envs)] for _ in range(args.num_steps)]
        self.actions = torch.zeros((args.num_steps, args.num_envs)).to(device)
        self.logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
        self.rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
        self.dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
        self.values = torch.zeros((args.num_steps, args.num_envs)).to(device)
        self.steps = torch.zeros((args.num_steps, args.num_envs)).to(device)
        # self.candidate_plans = torch.zeros((args.num_steps, args.num_envs, self.n_actions), dtype=torch.bool).to(device)
        
    def init_params(self):
        # TRY NOT TO MODIFY: start the game
        self.global_step = 0
        self.pre_global_step = 0
        self.start_time = time.time()
        # self.next_obs = torch.Tensor(envs.reset()).to(device)
        # self.next_done = torch.zeros(args.num_envs).to(device)
        self.num_updates = args.total_timesteps // args.batch_size 
        self.num_critic_warm_up_updates = args.critic_warm_up_steps // args.batch_size
        self.is_warmup = True
        self.step = 0
        self.update_idx = 1
        
    def critic(self, message, return_value=True):
        
        obs = obs2text(message)
        done = message["done"]
        reward = message["reward"]
        
        if isinstance(obs, str) or isinstance(obs, dict):
            obs = [obs]
        
        self.obs[self.step] = obs
        self.dones[self.step] = done
        self.rewards[self.step] = torch.tensor(reward).to(device)
        self.steps[self.step] = torch.tensor(self.step).to(device)
        self.step = (self.step + 1) % args.num_steps
        
        
        with torch.no_grad():
            if return_value:
                action, logprob, _, value, probs = self.agent.get_action_and_value(obs)
                self.values[self.step] = value.flatten()
            else:
                action, logprob, _, _, probs = self.agent.get_action_and_value(obs, return_value=False)
        self.actions[self.step] = action
        self.logprobs[self.step] = logprob
        
        if return_value:
            return action.item() # Conventional usage
        else:
            return action.item(), probs # For inferring with reflexion
        
    def bootstrap_value(self, obs, done):
        # bootstrap value if not done
        with torch.no_grad():
            next_value = self.agent.get_value(obs).reshape(1, -1)
            advantages = torch.zeros_like(self.rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - self.dones[t + 1]
                    nextvalues = self.values[t + 1]

                discount = torch.pow(args.gamma, self.steps[t])
                delta = self.rewards[t] + discount * nextvalues * nextnonterminal - self.values[t]
                advantages[t] = lastgaelam = delta + discount * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + self.values
        
        self.advantages = advantages
        self.returns = returns
        return advantages, returns
        
    def update(self):
        # print(f"step: {self.step}")
        # print(f"num_steps: {args.num_steps}")
        # if (self.step + 1) % args.num_steps == 0:
        if self.step == 0:
            # update
            # print(f"1. UPDATE")
            self.global_step += args.num_envs * args.num_steps
            logging.info("=======================================$Training=======================================")
            logging.info(f"global steps: {self.global_step}, rollout steps: {self.step}, updates: {self.update_idx}")
            logging.info("=======================================$Training=======================================")
        else:
            # do not update
            # print(f"DO not UPDATE")
            return
        if self.update_idx >= self.num_updates + 1 + self.num_critic_warm_up_updates:
            print(f"UPDATE should end at this point")
            logging.info("=======================================================================================")
            logging.info("The training steps have reached the set upper limit. Training should end at this point.")
            logging.info("=======================================================================================")
            return
        
        # print(f"2. UPDATE")
        if self.is_warmup and self.update_idx > self.num_critic_warm_up_updates:
            self.is_warmup = False

        # Annealing the rate if instructed to do so.
        if args.anneal_lr and not self.is_warmup:
            frac = 1.0 - (self.update_idx - 1.0 - self.num_critic_warm_up_updates) / self.num_updates
            self.policy_optimizer.param_groups[0]["lr"] = frac * args.policy_learning_rate
            self.value_optimizer.param_groups[0]["lr"] = frac * args.value_learning_rate

        
        # bootstrap value if not done
        self.bootstrap_value(self.obs[self.step - 1], self.dones[self.step - 1])
        
        # flatten the batch
        # b_obs = self.obs.reshape((-1,) + envs.single_observation_space.shape)
        b_obs = np.array([o for obs in self.obs for o in obs])
        # print(f"b_obs: {b_obs.shape}")
        b_logprobs = self.logprobs.reshape(-1)
        b_actions = self.actions.reshape((-1, ))
        # print(f"b_actions: {b_actions.shape}")
        b_advantages = self.advantages.reshape(-1)
        b_returns = self.returns.reshape(-1)
        b_values = self.values.reshape(-1)
        # b_candidate_plans = self.candidate_plans.reshape(-1, self.n_actions)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        kl_explode = False
        policy_update_steps = 0
        pg_loss = torch.tensor(0)
        entropy_loss = torch.tensor(0)
        old_approx_kl = torch.tensor(0)
        approx_kl = torch.tensor(0)
        total_approx_kl = torch.tensor(0)
        
        # pdb.set_trace()
        for epoch in range(args.update_epochs):
            if kl_explode:
                break
            #update value
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.value_minibatch_size):
                end = start + args.value_minibatch_size
                mb_inds = b_inds[start:end]
                # pdb.set_trace()
                newvalue = self.agent.get_value(b_obs[mb_inds])

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                loss = v_loss * args.vf_coef

                self.value_optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), args.max_grad_norm)
                self.value_optimizer.step()
                
                # del newvalue, v_loss_unclipped, v_clipped, v_loss_clipped, v_loss_max, v_loss, loss
                # torch.cuda.empty_cache()
            
            if self.is_warmup:
                continue
            
            # delete some tensor for preventing OOM
            # del newvalue, v_loss_unclipped, v_clipped, v_loss_clipped, v_loss_max, v_loss, loss

            self.policy_optimizer.zero_grad()            
            #update policy
            for start in range(0, args.batch_size, args.policy_minibatch_size):
                # torch.cuda.empty_cache()
                if policy_update_steps % args.gradient_checkpointing_steps == 0:
                    total_approx_kl = 0
                policy_update_steps += 1
                end = start + args.policy_minibatch_size
                mb_inds = b_inds[start:end]
                # pdb.set_trace()
                _, newlogprob, entropy, newvalue, _ = self.agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds], self.is_warmup, return_value = False)

                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()
                
                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    total_approx_kl += approx_kl / args.gradient_checkpointing_steps
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]


                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss
                loss /= args.gradient_checkpointing_steps
                
                loss.backward()
                
                if policy_update_steps % args.gradient_checkpointing_steps == 0:
                    if args.target_kl is not None:
                        if total_approx_kl > args.target_kl:
                            self.policy_optimizer.zero_grad()
                            kl_explode = True
                            policy_update_steps -= args.gradient_checkpointing_steps
                            break                    
                    
                    nn.utils.clip_grad_norm_(self.agent.parameters(), args.max_grad_norm)
                    self.policy_optimizer.step()
                    self.policy_optimizer.zero_grad()    


        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
        
        if len(clipfracs) == 0:
            num_clipfracs = 0
        else:
            num_clipfracs = np.mean(clipfracs)

        self.update_idx += 1
        
        # TRY NOT TO MODIFY: record rewards for plotting purposes
        self.writer.add_scalar("charts/policy_learning_rate", self.policy_optimizer.param_groups[0]["lr"], self.global_step)
        self.writer.add_scalar("charts/value_learning_rate", self.value_optimizer.param_groups[0]["lr"], self.global_step)
        self.writer.add_scalar("losses/value_loss", v_loss.item(), self.global_step)
        self.writer.add_scalar("losses/policy_loss", pg_loss.item(), self.global_step)
        self.writer.add_scalar("losses/entropy", entropy_loss.item(), self.global_step)
        self.writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), self.global_step)
        self.writer.add_scalar("losses/approx_kl", approx_kl.item(), self.global_step)
        self.writer.add_scalar("losses/total_approx_kl", total_approx_kl.item(), self.global_step)
        self.writer.add_scalar("losses/policy_update_times", policy_update_steps // args.gradient_checkpointing_steps, self.global_step)
        self.writer.add_scalar("losses/clipfrac", num_clipfracs, self.global_step)
        self.writer.add_scalar("losses/explained_variance", explained_var, self.global_step)
        print("SPS:", self.global_step, (time.time() - self.start_time))
        self.writer.add_scalar("charts/SPS", self.global_step / (time.time() - self.start_time), self.global_step)
        
        # my own params
        self.writer.add_scalar("params/step", self.step, self.global_step)
        self.writer.add_scalar("params/update_idx", self.update_idx, self.global_step)
        
        print(f"global_step//100: {self.global_step // 100}, pre_global_step//100: {self.pre_global_step // 100}")
        if self.global_step // 100 != self.pre_global_step // 100: 
            # self.agent.save(self.global_step // 1000, f"{args.record_path}/{self.run_name}/{args.save_path}")
            self.agent.save(self.global_step // 100, os.path.join(self.record_path, args.save_path))
            self.pre_global_step = self.global_step
        
        # store skills
        # self.store_skill()
        
    # self.agent.save(self.global_step // 10000 + 1, f"{args.record_path}/{run_name}/{args.save_path}")

# envs.close()
# writer.close()