import os, argparse, yaml, torch
import gymnasium as gym

import numpy as np
import torch.nn as nn

from mani_skill.utils import gym_utils
from typing import Union, Any
from stable_baselines3 import SAC, PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from transformers import AutoTokenizer, AutoModel
from rlkit.envs.wrappers import NormalizedBoxEnv
from mani_skill.vector.wrappers.sb3 import ManiSkillSB3VectorEnv
from mani_skill.utils.geometry import rotation_conversions
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize


class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, feature_dim):
        super().__init__(observation_space, feature_dim)
        self.embed_extractor = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, int(self._features_dim/2)),
            nn.ReLU(),
        )

    def forward(self, observations):
        ori_obs = observations[:, : int(self._features_dim/2)]
        emb_obs = observations[:, int(self._features_dim/2) :]
        extracted_embed = self.embed_extractor(emb_obs)
        new_obs = torch.cat((ori_obs, extracted_embed), dim=1)
        return new_obs

class Tokenizer():
    def __init__(self, text_plan):
        self.text_plan = text_plan
        self.device = 'cuda:0'
        self.sts_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased")
        self.sts_model = AutoModel.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased").to(self.device)
    def embed_single(self, message):
        inputs = self.sts_tokenizer(message, padding=True, truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            embeddings = self.sts_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
        return embeddings
    def embed_plans(self):
        embeddings = []
        for msg in self.text_plan:
            embedding = self.embed_single(msg).detach().cpu().numpy()
            embeddings.append(embedding)
        return embeddings

class ContinuousTaskWrapper(gym.Wrapper):
    def __init__(self, env, max_episode_steps, embed_goal, llm, gamma):
        super().__init__(env)
        self._elapsed_steps = 0
        self.pre_obs = None
        self._max_episode_steps = max_episode_steps

        self.embed_goal = embed_goal
        self.llm = llm
        self.gammaa = gamma

        if self.llm in ['lm', 'lm_pr', 'ellm']:
            self.embedding_low = np.full(768, -np.inf, dtype=np.float32)
            self.embedding_high = np.full(768, +np.inf, dtype=np.float32)
            self.observation_space = gym.spaces.Box(
                np.hstack(
                    (
                        self.observation_space.low[0],
                        self.embedding_low,
                    )
                ),
                np.hstack(
                    (
                        self.observation_space.high[0],
                        self.embedding_high,
                    )
                )
            )
            self.prev_progress = 0
            self.progress = 0
        else:
            self.observation_space = gym.spaces.Box(self.observation_space.low[0], self.observation_space.high[0])

    def progress_function(self):
        assert (0)

    def reset(self, *, seed: Union[int, None] = None, options: Union[dict[str, Any], None] = None):
        self._elapsed_steps = 0
        self.pre_obs, reset_info = self.env.reset()
        self.pre_obs = self.pre_obs.cpu().numpy().squeeze(0)
        if self.llm in ['lm', 'lm_pr', 'ellm']:
            this_embed_goal = self.embed_goal[0].squeeze(0)
            self.pre_obs = np.concatenate((self.pre_obs, this_embed_goal), axis=0)
            self.prev_progress, _ = self.progress_function()
        elif self.llm == 'lm_re':
            self.prev_progress, _ = self.progress_function()
        self.pre_idx = 0
        return self.pre_obs, reset_info



    def t2r_reward_function(self, action, obs):
        assert (0)

    def convert_tensors_to_nparray(self, data):
        for key, value in data.items():
            if isinstance(value, torch.Tensor):
                data[key] = value.cpu().numpy()
            elif isinstance(value, dict):
                self.convert_tensors_to_nparray(value)
        return data

    def step(self, action):
        ob, rew, done, truncated, info = super().step(action)
        ob = ob.cpu().numpy().squeeze(0)
        rew = rew.cpu().numpy().squeeze()
        truncated = truncated.cpu().numpy().squeeze()
        info = self.convert_tensors_to_nparray(info)

        if self.llm=='lm':
            self.progress, idx = self.progress_function()  
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = np.clip(10 * (self.gammaa * self.progress - self.prev_progress), -1, 1)
            self.prev_progress = self.progress
        elif self.llm == 'ellm':
            _, idx = self.progress_function()
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = 5.0 if idx > self.pre_idx else 0.0
            self.pre_idx = idx
        elif self.llm == 'lm_pr':
            _, idx = self.progress_function()
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
        elif self.llm == 'lm_re':
            self.progress, _ = self.progress_function()
            rew = np.clip(10 * (self.gammaa * self.progress - self.prev_progress), -1, 1)
            self.prev_progress = self.progress
        elif self.llm == 't2r':
            rew = self.t2r_reward_function(action)

        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info["TimeLimit.truncated"] = True
        else:
            done = False
            info["TimeLimit.truncated"] = False

        info["is_success"] = info["success"]
        if info["is_success"]:
            done = True
            if self.llm in ['lm', 'ellm', 'lm_re']:
                rew = rew + 50
        return ob, rew, done, truncated, info
    
def plan_generator(env_id):
    if env_id in ['PullCube-v1', 'PushCube-v1']:
        text_plan =  [
        "Move to the object",  # Step 1: Move to Door
        "Move to the goal"  # Step 2: Push Door
        ]
    elif env_id in ['PickCube-v1']:
        text_plan = [
        "Move to object",  # Subtask 1: Move to object
        "Grasp object",  # Subtask 2: Grasp object
        "Move to goal",  # Subtask 3: Move to goal
        "Task Complete"  # Task completion
        ]
    else:
        text_plan = [
        "Move to object",
        "Grasp object",
        "Rotate the object",
        "Move to goal",
        ]

    return text_plan


def make_env(env_id, max_episode_steps: int = None, record_dir: str = None, embed_goal: list = None, llm: bool = None, gamma: float = None):
    def _init() -> gym.Env:
        env = gym.make(env_id, obs_mode="state", control_mode="pd_ee_delta_pos")
        max_episode_steps = gym_utils.find_max_episode_steps_value(env)
        env = NormalizedBoxEnv(env)

        if max_episode_steps is not None:
            env = ContinuousTaskWrapper(env=env, max_episode_steps=max_episode_steps, embed_goal=embed_goal, llm=llm, gamma=gamma)
        return env

    return _init

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=314159)
    parser.add_argument('--env_id', type=str, default="LiftPegUpright-v1")
    parser.add_argument('--train_num', type=int, default=1)
    parser.add_argument('--eval_num', type=int, default=1)
    parser.add_argument('--eval_freq', type=int, default=12800)
    parser.add_argument('--max_episode_steps', type=int, default=100)
    parser.add_argument('--rollout_steps', type=int, default=3200)
    parser.add_argument('--train_max_steps', type=int, default=10_000_000)
    parser.add_argument('--n_epochs', type=str, default=10)
    parser.add_argument('--n_steps', type=str, default=2048)
    parser.add_argument('--eval_seed', type=int, default=1)
    parser.add_argument('--our', type=str, default="lm")
    parser.add_argument('--gamma', type=float, default=0.8)
    args = parser.parse_args()
    algo_name = "Oracle"
    policy_kwargs = dict(net_arch=[256, 256, 256])
    env_name = args.env_id
    training_id = ""
    function_id = "-v2"
    text_plan = plan_generator(env_id=env_name)
    embed_goal = []
    if args.our in ['lm', 'lm_pr', 'ellm']:
        if args.our == 'lm':
            algo_name = "prm"
        elif args.our =='lm_pr':
            algo_name = 'lmpr'
        elif args.our == 'ellm':
            algo_name = 'ellm'
        text_encoder = Tokenizer(text_plan)
        embed_goal = text_encoder.embed_plans()
        temp_env = gym.make(env_name, num_envs=1)
        temp_env = ManiSkillSB3VectorEnv(temp_env)
        env_ori_obs_dim = temp_env.observation_space.shape[0]
        temp_env.close()
        policy_kwargs['features_extractor_class'] = CustomFeatureExtractor
        policy_kwargs['features_extractor_kwargs'] = dict(feature_dim=env_ori_obs_dim*2)  
        # loading corresponding progress functions
        reward_path = "./progress_functions/" + env_name + function_id + ".py" 
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['progress_function']
        ContinuousTaskWrapper.progress_function = new_function
    elif args.our == 't2r':
        algo_name = "t2r"
        reward_path = "./t2r_reward_code/"+ env_name +"/specific.py"
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['compute_dense_reward']
        ContinuousTaskWrapper.t2r_reward_function = new_function
    elif args.our == 'lm_re':
        algo_name = "lmre"
        reward_path = "./progress_functions/" + env_name + function_id + ".py"  
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['progress_function']
        ContinuousTaskWrapper.progress_function = new_function

    eval_path = f"./logs/evaluation/{env_name}/{algo_name}-{training_id}/{args.seed}"
    # set up eval environment
    eval_env = SubprocVecEnv([make_env(env_name, embed_goal=embed_goal, llm=args.our, gamma=args.gamma) for i in range(args.eval_num)])
    eval_env = VecMonitor(eval_env)
    eval_env.seed(args.seed)
    eval_env.reset()

    # set up training environment
    env = SubprocVecEnv([make_env(env_name, max_episode_steps=args.max_episode_steps, embed_goal=embed_goal, llm=args.our, gamma=args.gamma) for i in range(args.train_num)])
    env = VecMonitor(env)
    env.seed(args.seed)
    obs = env.reset()

    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=f"./logs/evaluation/{env_name}/{algo_name}-{training_id}/{args.seed}",
        log_path=f"./logs/evaluation/{env_name}/{algo_name}-{training_id}/{args.seed}",
        eval_freq=args.eval_freq // args.train_num,
        deterministic=True,
        render=False,
        n_eval_episodes=10
    )
    set_random_seed(args.seed)

    model = PPO(
        policy="MlpPolicy",
        env=env,
        verbose=1,
        gamma=0.8,
        n_epochs=15,
        learning_rate=0.0002,
        n_steps=args.rollout_steps // args.train_num,
        batch_size=400,
        target_kl=0.1,
        gae_lambda=0.9,
        policy_kwargs=policy_kwargs,
        tensorboard_log=f"./logs/training/{env_name}/{algo_name}-{training_id}",
        seed=int(args.seed),
    )

    model.learn(
        total_timesteps=args.train_max_steps,
        tb_log_name=f"{args.seed}",
        callback=[eval_callback]
        )
    # save model
    model.save(f"./saved_models/{algo_name}-{env_name}-{args.seed}")

