# Utiulity fucntions for MW
import time

import gym
import metaworld  # modified mw env with additional gym wrappers
import numpy as np
from metaworld.data.dataset import read_trajs


def get_mw_env_and_data(dataset_path, discount, obs_type='state', res=128, max_episode_steps = np.inf):
    """ Assume ML1."""

    env_name_list = metaworld.ML1.ENV_NAMES  # list of MW ENVS
    matches = [ l for l in env_name_list if l in dataset_path ]
    if len(matches)>0:
        env_name = matches[0]
    else:
        raise ValueError("Cannot find matched MW env.")


    # We modify the MW env to stop at the goal and shift the reward by -10, so
    # the reward is all non-positive and the reward at the goal is zero.
    MAX_steps_at_goal = 1
    stop_at_goal = True
    train_distrib = False
    normalize_reward = lambda reward: (reward-10.0)/10.0

    # Create gym env
    res = (res, res)
    env = metaworld.mw_gym_make(
        env_name,
        stop_at_goal=stop_at_goal,
        steps_at_goal=MAX_steps_at_goal,
        train_distrib=train_distrib,
        goal_cost_reward=False,
        cam_height=res[0],
        cam_width=res[1],
        depth=False,
        )
    class RewardNormalizer(gym.RewardWrapper):
        def reward(self, reward):
            return normalize_reward(reward)
    env = RewardNormalizer(env)

    # Get dataset
    rawdata = qlearning_dataset(dataset_path, "original", discount, max_episode_steps)
    # Change to original for consistency of reward.
    data = dict(
        actions=rawdata['actions'],
        rewards=normalize_reward(rawdata['rewards'][:,0]),
        terminals=rawdata['terminals'][:,0],
        timeouts=rawdata['timeouts'][:,0],
        successes=rawdata['successes'][:,0],
        n_step_rewards=rawdata['n_step_rewards'][:,0],
        n_step_discount_rewards=rawdata['n_step_discount_rewards'][:,0]
    )

    # Wrap the observation
    if obs_type=='state':
        data['observations'] = rawdata['states']
        data['next_observations']  = rawdata['next_states']
        class ObservationWrapper(gym.ObservationWrapper):
            def observation(self, obs):
                return obs['full_state']
        env = ObservationWrapper(env)

    elif obs_type=='image':
        data['observations'] = rawdata['observations']
        data['next_observations']  = rawdata['next_observations']
        class ObservationWrapper(gym.ObservationWrapper):
            def observation(self, obs):
                return obs['image']
        env = ObservationWrapper(env)
    else:
        raise NotImplementedError

    return env, data, env_name


def qlearning_dataset(dataset_path, reward_type, discount, n_steps_for_reward=np.inf):
    env_metadata, all_trajs = read_trajs(dataset_path, reward_type)
    print("Concatenating all trajectories into one big dataset...")
    start_extend = time.time()

    # Add timeouts info.
    for traj in all_trajs:
        traj_len = len(traj['rewards'])
        traj['timeouts'] = np.full((traj_len,1), False, dtype=bool)
        traj['terminals'][traj['successes'],0]=True
        if not traj['successes'][-1]:
            traj['timeouts'][-2:] = True

    # Add n_step_reward
    def add_n_step_reward(rewards):
        if len(rewards)<= n_steps_for_reward:
            return np.array(rewards).sum()
        else:
            return np.array(rewards)[0:n_steps_for_reward].sum()
    # Add n_step_reward
    def add_n_step_discount_reward(rewards):
        if len(rewards)<= n_steps_for_reward:
            return (np.array(rewards)[:,0]*discount**np.arange(len(rewards))).sum()
        else:
            return (np.array(rewards)[0:n_steps_for_reward,0] *discount**np.arange(n_steps_for_reward) ).sum()
    result = {
        'states': np.vstack([traj['states'][:-1] for traj in all_trajs]),
        'next_states': np.vstack([traj['states'][1:] for traj in all_trajs]),
        'proprio_states': np.vstack([traj['proprio_states'][:-1] for traj in all_trajs]),
        'next_proprio_states': np.vstack([traj['proprio_states'][1:] for traj in all_trajs]),
        'observations': np.vstack([traj['observations'][:-1] for traj in all_trajs]),
        'next_observations': np.vstack([traj['observations'][1:] for traj in all_trajs]),
        'depths': None,
        'next_depths': None,
        'actions': np.vstack([traj['actions'][:-1] for traj in all_trajs]),
        'rewards': np.vstack([traj['rewards'][:-1] for traj in all_trajs]),
        'terminals': np.vstack([traj['terminals'][:-1] for traj in all_trajs]),
        'timeouts': np.vstack([traj['timeouts'][:-1] for traj in all_trajs]),
        'successes': np.vstack([traj['successes'][-1] for traj in all_trajs]),
        'n_step_rewards': np.vstack([add_n_step_reward(traj['rewards'][:-1] ) for traj in all_trajs]),
        'n_step_discount_rewards': np.vstack([add_n_step_discount_reward(traj['rewards'][:-1] ) for traj in all_trajs])
    }

    end_extend = time.time()
    print(f'Total time concatenating the dataset: {end_extend - start_extend}s.')

    return result
