import os, argparse, yaml, torch

import numpy as np
import torch.nn as nn
import gymnasium as gym
import metaworld.envs.mujoco.env_dict as _env_dict
from typing import Union, Any
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM
from rlkit.envs.wrappers import NormalizedBoxEnv


class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, feature_dim):
        super().__init__(observation_space, feature_dim)
        self.embed_extractor = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 39),
            nn.ReLU(),
        )

    def forward(self, observations):
        ori_obs = observations[:, :39]
        emb_obs = observations[:, 39:]
        extracted_embed = self.embed_extractor(emb_obs)
        new_obs = torch.cat((ori_obs, extracted_embed), dim=1)
        return new_obs

class Tokenizer():
    def __init__(self, text_plan):
        self.text_plan = text_plan
        self.device = 'cuda:0'
        self.sts_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased")
        self.sts_model = AutoModel.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased").to(self.device)
    def embed_single(self, message):
        inputs = self.sts_tokenizer(message, padding=True, truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            embeddings = self.sts_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
        return embeddings
    def embed_plans(self):
        embeddings = []
        for msg in self.text_plan:
            embedding = self.embed_single(msg).detach().cpu().numpy()
            embeddings.append(embedding)
        return embeddings

class ContinuousTaskWrapper(gym.Wrapper):
    def __init__(self, env, max_episode_steps, embed_goal, llm, gamma):
        super().__init__(env)
        self._elapsed_steps = 0
        self.pre_obs = None
        self._max_episode_steps = max_episode_steps
        self.embed_goal = embed_goal
        self.llm = llm
        self.gammaa = gamma
        if self.llm in ['lm', 'lm_pr', 'ellm', 'lm_ns']:
            self.embedding_low = np.full(768, -np.inf, dtype=np.float32)
            self.embedding_high = np.full(768, +np.inf, dtype=np.float32)
            self.observation_space = gym.spaces.Box(
                np.hstack(
                    (
                        self.observation_space.low,
                        self.embedding_low,
                    )
                ),
                np.hstack(
                    (
                        self.observation_space.high,
                        self.embedding_high,
                    )
                )
            )
        self.prev_progress = 0
        self.progress = 0

    def reset(self, *, seed: Union[int, None] = None, options: Union[dict[str, Any], None] = None):
        self._elapsed_steps = 0
        self.pre_obs, reset_info = super().reset()
        if self.llm in ['lm', 'lm_pr', 'ellm', 'lm_ns']:
            this_embed_goal = self.embed_goal[0].squeeze(0)
            self.pre_obs = np.concatenate((self.pre_obs, this_embed_goal), axis=0)
        self.prev_progress, _ = self.progress_function(self.pre_obs)
        self.pre_idx = 0
        return self.pre_obs, reset_info

    def progress_function(self, obs, action):
        assert (0)

    def t2r_reward_function(self, action, obs):
        assert (0)

    def step(self, action):
        ob, rew, done, truncated, info = super().step(action)
        if self.llm=='lm':
            self.progress, idx = self.progress_function(ob)
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = np.clip(10 * (self.gammaa * self.progress - self.prev_progress), -1, 1)
            self.prev_progress = self.progress
        elif self.llm == 'ellm':
            _, idx = self.progress_function(ob)
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = 5.0 if idx > self.pre_idx else 0.0
            self.pre_idx = idx
        elif self.llm == 'lm_pr':
            _, idx = self.progress_function(ob)
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
        elif self.llm == 'lm_re':
            self.progress, _ = self.progress_function(ob)
            rew = np.clip(10 * (self.gammaa * self.progress - self.prev_progress), -1, 1)
            self.prev_progress = self.progress
        elif self.llm == 't2r':
            rew = self.t2r_reward_function(action, ob)
        elif self.llm == 'lm_ns':
            self.prev_progress, idx = self.progress_function(ob)
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = np.clip(10 * (self.gammaa * self.progress - self.prev_progress), -1, 1)
            self.prev_progress = self.progress

        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info["TimeLimit.truncated"] = True
        else:
            done = False
            info["TimeLimit.truncated"] = False

        info["is_success"] = info["success"]
        if info["success"]:
            done = True
            if self.llm in ['lm', 'ellm', 'lm_re']:
                rew = rew + 50

        return ob, rew, done, truncated, info
    
def plan_generator(env_id):
    if env_id in ['button-press', 'door-close', 'drawer-close', 'window-open', 'faucet-open']:
        text_plan =  [
        "Move to the object",  # Step 1: Move to Door
        "Move to the goal"  # Step 2: Push Door
        ]
    else:
        text_plan = [
        "Move to object",  # Subtask 1: Move to object
        "Grasp object",  # Subtask 2: Grasp object
        "Move to goal",  # Subtask 3: Move to goal
        "Task Complete"  # Task completion
        ]
    return text_plan


def make_env(env_id, max_episode_steps: int = None, record_dir: str = None, embed_goal: list = None, llm: bool = None, gamma: float = None):
    def _init() -> gym.Env:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id]
        env = env_cls()
        env._freeze_rand_vec = False
        env._set_task_called = True
        env = NormalizedBoxEnv(env)
        env = ContinuousTaskWrapper(env, max_episode_steps, embed_goal, llm, gamma)
        return env

    return _init

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=314159)
    parser.add_argument('--env_id', type=str, default="pick-place")
    parser.add_argument('--train_num', type=int, default=8)
    parser.add_argument('--eval_num', type=int, default=2)
    parser.add_argument('--eval_freq', type=int, default=16_000)
    parser.add_argument('--max_episode_steps', type=int, default=500)
    parser.add_argument('--train_max_steps', type=int, default=1_000_000)
    parser.add_argument('--our', type=str, default='lm')
    parser.add_argument('--gamma', type=float, default=0.99)
    args = parser.parse_args()
    algo_name = "Oracle"
    policy_kwargs = dict(net_arch=[256, 256, 256])

    env_name = args.env_id
    training_id = ""
    function_id = "-v2"
    text_plan = plan_generator(env_id=env_name)
    embed_goal = []
    if args.our in ['lm', 'lm_pr', 'ellm', 'lm_ns']:
        # change name
        if args.our == 'lm':
            algo_name = "prm"
        elif args.our =='lm_pr':
            algo_name = 'lmpr'
        elif args.our == 'ellm':
            algo_name = 'ellm'
        elif args.our == 'lm_ns':
            algo_name = 'lmns'
        # embedding the goals
        text_encoder = Tokenizer(text_plan)
        embed_goal = text_encoder.embed_plans()
        # loading policy kwargs
        policy_kwargs['features_extractor_class'] = CustomFeatureExtractor
        policy_kwargs['features_extractor_kwargs'] = dict(feature_dim=78)
        # loading corresponding progress functions
        reward_path = "./progress_functions/" + env_name + function_id + ".py" # change to None to run original version.
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['progress_function']
        ContinuousTaskWrapper.progress_function = new_function

    elif args.our == 't2r':
        algo_name = "t2r"
        reward_path = "./t2r_reward_code/"+ env_name +"-v2/specific.py"
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['compute_dense_reward']
        ContinuousTaskWrapper.t2r_reward_function = new_function

    elif args.our == 'lm_re':
        algo_name = "lmre"
        reward_path = "./progress_functions/" + env_name + function_id + ".py"  # change to None to run original version.
        with open(reward_path, "r") as f:
            reward_code_str = f.read()
        namespace = {**globals()}
        exec(reward_code_str, namespace)
        new_function = namespace['progress_function']
        ContinuousTaskWrapper.progress_function = new_function

    # set up eval environment
    eval_env = SubprocVecEnv([make_env(env_name+"-v2-goal-observable", max_episode_steps=args.max_episode_steps,embed_goal=embed_goal, llm=args.our, gamma=args.gamma) for i in range(args.eval_num)])
    eval_env = VecMonitor(eval_env)
    eval_env.seed(args.seed)
    eval_env.reset()

    # set up training environment
    env = SubprocVecEnv([make_env(env_name+"-v2-goal-observable", max_episode_steps=args.max_episode_steps, embed_goal=embed_goal, llm=args.our, gamma=args.gamma) for i in range(args.train_num)])
    env = VecMonitor(env)
    env.seed(args.seed)
    obs = env.reset()

    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=f"./logs/evaluation/{env_name}/{algo_name}-{training_id}/{args.seed}",
        log_path=f"./logs/evaluation/{env_name}/{algo_name}-{training_id}/{args.seed}",
        eval_freq=args.eval_freq // args.train_num,
        deterministic=True,
        render=False,
        n_eval_episodes=10
    )
    set_random_seed(args.seed)

    model = SAC(
        policy="MlpPolicy",
        env=env,
        verbose=1,
        seed=int(args.seed),
        tensorboard_log=f"./logs/training/{env_name}/{algo_name}-{training_id}",
        gamma=0.99,
        target_update_interval=2,
        learning_rate=0.0003,
        train_freq=1,
        tau=0.005,
        learning_starts=4000,
        batch_size=512,
        ent_coef='auto_0.1',
        policy_kwargs=policy_kwargs
    )


    model.learn(
        total_timesteps=args.train_max_steps,
        tb_log_name=f"{args.seed}",
        callback=[eval_callback]
        )
    # save model
    model.save(f"./saved_models/{algo_name}-{env_name}-{args.seed}")

