import os
import os.path as osp
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler

from large_rl.policy.agent import Agent, LARGE_NEG
from large_rl.commons.launcher import launch_embedding, launch_env, launch_agent
from large_rl.commons.utils import logging, mean_dict, save_mp4, VideoFrameBuffer, scale_number
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.replay_buffer import ReplayBufferRecSim
from large_rl.commons.scheduler import AnnealingSchedule
from large_rl.commons.args import get_all_args


WOLP_ARDDPG = ["wolp", "wolp-sac", "arddpg_cont"]


def _compute_pairwise_dist(action, dict_embedding, args):
    if args["WOLP_cascade_list_len"] > 1:
        _key = "query"
    else:
        _key = "topk_act"  # dual-critic only has one "query" so we can't compute the pairwise distance
    _num_envs = action[_key].shape[0]
    if _key == "topk_act":
        topk_act = action[_key].reshape(_num_envs, args["WOLP_cascade_list_len"] * args["WOLP_topK"])
        if args['env_name'] != 'reacher':
            emb = dict_embedding["item"].get(index=topk_act, if_np=True)
    else:
        emb = action[_key]
    _dist_list = np.zeros(_num_envs)
    for _i in range(_num_envs):
        # ====== EC =========
        _d = euclidean_distances(X=emb[_i, ...])  # list-len x list-len
        _d = _d.sum()  # Aggregate; get the total distributedness of candidate-actions or queries
        _d /= 2.0  # because it's the symmetric matrix
        _d = scale_number(x=_d, to_min=0.0, to_max=1.0, from_min=0.0, from_max=args["_max_l2_dist"])  # H: spec, L: div
        _dist_list[_i] = _d * args["WOLP_pairwise_distance_bonus_coef"]
    return _dist_list


def run(args):
    print(args)
    
    if args["agent_save"] and not os.path.exists(args["agent_save_path"]):
        os.makedirs(args["agent_save_path"])

    # === START: Prep before training
    env_name = args["env_name"]
    if env_name == 'mine':
        if args["save_dir"] != "":
            args['video_saving_dir'] = os.path.join(args['mw_video_dir'], args["save_dir"])
        else:
            args['video_saving_dir'] = os.path.join(args['mw_video_dir'], 'seed{:05d}debug2022'.format(args["seed"]))
    elif env_name.startswith("mujoco"):
        if args["save_dir"] != "":
            args['video_saving_dir'] = os.path.join(args['reacher_video_dir'], args["save_dir"])
        else:
            args['video_saving_dir'] = os.path.join(args['reacher_video_dir'],
                                                    'seed{:05d}debug2022'.format(args["seed"]))

    if env_name.lower().startswith("recsim"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["recsim_dim_embed"]) * 2)
    env = launch_env(args=args)
    eval_env = launch_env(args=args)
    args["env_max_action"] = 1.
    if env_name == "mine" and args['mw_show_action_embeddings']:
        env.show_action_embeddings()
    if env_name.startswith("mujoco"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
        args["env_max_action"] = float(env.action_space[0].high[0])
    if env_name == "mine":
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["mw_action_dim"]) * 2)
    args["TD3_policy_noise"] *= args["env_max_action"]
    args["TD3_noise_clip"] *= args["env_max_action"]
    logging("======== Env IS READY ========")

    # Expand noisy dimensions
    if args["env_dim_extra"] > 0 and args["env_name"] != "recsim-data":
        _emb = np.random.random(size=(env.act_embedding.shape[0], args["env_dim_extra"]))
        __emb = MinMaxScaler(feature_range=(-0.01, 0.01)).fit_transform(env.act_embedding)
        # _emb += np.random.random(size=env.act_embedding.shape)
        _emb = np.hstack([__emb, _emb])
        if args["env_act_emb_tSNE"]:
            logging("======== START: t-SNE on Act Emb ========")
            _emb = TSNE(n_components=__emb.shape[-1],
                        perplexity=3,
                        # init="pca",
                        random_state=0,
                        method="exact",
                        n_iter=1000,
                        n_jobs=-1).fit_transform(_emb)
            logging("======== DONE: t-SNE on Act Emb ========")
    else:
        _emb = env.act_embedding

    if args["recsim_if_tsne_embed"]:  # for Dual-tSNE
        logging("======== START: t-SNE on Act Emb ========")
        _emb = TSNE(n_components=_emb.shape[-1],
                    perplexity=3,
                    # init="pca",
                    random_state=0,
                    method="exact",
                    n_iter=1000,
                    n_jobs=-1).fit_transform(_emb)
        logging("======== DONE: t-SNE on Act Emb ========")

    if env_name in ["recsim", "mine", "recsim-data"]:
        dict_embedding = launch_embedding(args=args)
        dict_embedding["item"].load(embedding=_emb)
        dict_embedding["task"].load(embedding=env.task_embedding)
    elif env_name.startswith("mujoco"):
        dict_embedding = {'item': None, 'task': None}
    else:
        raise NotImplementedError
    agent = launch_agent(args=args, env=env)
    if args['agent_load']:
        agent.load_model(epoch=args['agent_load_epoch'])

    _sch_cr = AnnealingSchedule(
        start=args["epsilon_start_cr"], end=args["epsilon_end_cr"], decay_steps=args["decay_steps_cr"])
    _sch_act = AnnealingSchedule(
        start=args["epsilon_start_act"], end=args["epsilon_end_act"], decay_steps=args["decay_steps_act"],
        _delay=args["delayed_actor_training"])
    scheduler_dict = {"critic": _sch_cr, "actor": _sch_act}
    replay_buffer = ReplayBufferRecSim(args=args)
    # === END: Prep before training

    # === for each episode
    args["global_ts"], epoch = 0, 0
    args["train_ep_return"] = deque(maxlen=10)
    args["train_ep_dist"] = deque(maxlen=10)
    args["train_ep_sr"] = deque(maxlen=10)
    args["train_ep_mining_return"] = deque(maxlen=10)
    args["train_return_vec"] = np.zeros(args["num_envs"])
    args["train_dist_vec"] = np.zeros(args["num_envs"])
    args["train_if_dist_bonus"] = args["WOLP_if_pairwise_distance_bonus"] and args["agent_type"].lower() in WOLP_ARDDPG
    args["eval_if_dist_bonus"] = args["WOLP_if_pairwise_distance_bonus"] or (
            args["if_visualise_agent"] and args["agent_type"].lower() in WOLP_ARDDPG and (
            args["WOLP_cascade_list_len"] * args["WOLP_topK"]) > 1)
    if args["per_train_ts"] is not None:
        args["num_epochs"] = (args["total_ts"] // args["per_train_ts"]) // args["num_envs"]
    else:
        args["per_train_ts"] = (args["total_ts"] // args["num_epochs"]) // args["num_envs"]
    obs = None

    # ==== Warm start with random interactions ====
    if args["agent_type"] != "random":
        min_replay_buffer_size = min(args["buffer_size"] // 10, args["min_replay_buffer_size"])
        replay_buffer = _fill_replay_buffer(args=args,
                                            env=env,
                                            agent=agent,
                                            min_replay_buffer_size=min_replay_buffer_size,
                                            replay_buffer=replay_buffer,
                                            dict_embedding=dict_embedding)
        logging(f"Filled the replay buffer: {len(replay_buffer)} / {args['buffer_size']}")
        args["global_ts"] += len(replay_buffer)

    while args["global_ts"] < args["total_ts"]:
        # For logging purpose including random agent
        _step = args["global_ts"] if args["agent_type"] != "random" else args["per_train_ts"] * epoch * args["num_envs"]

        """ === Evaluation === """
        if ((epoch + 1) % args["eval_freq"] == 0) or (epoch == 0):  # Run Eval once at the beginning of the training
            # if ((epoch + 1) % args["eval_freq"] == 0):  # Run Eval once at the beginning of the training
            args["if_visualise_now"] = (((epoch + 1) % args["visualise_agent_freq_epoch"]) == 0 or (epoch == 0)) \
                                       and args["if_visualise_agent"]
            ## add video saving for mining world
            eval_video_save = False
            if args['env_name'].startswith("mujoco") and args['reacher_save_video'] and \
                    ((epoch + 1) % args['video_save_frequency'] == 0):
                eval_video_save = True
            if args['env_name'] == "mine" and args['mw_test_save_video'] and \
                    ((epoch + 1) % args['video_save_frequency'] == 0 or epoch == 0):
                eval_video_save = True

            eval_metrics = _eval(agent=agent, env=eval_env, dict_embedding=dict_embedding,
                                 eval_video_save=eval_video_save, timestep=args["global_ts"], args=args)

            img_list = eval_metrics.pop("img_list", None)
            act_embed_img_path_list = eval_metrics.pop("act_embed_img_path_list", None)
            logging(f"[Eval  ] {_step}/{args['total_ts']} | {eval_metrics}")

            if args["agent_type"] == "random":
                args["global_ts"] = _step

        if args["agent_type"] != "random":
            """ === Roll-out === """
            ## need to change
            train_video_save = False
            if args['env_name'] == "mine" and args['mw_train_save_video']:
                if (epoch + 1) % args["video_save_frequency"] == 0:
                    train_video_save = True
            train_metrics, args["global_ts"], replay_buffer, obs = _train(
                args=args, env=env, agent=agent, scheduler_dict=scheduler_dict, dict_embedding=dict_embedding,
                replay_buffer=replay_buffer, obs=obs, train_video_save=train_video_save)
            logging(f"[Train ] {args['global_ts']}/{args['total_ts']} | {train_metrics}")

            """ === Update === """
            results = list()
            num_update_steps = args['per_train_ts'] * args['num_envs'] if args['if_train_every_ts'] \
                else args['num_updates']
            for _i in range(num_update_steps):
                if args["WOLP_dual_exp_if_ignore"] and args["agent_type"].startswith("wolp"):
                    # import pudb; pudb.start()
                    o, a, r, no, d = replay_buffer.sample(args["batch_size"], False)
                    o, no = torch.tensor(o, device=args["device"]), torch.tensor(no, device=args["device"])
                    if not args["WOLP_if_pairwise_distance_bonus"]: r = r[:, None]
                    agent.update(o, a, r, no, d, dict_embedding["item"], if_sel=True, if_ret=False)

                    if len(replay_buffer._storage_ret) > args["batch_size"]:
                        o, a, r, no, d = replay_buffer.sample(args["batch_size"], True)
                        o, no = torch.tensor(o, device=args["device"]), torch.tensor(no, device=args["device"])
                        if not args["WOLP_if_pairwise_distance_bonus"]: r = r[:, None]
                        agent.update(o, a, r, no, d, dict_embedding["item"], if_sel=False, if_ret=True)

                else:  # Previous update pipeline
                    if args["WOLP_separate_update"]:
                        o, a, r, no, d = replay_buffer.sample(args["batch_size"], False) # No special retrieval
                        o, no = torch.tensor(o, device=args["device"]), torch.tensor(no, device=args["device"])
                        if not args["WOLP_if_pairwise_distance_bonus"]: r = r[:, None]
                        agent.update(o, a, r, no, d, dict_embedding["item"], if_sel=True, if_ret=False)
                        # Sample again
                        o, a, r, no, d = replay_buffer.sample(args["batch_size"], False) # No special retrieval
                        o, no = torch.tensor(o, device=args["device"]), torch.tensor(no, device=args["device"])
                        if not args["WOLP_if_pairwise_distance_bonus"]: r = r[:, None]
                        agent.update(o, a, r, no, d, dict_embedding["item"], if_sel=False, if_ret=True)
                    else:
                        o, a, r, no, d = replay_buffer.sample(args["batch_size"], False)
                        o, no = torch.tensor(o, device=args["device"]), torch.tensor(no, device=args["device"])
                        if not args["WOLP_if_pairwise_distance_bonus"]: r = r[:, None]
                        agent.update(o, a, r, no, d, dict_embedding["item"], if_sel=True, if_ret=True)

                if args["sync_every_update"]:
                    if num_update_steps > 1:
                        if _i % args["sync_freq"] == 0: agent.sync(tau=float(args["soft_update_tau"]))
                    else:
                        if (epoch == 0) or (epoch + 1) % args["sync_freq"] == 0: agent.sync(tau=float(args["soft_update_tau"]))
                else:
                    if (epoch == 0) or (epoch + 1) % args["sync_freq"] == 0:
                        agent.sync(tau= float(args["soft_update_tau"]))
                results.append(agent.res)

            """ === Visualise the results of updating === """
            results = mean_dict(_list_dict=results)
            logging(f"[Update] {args['global_ts']}/{args['total_ts']} | Avg. Loss: {results['loss']} ")
            # print(results["ar_value_loss"])
        epoch += 1

        """ === Save the agent model ==="""
        if args['agent_save'] and (epoch + 1) % args['agent_save_frequency'] == 0:
            print("saving agent model")
            agent.save_model(epoch=epoch)
    return None


def _modify_reward(agent, reward, action, dict_embedding):
    _dist_list = None
    # ============ Reward modification
    # Diversity bonus
    if agent._args["train_if_dist_bonus"]:
        _dist_list = _compute_pairwise_dist(action=action, dict_embedding=dict_embedding, args=args)
        reward = np.concatenate([reward[:, None], _dist_list[:, None]], axis=-1)
    else:
        reward = reward
    return reward, _dist_list


def _train(args, env, agent: Agent, scheduler_dict: dict, dict_embedding: dict, replay_buffer, obs, train_video_save,
           **kwargs):
    if obs is None:
        obs = env.reset()
    agent.reset()
    if args['env_name'] == "mine" or args["env_name"].startswith("mujoco"):
        env.train()
    agent.train()
    if args["env_name"] == 'mine':
        env.disable_render_info()

    if train_video_save and args["env_name"] == 'mine':
        info = None
        frame_buffer = VideoFrameBuffer(args=args, action_meaning_dict=env.true_action_meaning_dict())
        env.enable_render_info()
        # frame_buffer.append(env.render_frame(mode="rgb_array", info=info))
    num_ep = np.zeros(args["num_envs"])
    for t in range(args["per_train_ts"]):
        if not torch.is_tensor(obs):
            obs = torch.tensor(obs.astype(np.float32), device=args["device"])
        eps_cr = scheduler_dict["critic"].get_value(ts=args["global_ts"])
        eps_act = scheduler_dict["actor"].get_value(ts=args["global_ts"])
        epsilon = {"critic": eps_cr, "actor": eps_act}
        action = agent.select_action(obs=obs, act_embed_base=dict_embedding["item"], epsilon=epsilon)
        next_obs, reward, done, info = env.step(action["action"])
        if args["env_name"] == 'mine':
            replay_buffer._storage_mine_agent_pos.append(np.asarray(info[0]["agent_pos"]))  # temp

        """ === After One batch time-step === """
        args["global_ts"] += args["num_envs"]
        num_ep += done.astype(np.float32)
        if train_video_save and args["env_name"] == 'mine':
            frame_buffer.append(env.render_frame(mode="rgb_array", info=info), list_action=action["topk_act"])

        _r, _dist_list = _modify_reward(agent=agent, reward=reward, action=action, dict_embedding=dict_embedding)
        replay_buffer = _add_buffer(args, obs, action, replay_buffer, _r, next_obs, done)

        """ === Before the next time-step === """
        obs = next_obs
        args["train_return_vec"] += reward
        if args["train_if_dist_bonus"]: args["train_dist_vec"] += _dist_list
        if any(done):
            ids = np.arange(args["num_envs"])[done]  # get the index of dead env
            for _id in ids:
                agent.reset(id=_id)
                args["train_ep_return"].append(args["train_return_vec"][_id])
                args["train_return_vec"][_id] = 0.0  # empty the bucket
                if args["WOLP_if_pairwise_distance_bonus"]:
                    args["train_ep_dist"].append(args["train_dist_vec"][_id])
                    args["train_dist_vec"][_id] = 0.0  # empty the bucket
                if args["env_name"].lower() == "mine":
                    args["train_ep_sr"].append(float(info[_id]["ep_success"]))
                    args["train_ep_mining_return"].append(float(info[_id]["ep_mining_reward"]))

    res = {
        "ep_return": np.mean(args["train_ep_return"]),
        "epsilon": epsilon,
        "buffer_size": len(replay_buffer)
    }
    if args["WOLP_dual_exp_if_ignore"]: res["buffer_ret_size"] = len(replay_buffer._storage_ret)
    if train_video_save:
        if not frame_buffer.empty:
            save_name = '%s_%s' % (str(args["global_ts"]), 'train')
            save_dir = os.path.join(args['video_saving_dir'], 'train')
            frames = frame_buffer.gen_video()
            save_mp4(frames, save_dir, save_name, fps=args['vid_fps'], no_frame_drop=True)
            saved_file_name = '%s.mp4' % osp.join(save_dir, save_name)
            print('Rendered frames to %s' % saved_file_name)
            res["path_video"] = saved_file_name
        env.disable_render_info()

    if args["train_if_dist_bonus"]:
        res["ep_dist"] = np.mean(args["train_ep_dist"])
    if args["env_name"].lower() == "mine":
        res["success_rate"] = np.asarray(args["train_ep_sr"]).mean()
        res["ep_mining_reward"] = np.asarray(args["train_ep_mining_return"]).mean()
    return res, args["global_ts"], replay_buffer, obs


def _eval(args, env, agent: Agent, dict_embedding: dict, eval_video_save: bool, timestep: int):
    agent.eval()
    agent.reset()
    if args['env_name'] == "mine" or args["env_name"].startswith("mujoco"):
        env.eval()
    # Get the random user id for action visualisation
    _env_id = np.random.randint(low=0, high=args["num_envs"])
    act_embed_img_path_list = list()
    ep_return, ep_ts, ep_ctr, img_list, ep_mining_return = list(), list(), list(), list(), list()
    ep_dist, ep_can_multi_opt_bonus = list(), list()
    ep_candidate_Q_mean, total_0th_not_max_index, ep_max_Q_improvement, ep_max_Q_improvement_percent = list(), list(), list(), list()
    ep_if_valid = list()
    if args['env_name'] == 'mine':
        frame_buffer = VideoFrameBuffer(args=args, action_meaning_dict=env.true_action_meaning_dict())
    else:
        frame_buffer = list()
    eval_loop = args["eval_num_episodes"] // args["num_envs"]
    if args['do_naive_eval']:
        ep_return_naive = list()
        for ep in range(eval_loop):
            obs = env.reset()
            done_env_mask = np.asarray([False] * args["num_envs"])
            _ep_return_naive = np.zeros(args["num_envs"])
            while not all(done_env_mask):
                if not torch.is_tensor(obs):
                    obs = torch.tensor(obs.astype(np.float32), device=args["device"])
                action = agent.select_action(obs=obs, act_embed_base=dict_embedding["item"],
                                         epsilon={"actor": args["eval_epsilon_ac"], "critic": args["eval_epsilon_cr"]},
                                         naive_eval=True)
                next_obs, reward, _done, info = env.step(action["action"])
                reward[done_env_mask] = 0.0
                done_env_mask[_done] = True
                next_obs[done_env_mask] *= 0.0
                _ep_return_naive += reward
                obs = next_obs
            ep_return_naive.append(_ep_return_naive)

    for ep in range(eval_loop):
        obs = env.reset()
        info = None
        if_record = eval_video_save and ep == 0
        done_env_mask = np.asarray([False] * args["num_envs"])
        sr_mask = np.asarray([False] * args["num_envs"])
        if if_record and args["env_name"] == 'mine':
            if eval_video_save:
                env.enable_render_info()
                # frame_buffer.append(env.render_frame(mode="rgb_array", info=info))
            else:
                env.disable_render_info()
        elif args["env_name"].lower().startswith("mujoco"):
            _ep_if_valid = np.zeros(args["num_envs"])
        _ep_return, _ep_ts, _ts = np.zeros(args["num_envs"]), np.zeros(args["num_envs"]), 0
        _ep_candidate_Q_mean = np.zeros(args["num_envs"])
        _0th_not_max_index = np.zeros(args["num_envs"])
        _ep_max_Q_improvement = np.zeros(args["num_envs"])
        _ep_max_Q_improvement_percent = np.zeros(args["num_envs"])
        _ep_dist = np.zeros(args["num_envs"])
        _ep_can_multi_opt_bonus = np.zeros(args["num_envs"])
        ep_sr = 0
        ep_mining_reward = 0
        while not all(done_env_mask):
            if if_record:
                if not args['env_name'] == 'mine':
                    frame_buffer.append(env.render(mode="rgb_array"))
            if not torch.is_tensor(obs):
                obs = torch.tensor(obs.astype(np.float32), device=args["device"])
            action = agent.select_action(obs=obs, act_embed_base=dict_embedding["item"],
                                         epsilon={"actor": args["eval_epsilon_ac"], "critic": args["eval_epsilon_cr"]})
            next_obs, reward, _done, info = env.step(action["action"])

            if args["eval_if_dist_bonus"]:
                _dist_list = _compute_pairwise_dist(action=action, dict_embedding=dict_embedding, args=args)

            if args["agent_type"].startswith("wolp") and args["if_visualise_agent"] and args["env_dim_extra"] == 0:
                topk_act = action["topk_act"].reshape((obs.shape[0], args["WOLP_cascade_list_len"] * args["WOLP_topK"]))

                if args["env_name"].lower().startswith("recsim"):
                    topk_act = dict_embedding["item"].get(index=topk_act, if_np=True)
                can_multi_opt_bonus = env.check_candidate_set(candidate_set=topk_act)
                if args["env_name"].lower().startswith("recsim"):
                    can_multi_opt_bonus[reward != 1] = 0.0  # Only consider successful occurrence of multiple solutions
                _ep_can_multi_opt_bonus += can_multi_opt_bonus

            if args["env_name"].lower() == "mine":
                if any(_done):
                    for i in range(args["num_envs"]):
                        if _done[i] and not done_env_mask[i]:
                            ep_sr += float(info[i]["ep_success"])
                            ep_mining_reward += float(info[i]["ep_mining_reward"])
                        if _done[i] and info[i]["ep_success"]:
                            sr_mask[i] = True

                if if_record:
                    frame_buffer.append(frames=env.render_frame(mode="rgb_array", info=info),
                                        list_action=action["topk_act"],
                                        done_mask=done_env_mask)
            elif args["env_name"].lower().startswith("mujoco"):
                # Calculate the number of if_valid actions
                if "if_valid" in info[0].keys():
                    _ep_if_valid += np.asarray([info[i]["if_valid"] for i in range(args["num_envs"])])

            if any(_done):
                ids = np.arange(args["num_envs"])[_done]  # get the index of dead env
                for _id in ids:
                    agent.reset(id=_id)

            # Universal logic to conduct the eval w/h vector env
            reward[done_env_mask] = 0.0
            done_env_mask[_done] = True
            next_obs[done_env_mask] *= 0.0

            """ === After One batch time-step === """
            _ep_return += reward
            _0th_not_max_index += (action["max_index"] != 0).astype(np.float32)
            _ep_max_Q_improvement += action["max_Q_improvement"]
            _ep_max_Q_improvement_percent += action["max_Q_improvement_percent"]
            _ep_ts += (~done_env_mask).astype(np.float32)
            _ts += 1
            if args["agent_type"].startswith("wolp") or args["agent_type"].startswith('arddpg_cont'):
                _ep_candidate_Q_mean += agent._candidate_Q_mean
            if args["eval_if_dist_bonus"]: _ep_dist += _dist_list

            obs = next_obs  # Update obs for next step

        if args['env_name'] == 'mine':
            if if_record:
                frame_buffer.append_dark()

        ep_return.append(np.mean(_ep_return))
        ep_ts.append(np.mean(_ep_ts))
        # ep_0th_not_max_index.append(np.mean(_ep_0th_not_max_index))
        total_0th_not_max_index.append(np.mean(_0th_not_max_index / _ts))
        ep_candidate_Q_mean.append(np.mean(_ep_candidate_Q_mean))
        ep_max_Q_improvement.append(np.mean(_ep_max_Q_improvement))
        ep_max_Q_improvement_percent.append(np.mean(_ep_max_Q_improvement_percent / _ts))
        if args["eval_if_dist_bonus"]: ep_dist.append(np.mean(_ep_dist))
        if args['env_name'].lower() == "mine":
            ep_ctr.append(np.divide(ep_sr, args["num_envs"]))
            ep_mining_return.append(np.divide(ep_mining_reward, args["num_envs"]))
        elif args['env_name'].lower().startswith("recsim"):
            ep_ctr.append(np.mean(np.divide(_ep_return, _ep_ts, out=np.zeros_like(_ep_return), where=_ep_ts != 0)))
        elif args['env_name'].lower().startswith("mujoco"):
            ep_if_valid.append(np.mean(_ep_if_valid / _ts))

        if args["agent_type"].startswith("wolp") and args["if_visualise_agent"]:
            if args["env_name"].lower() == "mine":
                _ep_can_multi_opt_bonus[~sr_mask] = 0.0  # If ep is not successful, then we don't count!
            ep_can_multi_opt_bonus.append(np.mean(_ep_can_multi_opt_bonus))
        # ep_return.append(np.mean(_ep_return))
        # ep_ts.append(np.mean(_ep_ts))
        if args["eval_if_dist_bonus"]: ep_dist.append(np.mean(_ep_dist))

    """ === After All Evaluation === """
    out = {
        "ep_return": np.mean(ep_return),
        "ep_candidate_Q_mean": np.mean(ep_candidate_Q_mean),
        "ep_max_Q_improvement": np.mean(ep_max_Q_improvement),
        "ep_max_Q_improvement_percent": np.mean(ep_max_Q_improvement_percent),
        "0th_not_max_index": np.mean(total_0th_not_max_index),
        "time_steps": np.mean(ep_ts),
        "img_list": img_list,
        "act_embed_img_path_list": act_embed_img_path_list,
    }
    if args['do_naive_eval']:
        out["ep_return_naive"] = np.mean(ep_return_naive)
    if args['env_name'].lower() == "mine":
        out["success_rate"] = np.mean(ep_ctr)
        out["mining_return"] = np.mean(ep_mining_return)
    elif args['env_name'].lower().startswith("recsim"):
        out["ctr"] = np.mean(ep_ctr)
    elif args['env_name'].lower().startswith("mujoco"):
        out["if_valid"] = np.mean(ep_if_valid)

    if args["eval_if_dist_bonus"]: out["ep_dist"] = np.mean(ep_dist)

    if args["agent_type"].startswith("wolp") and args["if_visualise_agent"]:
        out["ep_can_multi_opt_bonus"] = np.mean(ep_can_multi_opt_bonus)

    """ === Save rendered video frames to file === """
    if args['env_name'] == 'mine':
        if not frame_buffer.empty:
            save_name = '%s_%s' % (str(timestep), 'eval')
            save_dir = os.path.join(args['video_saving_dir'], 'test')
            frames = frame_buffer.gen_video()
            save_mp4(frames, save_dir, save_name, fps=args['vid_fps'], no_frame_drop=True)
            saved_file_name = '%s.mp4' % osp.join(save_dir, save_name)
            print('Rendered frames to %s' % saved_file_name)
            out["path_video"] = saved_file_name
    else:
        if len(frame_buffer) > 0:
            save_name = '%s_%s' % (str(timestep), 'eval')
            save_dir = os.path.join(args['video_saving_dir'], 'test')
            save_mp4(frame_buffer, save_dir, save_name, fps=args['vid_fps'], no_frame_drop=True)
            saved_file_name = '%s.mp4' % osp.join(save_dir, save_name)
            print('Rendered frames to %s' % saved_file_name)
            out["path_video"] = saved_file_name
    agent.train()
    return out


def _add_buffer(args, obs, action, replay_buffer, reward, next_obs, done):
    if torch.is_tensor(obs):
        obs = obs.cpu().detach().numpy()
    if torch.is_tensor(next_obs):
        next_obs = next_obs.cpu().detach().numpy()
    for i in range(obs.shape[0]):
        a = action["action"][i, :]
        _if_retriever = False
        if args['env_name'].lower() in ["recsim", "mine", "recsim-data"]:
            if args["agent_type"].startswith("wolp"):
                if action["topk_act"][i, 0][0] > -1:
                    _if_retriever = True
                    _ind = np.where(action["topk_act"][i].flatten() == action["action"][i])[0]
                    if len(_ind) > 1: _ind = _ind[0]  # this is because of duplication is allowed
                    elif len(_ind) == 0: _ind = np.array([-1000])
                    if args["WOLP_if_ar"] or args["WOLP_if_joint_actor"] or args["WOLP_if_dual_critic"] or args[
                        "agent_type"] == "wolp-sac":
                        _a = np.ones((action["query"][i, :].shape[0], 1))  # num-queries x 1
                        a = np.hstack([_a * action["action"][i, :],  # num-queries x 1
                                       _a * _ind,  # num-queries x 1
                                       action["topk_act"][i, :],  # num-queries x topK
                                       action["query"][i, :],  # num-queries x dim_act
                                       _a * action["query_max"][i, :]])
                    else:
                        # 1 x (2 + list-len)
                        a = np.hstack([action["action"][i, :][:, None], _ind[:, None], action["topk_act"][i, :]])
        elif args['env_name'].startswith("mujoco"):
            if args['agent_type'].startswith('wolp'):
                if action["query"][i, 0][0] > LARGE_NEG:
                    _if_retriever = True
                a = np.hstack([action['action'][i, :], action['query'][i, :].flatten(),
                               action['query_max'][i, :]])
            elif args['agent_type'].startswith('arddpg_cont'):
                a = np.hstack([action['action'][i, :], action['query'][i, :].flatten(),
                               action['refinement_index'][i, :]])
        else:
            raise NotImplementedError

        replay_buffer.add(obs[i], a, reward[i], next_obs[i], done[i], if_retriever=_if_retriever)
    return replay_buffer


def _fill_replay_buffer(args, env, agent, min_replay_buffer_size: int, replay_buffer, dict_embedding):
    while len(replay_buffer) < min_replay_buffer_size:
        obs = env.reset()
        done = False

        while not np.all(done):
            obs = torch.tensor(obs.astype(np.float32), device=args["device"])
            action = agent.select_action(obs=obs,
                                         epsilon={"actor": 1.0, "critic": 1.0},
                                         act_embed_base=dict_embedding["item"])
            next_obs, reward, done, info = env.step(action["action"])
            _r, _ = _modify_reward(agent=agent, reward=reward, action=action, dict_embedding=dict_embedding)

            if (len(replay_buffer) + obs.shape[0]) > min_replay_buffer_size:
                _offset = (len(replay_buffer) + obs.shape[0]) - min_replay_buffer_size
                action = {k: v[:_offset] for k, v in action.items() if v is not None}
                obs, next_obs, reward, done = obs[:_offset], next_obs[:_offset], reward[:_offset], done[:_offset]
            replay_buffer = _add_buffer(args, obs, action, replay_buffer, _r, next_obs, done)

            if len(replay_buffer) >= min_replay_buffer_size:
                break
            else:
                obs = next_obs
    return replay_buffer


def main(args):
    if args.device == "cuda":
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if args.env_name == "mine" and args.mw_maxRoomSize < 10:
            torch.backends.cudnn.enabled = False

    if args.env_name == "recsim-data":
        args.recsim_data_dir = os.path.join(DATASET_PATH, args.recsim_data_dir)
        args.user_embedding_path = os.path.join(DATASET_PATH, args.recsim_data_dir, "user_attr.npy")
        args.item_embedding_path = os.path.join(DATASET_PATH, args.recsim_data_dir, "trained_weight/item.npy")
        args.save_dir = os.path.join(DATASET_PATH, args.save_dir)

    set_randomSeed(seed=args.seed)
    run(args=vars(args))
    logging("==================== END ====================")


if __name__ == '__main__':
    args = get_all_args()
    main(args=args)
