import copy
import numpy as np
from src import d4rl_utils
from jaxrl_m.evaluation import (
    EpisodeMonitor,
)

# from src.binary_datasets import BinaryDataset
from jaxrl_m.dataset import ReplayBuffer

KITCHEN_VISION_GOAL_STATE = [
    -2.3403780e00,
    -1.3053924e00,
    1.1021180e00,
    -1.8613019e00,
    1.5087037e-01,
    1.7687809e00,
    1.2525779e00,
    2.9698312e-02,
    3.0899283e-02,
    3.9908718e-04,
    4.9550228e-05,
    -1.9946630e-05,
    2.7519276e-05,
    4.8786267e-05,
    3.2835731e-05,
    2.6504624e-05,
    3.8422750e-05,
    -6.9888681e-01,
    -5.0150707e-02,
    3.4855098e-01,
    -9.8701166e-03,
    -7.6958216e-03,
    -8.0031347e-01,
    -1.9142720e-01,
    7.2064394e-01,
    1.6191028e00,
    1.0021452e00,
    -3.2998802e-04,
    3.7205056e-05,
    5.3616576e-02,
]


def get_env_and_dataset(FLAGS, kitchen_full_obs=True, goal_conditioned=True, hrl=False):
    aux_env = {}
    goal_info = {}
    high_dataset = None
    if "antmaze" in FLAGS.env_name:
        env_name = FLAGS.env_name

        if "ultra" in FLAGS.env_name:
            import d4rl_ext
            import gym

            env = gym.make(env_name)
            env = EpisodeMonitor(env)
        else:
            env = d4rl_utils.make_env(env_name)

        replay_buffer = d4rl_utils.get_replay_buffer(
            env,
            FLAGS.env_name,
            goal_conditioned=False,
        )

        gc_dataset = d4rl_utils.get_dataset(
            env, FLAGS.env_name, goal_conditioned=goal_conditioned
        )
        gc_dataset = gc_dataset.copy({"rewards": gc_dataset["rewards"] - 1.0})

        if hrl:
            high_dataset = d4rl_utils.get_replay_buffer(
                env,
                FLAGS.env_name,
                goal_conditioned=False,
            )
            # high_dataset = high_dataset.copy({"rewards": gc_dataset["rewards"] - 1.0})

        env.render(mode="rgb_array", width=200, height=200)
        if "large" in FLAGS.env_name:
            env.viewer.cam.lookat[0] = 18
            env.viewer.cam.lookat[1] = 12
            env.viewer.cam.distance = 50
            env.viewer.cam.elevation = -90
        elif "ultra" in FLAGS.env_name:
            env.viewer.cam.lookat[0] = 26
            env.viewer.cam.lookat[1] = 18
            env.viewer.cam.distance = 70
            env.viewer.cam.elevation = -90

    elif "kitchen" in FLAGS.env_name:
        if "visual" in FLAGS.env_name:
            # from d4rl_utils import kitchen_render

            orig_env_name = FLAGS.env_name.split("visual-")[1]
            env = d4rl_utils.make_env(orig_env_name)
            dataset_dict = dict(
                np.load(f"data/d4rl_kitchen_rendered/{orig_env_name}.npz")
            )
            gc_dataset = d4rl_utils.get_dataset(
                env, FLAGS.env_name, dataset=dataset_dict, filter_terminals=True
            )
            replay_buffer = d4rl_utils.get_replay_buffer(
                env, FLAGS.env_name, dataset=dataset_dict, filter_terminals=True
            )

            state = env.reset()

            goal_state = KITCHEN_VISION_GOAL_STATE
            # Set the goal state for kitchen-mixed-v0
            goal_state[9:] = state[39:]
            env.sim.set_state(np.concatenate([goal_state, env.init_qvel]))
            env.sim.forward()
            goal_info = {
                "ob": d4rl_utils.kitchen_render(env).astype(np.float32),
            }
            env.reset()
            if hrl:
                raise NotImplementedError
        else:
            orig_env_name = FLAGS.env_name
            env = d4rl_utils.make_env(orig_env_name)
            replay_buffer = d4rl_utils.get_replay_buffer(
                env, FLAGS.env_name, filter_terminals=True
            )
            buf_size = replay_buffer.size

            gc_dataset = d4rl_utils.get_dataset(
                env, FLAGS.env_name, filter_terminals=True
            )
            gc_size = gc_dataset.size

            if hrl:
                # assert FLAGS.high_reward_scale == 1
                high_dataset = d4rl_utils.get_replay_buffer(
                    env, FLAGS.env_name, filter_terminals=True, goal_conditioned=False
                )
                high_size = high_dataset.size
                high_dataset = high_dataset.copy(
                    {
                        "observations": high_dataset["observations"][:, :30],
                        "next_observations": high_dataset["next_observations"][:, :30],
                    }
                )
                high_dataset.size = high_dataset.pointer = high_size
            if not kitchen_full_obs:
                replay_buffer = replay_buffer.copy(
                    {
                        "observations": replay_buffer["observations"][:, :30],
                        "next_observations": replay_buffer["next_observations"][:, :30],
                    }
                )
                replay_buffer.size = replay_buffer.pointer = buf_size

                gc_dataset = gc_dataset.copy(
                    {
                        "observations": gc_dataset["observations"][:, :30],
                        "next_observations": gc_dataset["next_observations"][:, :30],
                    }
                )
                gc_dataset.size = gc_size

    else:
        raise NotImplementedError

    eval_env = copy.copy(env)

    output = {
        "env": env,
        "eval_env": eval_env,
        "replay_buffer": replay_buffer,
        "dataset": gc_dataset,
        "aux_env": aux_env,
        "goal_info": goal_info,
        "high_dataset": high_dataset,
    }

    return output
