import os, argparse, yaml, torch, warnings

import numpy as np
import torch.nn as nn
import gymnasium as gym
import metaworld.envs.mujoco.env_dict as _env_dict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM
from rlkit.envs.wrappers import NormalizedBoxEnv
import warnings
warnings.filterwarnings("ignore", message=".*Box*")


def evaluate_policy(
    model: "type_aliases.PolicyPredictor",
    env: Union[gym.Env, VecEnv],
    n_eval_episodes: int = 10,
    deterministic: bool = True,
    render: bool = False,
    callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
    reward_threshold: Optional[float] = None,
    return_episode_rewards: bool = False,
    warn: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
    is_monitor_wrapped = False
    # Avoid circular import
    from stable_baselines3.common.monitor import Monitor

    if not isinstance(env, VecEnv):
        env = DummyVecEnv([lambda: env])  # type: ignore[list-item, return-value]

    is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]

    if not is_monitor_wrapped and warn:
        warnings.warn(
            "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
            "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
            "Consider wrapping environment first with ``Monitor`` wrapper.",
            UserWarning,
        )

    n_envs = env.num_envs
    episode_rewards = []
    episode_lengths = []
    episode_success = []

    episode_counts = np.zeros(n_envs, dtype="int")
    # Divides episodes among different sub environments in the vector as evenly as possible
    episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")

    current_rewards = np.zeros(n_envs)
    current_lengths = np.zeros(n_envs, dtype="int")
    current_success = np.zeros(n_envs, dtype='int')
    observations = env.reset()
    states = None
    episode_starts = np.ones((env.num_envs,), dtype=bool)
    while (episode_counts < episode_count_targets).any():
        actions, states = model.predict(
            observations,  # type: ignore[arg-type]
            state=states,
            episode_start=episode_starts,
            deterministic=deterministic,
        )
        new_observations, rewards, dones, infos = env.step(actions)
        current_rewards += rewards
        current_lengths += 1
        for i in range(n_envs):
            if episode_counts[i] < episode_count_targets[i]:
                # unpack values so that the callback can access the local variables
                reward = rewards[i]
                done = dones[i]
                info = infos[i]
                episode_starts[i] = done

                if callback is not None:
                    callback(locals(), globals())

                if dones[i]:
                    if is_monitor_wrapped:
                        # Atari wrapper can send a "done" signal when
                        # the agent loses a life, but it does not correspond
                        # to the true end of episode
                        if "episode" in info.keys():
                            # Do not trust "done" with episode endings.
                            # Monitor wrapper includes "episode" key in info if environment
                            # has been wrapped with it. Use those rewards instead.
                            episode_rewards.append(info["episode"]["r"])
                            episode_lengths.append(info["episode"]["l"])
                            episode_success.append(info['is_success'])
                            # Only increment at the real end of an episode
                            episode_counts[i] += 1
                    else:
                        episode_rewards.append(current_rewards[i])
                        episode_lengths.append(current_lengths[i])
                        episode_counts[i] += 1
                    current_rewards[i] = 0
                    current_lengths[i] = 0


        observations = new_observations

        if render:
            env.render()

    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    mean_success_rate = np.mean(episode_success)
    if reward_threshold is not None:
        assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
    if return_episode_rewards:
        return episode_rewards, episode_lengths
    return mean_reward, std_reward, mean_success_rate


class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, feature_dim):
        super().__init__(observation_space, feature_dim)
        self.embed_extractor = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 39),
            nn.ReLU(),
        )

    def forward(self, observations):
        ori_obs = observations[:, :39]
        emb_obs = observations[:, 39:]
        extracted_embed = self.embed_extractor(emb_obs)
        new_obs = torch.cat((ori_obs, extracted_embed), dim=1)
        return new_obs

class Tokenizer():
    def __init__(self, text_plan):
        self.text_plan = text_plan
        self.device = 'cuda:0'
        self.sts_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased")
        self.sts_model = AutoModel.from_pretrained("princeton-nlp/unsup-simcse-bert-base-uncased").to(self.device)
    def embed_single(self, message):
        inputs = self.sts_tokenizer(message, padding=True, truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            embeddings = self.sts_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
        return embeddings
    def embed_plans(self):
        embeddings = []
        for msg in self.text_plan:
            embedding = self.embed_single(msg).detach().cpu().numpy()
            embeddings.append(embedding)
        return embeddings

class ContinuousTaskWrapper(gym.Wrapper):
    def __init__(self, env, max_episode_steps, embed_goal, llm):
        super().__init__(env)
        self._elapsed_steps = 0
        self.pre_obs = None
        self._max_episode_steps = max_episode_steps

        self.embed_goal = embed_goal
        self.llm = llm
        if self.llm in ['lm', 'ellm']:
            # if our method, refactor the obs space
            self.embedding_low = np.full(768, -np.inf, dtype=np.float32)
            self.embedding_high = np.full(768, +np.inf, dtype=np.float32)
            self.observation_space = gym.spaces.Box(
                np.hstack(
                    (
                        self.observation_space.low,
                        self.embedding_low,
                    )
                ),
                np.hstack(
                    (
                        self.observation_space.high,
                        self.embedding_high,
                    )
                )
            )

    def reset(self, *, seed: Union[int, None] = None, options: Union[dict[str, Any], None] = None):
        self._elapsed_steps = 0
        self.pre_obs, reset_info = super().reset()
        if self.llm in ['lm', 'lm_pr', 'ellm']:
            # if our method, refactor obs
            this_embed_goal = self.embed_goal[0].squeeze(0)
            self.pre_obs = np.concatenate((self.pre_obs, this_embed_goal), axis=0)
        self.pre_idx = 0
        return self.pre_obs, reset_info

    def progress_function(self, obs, action):
        assert (0)

    def t2r_reward_function(self, action, obs):
        assert (0)

    def step(self, action):
        ob, rew, done, truncated, info = super().step(action)
        if self.llm == 'lm':
            rew, idx = self.progress_function(ob, action) 
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
        elif self.llm == 'ellm':
            _, idx = self.progress_function(ob, action)
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
            rew = 5.0 if idx > self.pre_idx else 0.0
            self.pre_idx = idx
        elif self.llm == 'lm_pr':
            _, idx = self.progress_function(ob, action) 
            this_embed_goal = self.embed_goal[idx].squeeze(0)
            ob = np.concatenate((ob, this_embed_goal), axis=0)
        elif self.llm == 'lmre':
            rew, _ = self.progress_function(ob, action)
        elif self.llm == 't2r':
            rew = self.t2r_reward_function(action, ob)

        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info["TimeLimit.truncated"] = True
        else:
            done = False
            info["TimeLimit.truncated"] = False

        info["is_success"] = info["success"]
        if info["success"]:
            done = True
            if self.llm in ['lm', 'ellm', 'lm_re']:
                rew = rew + 50

        return ob, rew, done, truncated, info
def plan_generator(env_id):
    if env_id in ['button-press', 'door-close', 'drawer-close', 'window-open', 'faucet-open', 'handle-press', 'coffee-button', 'button-press-wall']:
        text_plan =  [
        "Move to the object",  # Step 1: Move to Door
        "Move to the goal"  # Step 2: Push Door
        ]
    else:
        text_plan = [
        "Move to object",  # Subtask 1: Move to object
        "Grasp object",  # Subtask 2: Grasp object
        "Move to goal",  # Subtask 3: Move to goal
        "Task Complete"  # Task completion
        ]
    return text_plan

def make_env(env_id, max_episode_steps: int = None, record_dir: str = None, embed_goal: list = None, llm: bool = None):
    def _init() -> gym.Env:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id]
        env = env_cls()
        env._freeze_rand_vec = False
        env._set_task_called = True
        env = NormalizedBoxEnv(env)
        env = ContinuousTaskWrapper(env, env.max_path_length, embed_goal, llm)
        return env

    return _init



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_id', type=str, default=None)
    parser.add_argument('--eval_num', type=int, default=100)
    args = parser.parse_args()
    policy_kwargs = dict(net_arch=[256, 256, 256])
    train_env_name = args.env_id
    function_id = "-v2"
    seeds = ['1024', '1234','314159', '42', '2986']
    algos = ['lm', 'ellm', 'lmre', 't2r', 'sac']
    eval_env_names = []
    if train_env_name == 'faucet-open':
        eval_env_names = ['faucet-open', 'button-press', 'drawer-close', 'button-press-wall', 'coffee-button']
    elif train_env_name == 'pick-place':
        eval_env_names = ['pick-place', 'drawer-close', 'sweep-into', 'handle-press', 'pick-place-wall']  
    for algo in algos:
        for eval_env_name in eval_env_names:
            print(f"eval {algo} {train_env_name} model in {eval_env_name}\n")
            text_plan = plan_generator(eval_env_name)
            embed_goal = []
            if algo in ['lm', 'ellm']:
                text_encoder = Tokenizer(text_plan)
                embed_goal = text_encoder.embed_plans()
                policy_kwargs['features_extractor_class'] = CustomFeatureExtractor
                policy_kwargs['features_extractor_kwargs'] = dict(feature_dim=78)
                reward_path = "./progress_functions/" + eval_env_name + function_id + ".py" 
                with open(reward_path, "r") as f:
                    reward_code_str = f.read()
                namespace = {**globals()}
                exec(reward_code_str, namespace)
                new_function = namespace['progress_function']
                ContinuousTaskWrapper.progress_function = new_function

            elif algo == 't2r':
                reward_path = "./t2r_reward_code/"+ 'button-press' +"-v2/specific.py"
                with open(reward_path, "r") as f:
                    reward_code_str = f.read()
                namespace = {**globals()}
                exec(reward_code_str, namespace)
                new_function = namespace['compute_dense_reward']
                ContinuousTaskWrapper.t2r_reward_function = new_function

            elif algo == 'lmre':
                reward_path = "./progress_functions/" + eval_env_name + function_id + ".py" 
                with open(reward_path, "r") as f:
                    reward_code_str = f.read()
                namespace = {**globals()}
                exec(reward_code_str, namespace)
                new_function = namespace['progress_function']
                ContinuousTaskWrapper.progress_function = new_function

            success_rates = []
            eval_env = SubprocVecEnv([make_env(eval_env_name + "-v2-goal-observable", record_dir="logs/videos",
                                               embed_goal=embed_goal, llm=algo) for i in range(args.eval_num)])
            eval_env = VecMonitor(eval_env)
            for seed in seeds:
                eval_env.seed(int(seed))
                eval_env.reset()

                model = SAC.load(
                    path=f"./saved_models/{algo}-{train_env_name}-{seed}",
                    env=eval_env,
                )

                mean_reward, std_reward, success_rate = evaluate_policy(model, eval_env, n_eval_episodes=100)
                success_rates.append(success_rate)

                print(f"{seed}, {mean_reward}+{std_reward}, {success_rate}")
            print(f"final success rate {np.mean(success_rates)} + {np.std(success_rates)}")
            eval_env.close()


