import numpy as np
import torch
import torch.nn.functional as F
import gym
import os

import metaworld
import metaworld.envs.mujoco.env_dict as _env_dict
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from gym.wrappers.time_limit import TimeLimit
from env.wrappers import NormalizedBoxEnv
import pickle as pkl


def make_metaworld_env(env_name, seed):
    env_name = env_name.replace("metaworld_", "")
    if env_name in _env_dict.ALL_V2_ENVIRONMENTS:
        env_cls = _env_dict.ALL_V2_ENVIRONMENTS[env_name]
    else:
        env_cls = _env_dict.ALL_V1_ENVIRONMENTS[env_name]

    env = env_cls()
    # print("partially observe", env._partially_observable) Ture
    # print("env._freeze_rand_vec", env._freeze_rand_vec) True
    env._partially_observable = False
    env._freeze_rand_vec = False
    env._set_task_called = True
    env.seed(seed)
    return TimeLimit(NormalizedBoxEnv(env), env.max_path_length)

def MetaWorld_mr_dataset(config):
    """
    MetaWorld medium-replay dataset from LiRE (Choi et al., 2024)
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if config.human == False:
        base_path = os.path.join(os.getcwd(), "dataset/MetaWorld_medium-replay/")
        env_name = config.env
        base_path += str(env_name.replace("metaworld_", ""))
        dataset = dict()
        for seed in range(3):
            path = base_path + f"/saved_replay_buffer_1000000_seed{seed}.pkl"
            with open(path, "rb") as f:
                load_dataset = pkl.load(f)

            for key in load_dataset.keys():
                load_dataset[key] = load_dataset[key][
                    : int(config.data_quality * 100_000)
                ]
            load_dataset["terminals"] = load_dataset["dones"][
                : int(config.data_quality * 100_000)
            ]
            load_dataset.pop("dones", None)

            for key in load_dataset.keys():
                if key not in dataset:
                    dataset[key] = load_dataset[key]
                else:
                    dataset[key] = np.concatenate(
                        (dataset[key], load_dataset[key]), axis=0
                    )
    elif config.human == True:
        base_path = os.path.join(os.getcwd(), "human_feedback/")
        base_path += f"{config.env}/dataset.pkl"
        with open(base_path, "rb") as f:
            dataset = pkl.load(f)
            dataset["observations"] = np.array(dataset["observations"])
            dataset["actions"] = np.array(dataset["actions"])
            dataset["next_observations"] = np.array(dataset["next_observations"])
            dataset["rewards"] = np.array(dataset["rewards"])
            dataset["terminals"] = np.array(dataset["dones"])

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []

    dataset["rewards"] = dataset["rewards"].reshape(-1)
    dataset["terminals"] = dataset["terminals"].reshape(-1)

    for i in range(N):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = dataset["next_observations"][i].astype(np.float32)
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i])
        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
    }


def MetaWorld_me_dataset(config):
    """
    MetaWorld medium-expert dataset following the approaches of IPL (Hejna & Sadigh, 2024) and LiRE (Choi et al., 2024) 
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    base_path = os.path.join(os.getcwd(), "dataset/MetaWorld_medium-expert/" + str(config.env).split("_")[1])
    load_dataset = np.load(os.path.join(base_path, "trajectory.npz"))
    dataset = {key: load_dataset[key] for key in load_dataset.keys()}
    N = dataset["rewards"].shape[0] // 500

    if config.data_quality * 100_000 >= dataset["rewards"].shape[0]:
        idx = np.arange(dataset["rewards"].shape[0]//500)
    else:
        # take trajectories proportional to the data quality
        n_expert = int(config.data_quality * 200 / 12)
        idx_expert = np.arange(n_expert)
        n_within_env = int(config.data_quality * 200 / 12)
        idx_within_env = np.arange(n_within_env) + int(N/12)
        n_random = int(config.data_quality * 200 / 3)
        idx_random = np.arange(n_random) + int(N/6)
        n_eps_greedy = int(config.data_quality * 200 / 3)
        idx_eps_greedy = np.arange(n_eps_greedy) + int(N/3)
        n_cross_env = int(config.data_quality * 200 / 6)
        idx_cross_env = np.arange(n_cross_env) + int(5*N/6)

        idx = np.concatenate((idx_expert, idx_within_env, idx_random, idx_eps_greedy, idx_cross_env), axis=0)
    
    state_dim = dataset["states"].shape[1]
    action_dim = dataset["actions"].shape[1]

    return {
        "observations": dataset["states"].astype(np.float32).reshape(-1,500,state_dim)[idx].reshape(-1,state_dim),
        "actions": dataset["actions"].astype(np.float32).reshape(-1,500,action_dim)[idx].reshape(-1,action_dim),
        "next_observations": dataset["next_states"].astype(np.float32).reshape(-1,500,state_dim)[idx].reshape(-1,state_dim),
        "rewards": dataset["rewards"].astype(np.float32).reshape(N,500,-1)[idx].reshape(-1),
        "terminals": dataset["dones"].astype(bool).reshape(N,500,-1)[idx].reshape(-1),
    }

def MetaWorld_mix_dataset(config, mix_ratio=0.5):
    """
    MetaWorld dataset that mixes random trajectories and expert trajectories
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    random_data_path = os.path.join(os.getcwd(), "dataset/MetaWorld_random_data/" + str(config.env).split("_")[1])
    expert_data_path = os.path.join(os.getcwd(), "dataset/MetaWorld_expert_data/" + str(config.env).split("_")[1])
    random_dataset = np.load(os.path.join(random_data_path, "trajectory.npz"))
    random_dataset = {key: random_dataset[key] for key in random_dataset.keys()}
    expert_dataset = np.load(os.path.join(expert_data_path, "trajectory.npz"))
    expert_dataset = {key: expert_dataset[key] for key in expert_dataset.keys()}

    N = random_dataset["rewards"].shape[0]
    assert expert_dataset["rewards"].shape[0] == N
    assert N % 500 == 0 
    N_traj = N // 500
    obs_dim = random_dataset["states"].shape[-1]
    action_dim = random_dataset["actions"].shape[-1]

    N_random = int(N_traj * mix_ratio)
    N_expert = N_traj - N_random

    idx_random = np.random.choice(N_traj, N_random, replace=False)
    idx_expert = np.random.choice(N_traj, N_expert, replace=False)

    obs = np.concatenate((
        random_dataset["states"].reshape(-1, 500, obs_dim)[idx_random],
        expert_dataset["states"].reshape(-1, 500, obs_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    next_obs = np.concatenate((
        random_dataset["next_states"].reshape(-1, 500, obs_dim)[idx_random],
        expert_dataset["next_states"].reshape(-1, 500, obs_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    actions = np.concatenate((
        random_dataset["actions"].reshape(-1, 500, action_dim)[idx_random],
        expert_dataset["actions"].reshape(-1, 500, action_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    rewards = np.concatenate((
        random_dataset["rewards"].reshape(-1, 500)[idx_random],
        expert_dataset["rewards"].reshape(-1, 500)[idx_expert],
    ), axis=0).astype(np.float32)
    terminals = np.concatenate((
        random_dataset["dones"].reshape(-1, 500)[idx_random],
        expert_dataset["dones"].reshape(-1, 500)[idx_expert],
    ), axis=0).astype(bool)

    return {
        "observations": obs.reshape(-1, obs_dim),
        "next_observations": next_obs.reshape(-1, obs_dim),
        "actions": actions.reshape(-1, action_dim),
        "rewards": rewards.reshape(-1),
        "terminals": terminals.reshape(-1),
    }

def MetaWorld_mixed_dataset(config, mix_ratio=0.5):
    """
    MetaWorld dataset that mixes random trajectories and expert trajectories
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    random_data_path = os.path.join(os.getcwd(), "dataset/MetaWorld_random_data/" + str(config.env).split("_")[1])
    expert_data_path = os.path.join(os.getcwd(), "dataset/MetaWorld_MoE_data/" + str(config.env).split("_")[1])
    random_dataset = np.load(os.path.join(random_data_path, "trajectory.npz"))
    random_dataset = {key: random_dataset[key] for key in random_dataset.keys()}
    expert_dataset = np.load(os.path.join(expert_data_path, "trajectory.npz"))
    expert_dataset = {key: expert_dataset[key] for key in expert_dataset.keys()}

    N = random_dataset["rewards"].shape[0]
    assert expert_dataset["rewards"].shape[0] == N
    assert N % 500 == 0 
    N_traj = N // 500
    obs_dim = random_dataset["states"].shape[-1]
    action_dim = random_dataset["actions"].shape[-1]

    N_random = int(N_traj * mix_ratio)
    N_expert = N_traj - N_random

    idx_random = np.random.choice(N_traj, N_random, replace=False)
    idx_expert = np.random.choice(N_traj, N_expert, replace=False)

    obs = np.concatenate((
        random_dataset["states"].reshape(-1, 500, obs_dim)[idx_random],
        expert_dataset["states"].reshape(-1, 500, obs_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    next_obs = np.concatenate((
        random_dataset["next_states"].reshape(-1, 500, obs_dim)[idx_random],
        expert_dataset["next_states"].reshape(-1, 500, obs_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    actions = np.concatenate((
        random_dataset["actions"].reshape(-1, 500, action_dim)[idx_random],
        expert_dataset["actions"].reshape(-1, 500, action_dim)[idx_expert],
    ), axis=0).astype(np.float32)
    rewards = np.concatenate((
        random_dataset["rewards"].reshape(-1, 500)[idx_random],
        expert_dataset["rewards"].reshape(-1, 500)[idx_expert],
    ), axis=0).astype(np.float32)
    terminals = np.concatenate((
        random_dataset["dones"].reshape(-1, 500)[idx_random],
        expert_dataset["dones"].reshape(-1, 500)[idx_expert],
    ), axis=0).astype(bool)

    return {
        "observations": obs.reshape(-1, obs_dim),
        "next_observations": next_obs.reshape(-1, obs_dim),
        "actions": actions.reshape(-1, action_dim),
        "rewards": rewards.reshape(-1),
        "terminals": terminals.reshape(-1),
    }