import os
import glob
import pickle
import gzip
import pdb
import numpy as np
from copy import deepcopy

def restore_pool(replay_pool, experiment_root, max_size, save_path=None, eval_pool=False, env=None, multitask_type=None, load_traj=False, relabel_balance_batch=False):
    if type(experiment_root) is not str:
        restore_pool_softlearning_multitask(replay_pool, experiment_root, max_size, save_path, multitask_type=multitask_type, relabel_balance_batch=relabel_balance_batch)
    else:
        assert 'd4rl' in experiment_root
        assert env is not None
        restore_pool_d4rl(replay_pool, experiment_root[5:], env, multitask_type=multitask_type, load_traj=load_traj)
    print('[ cuds/off_policy ] Replay pool has size: {}'.format(replay_pool.size))


def modify_data(env, data, multitask_type, relabel_balance_batch=False):
    if hasattr(env, "mujoco_envs"):
        env = env.mujoco_envs[0]
    if multitask_type == 'single_task_goal_condition':
        try:
            all_goals = env.get_all_goals(random=False)
        except:
            all_goals = env.unwrapped.get_all_goals(random=False)
    else:
        try:
            all_goals = env.get_all_goals(random=False)
        except:
            all_goals = env.unwrapped.get_all_goals(random=False)        
    all_goals = np.array(all_goals)
    num_goals = len(all_goals)
    one_hot_arr = np.eye(num_goals)
    total_size = data['rewards'].shape[0]
    buffer_partition_size = int(data['rewards'].shape[0] // num_goals)

    if multitask_type == 'multi-goal':
        # This is our undirected multi-task setting
        """Modify equal chunks of data to use differnt goals."""
        for idx in range(num_goals):
            next_states = data['next_observations'][idx * buffer_partition_size: min((idx+1) * buffer_partition_size, total_size)][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            data['rewards'][idx * buffer_partition_size: min((idx+1) * buffer_partition_size, total_size)] = new_rewards.astype(np.float32)
            data['terminals'][idx * buffer_partition_size: min((idx+1) * buffer_partition_size, total_size)] = new_rewards.astype(np.float32)

        # Handle any leftover transitions
        idx = num_goals
        next_states = data['next_observations'][idx * buffer_partition_size: total_size][:, :2]
        curr_goal = all_goals[-1]
        new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
        data['rewards'][idx * buffer_partition_size: total_size] = new_rewards.astype(np.float32)
        data['terminals'][idx * buffer_partition_size: total_size] = new_rewards.astype(np.float32)
        return data
    elif multitask_type == 'relabel-multi-goal':
        """Modify equal chunks of data to use different goals, but with context labeled"""
        current_observations = np.zeros((data['observations'].shape[0], data['observations'].shape[1] + num_goals))
        next_observations = np.zeros_like(current_observations)
        task_ids = np.zeros((data['observations'].shape[0],))

        for idx in range(num_goals):
            start_idx = idx * buffer_partition_size
            end_idx = min((idx+1) * buffer_partition_size, total_size)
            next_states = data['next_observations'][start_idx: end_idx][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            data['rewards'][start_idx: end_idx] = new_rewards.astype(np.float32)
            data['terminals'][start_idx: end_idx] = new_rewards.astype(np.float32)

            task_ids[start_idx:end_idx] = idx 
            
            # Append the goal idx to next state and curr state
            current_observations[start_idx: end_idx][:, :data['observations'].shape[1]] = data['observations'][start_idx: end_idx]
            current_observations[start_idx: end_idx][:, data['observations'].shape[1]:] = one_hot_arr[idx]
            next_observations[start_idx: end_idx][:, :data['next_observations'].shape[1]] = data['next_observations'][start_idx: end_idx]
            next_observations[start_idx: end_idx][:, data['observations'].shape[1]:] = one_hot_arr[idx]

        # Handle any leftover transitions
        idx = num_goals - 1
        next_states = data['next_observations'][idx * buffer_partition_size: total_size][:, :2]
        curr_goal = all_goals[-1]
        new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.55
        data['rewards'][idx * buffer_partition_size: total_size] = new_rewards.astype(np.float32)
        data['terminals'][idx * buffer_partition_size: total_size] = new_rewards.astype(np.float32)
        # Append goal idx to any leftover transitions
        current_observations[idx * buffer_partition_size: total_size][:, :data['observations'].shape[1]] = data['observations'][idx * buffer_partition_size: total_size]
        current_observations[idx * buffer_partition_size: total_size][:, data['observations'].shape[1]: ] = one_hot_arr[idx]
        next_observations[idx * buffer_partition_size: total_size][:, :data['next_observations'].shape[1]] = data['next_observations'][idx * buffer_partition_size: total_size]
        next_observations[idx * buffer_partition_size: total_size][:, data['observations'].shape[1]:] = one_hot_arr[idx]

        task_ids[idx + buffer_partition_size: total_size] = idx

        # Finally modify the transitions
        data['observations'] = current_observations
        data['next_observations'] = next_observations
        data['task_ids'] = task_ids
        return data
    elif multitask_type == 'multi-goal-all':
        """Replicate the buffer n times, and for each copy label it with a different goal"""
        orig_data = deepcopy(data)
        all_data = []
        for idx in range(num_goals):
            curr_dict = {}
            for key in orig_data:
                curr_dict[key] = orig_data[key]
            next_states = orig_data['next_observations'][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            curr_dict['rewards'] = new_rewards.astype(np.float32)
            curr_dict['terminals'] = new_rewards.astype(np.float32)
            all_data.append(curr_dict)

        final_data = {}
        for key in orig_data:
            final_data[key] = np.concatenate([x[key] for x in all_data], 0)
        return final_data

    elif multitask_type == 'hipi':
        orig_data = deepcopy(data)
        all_data = []

        orig_task_id_onehot = np.zeros((data['observations'].shape[0], num_goals))
        for idx in range(num_goals):
            orig_task_id_onehot[idx * buffer_partition_size:\
                 min((idx + 1) * buffer_partition_size, total_size), idx] = 1.0
        
        print (min(num_goals * buffer_partition_size, total_size), total_size)
        orig_task_id_onehot[
            min(num_goals * buffer_partition_size, total_size): total_size, num_goals - 1
        ] = 1.0
        
        for idx in range(num_goals):
            curr_dict = {}
            for key in orig_data:
                curr_dict[key] = orig_data[key]
            next_states = orig_data['next_observations'][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            curr_dict['rewards'] = new_rewards.astype(np.float32)
            curr_dict['terminals'] = new_rewards.astype(np.float32)

            # Append task index
            curr_dict['observations'] = np.concatenate(
                [curr_dict['observations'], orig_task_id_onehot], -1)
            curr_dict['next_observations'] = np.concatenate(
                [curr_dict['next_observations'], orig_task_id_onehot], -1)
            
            start_idx = idx * buffer_partition_size
            end_idx = min((idx+1) * buffer_partition_size, total_size)

            if relabel_balance_batch:
                relabel_masks = np.ones_like(curr_dict['rewards'])
                relabel_masks[start_idx:end_idx] = 0.0
                # Account for the edge case of overflowing transitions
                if idx == num_goals - 1:
                    relabel_masks[end_idx: total_size] = 0.0
                
                curr_dict['relabel_masks'] = relabel_masks
                curr_dict['original_task_id_onehot'] = orig_task_id_onehot

            all_data.append(curr_dict)

        final_data = {}
        for key in orig_data.keys():
            if key != 'rewards' and key != 'terminals':
                final_data[key] = all_data[0][key]
            else:
                # print (x.shape)
                final_data[key] = np.concatenate([x[key][:, None] for x in all_data], 1)
                final_data[key] = np.squeeze(final_data[key])
                assert final_data[key].shape[1] == num_goals, "Rewards and terminals don't have shape num_goals"

        return final_data

    elif multitask_type == 'relabel-all':
        """Replicate the buffer n-times, and for each copy label and concat a different goal"""
        orig_data = deepcopy(data)
        all_data = []

        orig_task_id_onehot = np.zeros((data['observations'].shape[0], num_goals))
        for idx in range(num_goals):
            orig_task_id_onehot[idx * buffer_partition_size:\
                 min((idx + 1) * buffer_partition_size, total_size), idx] = 1.0
        
        print (min(num_goals * buffer_partition_size, total_size), total_size)
        orig_task_id_onehot[
            min(num_goals * buffer_partition_size, total_size): total_size, num_goals - 1
        ] = 1.0
        
        for idx in range(num_goals):
            curr_dict = {}
            for key in orig_data:
                curr_dict[key] = orig_data[key]
            next_states = orig_data['next_observations'][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            curr_dict['rewards'] = new_rewards.astype(np.float32)
            curr_dict['terminals'] = new_rewards.astype(np.float32)

            # Append task index
            goal_arr = np.zeros((data['observations'].shape[0], len(all_goals)))
            goal_arr[:, idx] = 1.0
            curr_dict['observations'] = np.concatenate([curr_dict['observations'], goal_arr], -1)
            curr_dict['next_observations'] = np.concatenate([curr_dict['next_observations'], goal_arr], -1)
            
            start_idx = idx * buffer_partition_size
            end_idx = min((idx+1) * buffer_partition_size, total_size)

            if relabel_balance_batch:
                relabel_masks = np.ones_like(curr_dict['rewards'])
                relabel_masks[start_idx:end_idx] = 0.0
                # Account for the edge case of overflowing transitions
                if idx == num_goals - 1:
                    relabel_masks[end_idx: total_size] = 0.0
                
                curr_dict['relabel_masks'] = relabel_masks
                curr_dict['original_task_id_onehot'] = orig_task_id_onehot
            print('Assign rewards %s to relabeled data.' % str(relabel_value))
            curr_dict['rewards'][:start_idx] = 0.0
            curr_dict['terminals'][:start_idx] = 0.0
            if idx != num_goals - 1:
                curr_dict['rewards'][end_idx:] = 0.0
                curr_dict['terminals'][end_idx:] = 0.0

            all_data.append(curr_dict)

        final_data = {}
        for key in all_data[0].keys():
            final_data[key] = np.concatenate([x[key] for x in all_data], 0)
        return final_data
    elif multitask_type == 'relabel-all-zero':
        orig_data = deepcopy(data)
        all_data = []
        all_data.append(orig_data)
        for idx in range(num_goals):
            curr_dict = {}
            for key in orig_data:
                curr_dict[key] = orig_data[key]
            next_states = orig_data['next_observations'][:, :2]
            curr_goal = all_goals[idx]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            curr_dict['rewards'] = curr_dict['rewards'] * 0.0
            curr_dict['terminals'] = curr_dict['terminals'] * 0.0

            # Append task index
            goal_arr = np.zeros((data['observations'].shape[0], len(all_goals)))
            goal_arr[:, idx] = 1.0
            curr_dict['observations'] = np.concatenate([curr_dict['observations'], goal_arr], -1)
            curr_dict['next_observations'] = np.concatenate([curr_dict['next_observations'], goal_arr], -1)
            all_data.append(curr_dict)

        idx = num_goals
        goal_arr = np.zeros((orig_data['observations'].shape[0], 1)) + idx
        all_data[0]['observations'] = np.concatenate([all_data[0]['observations'], goal_arr], -1)
        all_data[0]['next_observations'] = np.concatenate([all_data[0]['next_observations'], goal_arr], -1)

        final_data = {}
        for key in orig_data:
            final_data[key] = np.concatenate([x[key] for x in all_data], 0)
        return final_data
    elif multitask_type == 'single_task':
        """Extract data with the goal of reaching the farthest position"""
        print('original dataset size is %d' % data['observations'].shape[0])
        goal_idx = getattr(env, "task_idx", len(all_goals) - 1)
        print('task idx is', goal_idx)
        actual_goals = data['goals']
        # mask = (np.linalg.norm(data['goals'], axis=1) == np.linalg.norm(all_goals[-1]))
        mask = (np.squeeze(np.argmin(np.linalg.norm(np.expand_dims(data['goals'], axis=1) - np.expand_dims(np.array(all_goals), axis=0), axis=-1), axis=-1)) == goal_idx)
        for key in data:
            if key != 'goal':
                data[key] = data[key][mask]
        print('single task dataset size is %d' % data['observations'].shape[0])
        # import pdb; pdb.set_trace()
        return data
    elif multitask_type == 'multi_task':
        """Extract data with the goal of reaching the farthest position"""
        actual_goals = data['goals']
        # goal_arr = np.zeros((data['observations'].shape[0], 1))
        goal_arr = np.zeros((data['observations'].shape[0], len(all_goals)))
        for idx in range(num_goals):
            mask = (np.squeeze(np.argmin(np.linalg.norm(np.expand_dims(data['goals'], axis=1) - np.expand_dims(np.array(all_goals), axis=0), axis=-1), axis=-1)) == idx)
            # goal_arr[mask] += idx
            goal_arr[mask, idx] = 1.0
            """Relabel rewards with the actual goal"""
            next_states = data['next_observations'][mask][:, :2]
            curr_goal = actual_goals[mask]
            new_rewards = np.linalg.norm(next_states - curr_goal, axis=-1) <= 0.5
            data['rewards'][mask] = new_rewards.astype(np.float32)
            data['terminals'][mask] = new_rewards.astype(np.float32)
        # assert goal_arr.max() == len(all_goals) - 1
        # Append task index
        data['observations'] = np.concatenate([data['observations'], goal_arr], -1)
        data['next_observations'] = np.concatenate([data['next_observations'], goal_arr], -1)
        return data
    elif multitask_type == 'single_task_goal_condition':
        """Extract data with the goal of reaching the farthest position"""
        print('original dataset size is %d' % data['observations'].shape[0])
        actual_goals = data['goals']
        # mask = (np.linalg.norm(data['goals'], axis=1) == np.linalg.norm(all_goals[-1]))
        mask = (np.linalg.norm(data['goals'] - all_goals[-1], axis=-1) <= 0.5)
        for key in data:
            if key != 'goal':
                data[key] = data[key][mask]
        next_states = data['next_observations'][:, :2]
        new_rewards = np.linalg.norm(next_states - actual_goals[mask], axis=-1) <= 0.5
        data['rewards'] = new_rewards.astype(np.float32)
        data['terminals'] = new_rewards.astype(np.float32)
        print('single task dataset size is %d' % data['observations'].shape[0])
        return data
    elif multitask_type == 'multi_task_goal_condition':
        # This is our directed multi-task setting
        """Extract data with the goal of reaching the farthest position"""
        actual_goals = data['goals']
        # goal_arr = np.zeros((data['observations'].shape[0], 1))
        next_states = data['next_observations'][:, :2]
        new_rewards = np.linalg.norm(next_states - actual_goals, axis=-1) <= 0.5
        data['rewards'] = new_rewards.astype(np.float32)
        data['terminals'] = new_rewards.astype(np.float32)
        # assert goal_arr.max() == len(all_goals) - 1
        # Append task index
        data['observations'] = np.concatenate([data['observations'], actual_goals], -1)
        data['next_observations'] = np.concatenate([data['next_observations'], actual_goals], -1)
        return data
    elif multitask_type == 'single_task_relabel_zero':
        """Extract data with the goal of reaching the farthest position"""
        print('original dataset size is %d' % data['observations'].shape[0])
        actual_goals = data['goals']
        # mask = (np.linalg.norm(data['goals'], axis=1) == np.linalg.norm(all_goals[-1]))
        mask = (np.linalg.norm(data['goals'] - all_goals[-1], axis=-1) <= 0.5)
        data['rewards'] = data['rewards'] * mask
        data['terminals'] = data['terminals'] * mask
        return data
    else:
        assert False, "multitask_type is unknown"

def restore_pool_d4rl(replay_pool, name, env, multitask_type=None, load_traj=False):
    import gym
    import d4rl
    # multitask_type = 'relabel-all-zero'
    # if 'kitchen' in name:
    #     data = gym.make(name).unwrapped.get_dataset()
    #     data['next_observations'] = data['observations'][1:]
    # else:
    # import pdb; pdb.set_trace()
    data = d4rl.qlearning_dataset(env.unwrapped)
    if multitask_type is not None:
        data = modify_data(env.unwrapped, data, multitask_type, relabel_balance_batch=True)

    if 'antmaze' in name or 'kitchen' in name:
        if multitask_type != 'hipi':
            data['rewards'] = np.expand_dims(data['rewards'], axis=1)
            data['rewards'] = (np.expand_dims(data['rewards'], axis=1) - 0.5) * 4.0
    elif 'pen' in name:
        data['rewards'] = np.expand_dims(data['rewards'], axis=1)*0.02 - 0.5
    elif 'hammer' in name:
        data['rewards'] = np.expand_dims(data['rewards'], axis=1)*0.02 - 0.05
    elif 'door' in name:
        data['rewards'] = np.expand_dims(data['rewards'], axis=1)*0.1
    else:
        # data['rewards'] = np.expand_dims(data['rewards'], axis=1)
        data['rewards'] = (np.expand_dims(data['rewards'], axis=1) - 0.5) * 4.0
    # data['rewards'] = (np.expand_dims(data['rewards'], axis=1) - 0.5) * 20.0
    if multitask_type != 'hipi':
        data['terminals'] = np.expand_dims(data['terminals'], axis=1)
        data['rewards'] = np.reshape(data['rewards'], (data['rewards'].shape[0], 1))
        data['terminals'] = np.reshape(data['terminals'], (data['terminals'].shape[0], 1))
        if multitask_type is not None:
            data['relabel_masks'] = np.reshape(data['relabel_masks'], (data['relabel_masks'].shape[0], 1))
            print (data['rewards'].shape, data['terminals'].shape, data['relabel_masks'].shape)

    if load_traj and 'antmaze' in name:
        trajectories = []
        num_paths = 0
        temp = 0
        path_end_idx = []
        for i in range(data['terminals'].shape[0]):
            if data['terminals'][i][0] or i - temp + 1 == 1000:
                num_paths += 1
                temp = i + 1
                trajectory = {}
                for key in data.keys():
                    if len(path_end_idx) == 0:
                        trajectory[key] = data[key][:i+1]
                    else:
                        trajectory[key] = data[key][path_end_idx[-1]+1:i+1]
                try:
                    assert trajectory['terminals'].astype(float).sum() == 1.0 or trajectory['terminals'].shape[0] == 1000
                except:
                    import pdb; pdb.set_trace()
                trajectories.append(trajectory)                       
                path_end_idx.append(i)
        if path_end_idx[-1] < data['terminals'].shape[0] - 1:
            trajectory = {}
            for key in data.keys():
                trajectory[key] = data[key][path_end_idx[-1]+1:]
            try:
                assert trajectory['terminals'].astype(float).sum() == 1.0 or trajectory['terminals'].astype(float).sum() == 0.0
            except:
                import pdb; pdb.set_trace()
            trajectories.append(trajectory)
        replay_pool.add_paths(trajectories)
    else:
        replay_pool.add_samples(data)


def restore_pool_softlearning(replay_pool, experiment_root, max_size, save_path=None):
    print('[ cuds/off_policy ] Loading SAC replay pool from: {}'.format(experiment_root))
    if 'pkl' not in experiment_root:
        experience_paths = [
            checkpoint_dir
            for checkpoint_dir in sorted(glob.iglob(
                os.path.join(experiment_root, 'checkpoint_*')))
        ]

        checkpoint_epochs = [int(path.split('_')[-1]) for path in experience_paths]
        checkpoint_epochs = sorted(checkpoint_epochs)
        if max_size == 250e3:
            checkpoint_epochs = checkpoint_epochs[2:]

        for epoch in checkpoint_epochs:
            fullpath = os.path.join(experiment_root, 'checkpoint_{}'.format(epoch), 'replay_pool.pkl')
            print('[ cuds/off_policy ] Loading replay pool data: {}'.format(fullpath))
            replay_pool.load_experience(fullpath)
            if replay_pool.size >= max_size:
                break
    else:
        print('[ cuds/off_policy ] Loading replay pool data: {}'.format(experiment_root))
        replay_pool.load_experience(experiment_root)

    if save_path is not None:
        size = replay_pool.size
        stat_path = os.path.join(save_path, 'pool_stat_{}.pkl'.format(size))
        save_path = os.path.join(save_path, 'pool_{}.pkl'.format(size))
        d = {}
        for key in replay_pool.fields.keys():
            d[key] = replay_pool.fields[key][:size]

        num_paths = 0
        temp = 0
        path_end_idx = []
        for i in range(d['terminals'].shape[0]):
            if d['terminals'][i] or i - temp + 1 == 1000:
                num_paths += 1
                temp = i + 1
                path_end_idx.append(i)
        total_return = d['rewards'].sum()
        avg_return = total_return / num_paths
        buffer_max, buffer_min = -np.inf, np.inf
        path_return = 0.0
        for i in range(d['rewards'].shape[0]):
            path_return += d['rewards'][i]
            if i in path_end_idx:
                if path_return > buffer_max:
                    buffer_max = path_return
                if path_return < buffer_min:
                    buffer_min = path_return
                path_return = 0.0

        print('[ cuds/off_policy ] Replay pool average return is {}, buffer_max is {}, buffer_min is {}'.format(avg_return, buffer_max, buffer_min))
        d_stat = dict(avg_return=avg_return, buffer_max=buffer_max, buffer_min=buffer_min)
        pickle.dump(d_stat, open(stat_path, 'wb'))

        print('[ cuds/off_policy ] Saving replay pool to: {}'.format(save_path))
        pickle.dump(d, open(save_path, 'wb'))

    
    ####
    # val_size = 1000
    # print('NOT USING LAST {} SAMPLES'.format(val_size))
    # replay_pool._pointer -= val_size
    # replay_pool._size -= val_size
    # print(replay_pool._pointer, replay_pool._size)
    # pdb.set_trace()


def restore_pool_bear(replay_pool, load_path):
    print('[ cuds/off_policy ] Loading BEAR replay pool from: {}'.format(load_path))
    data = pickle.load(gzip.open(load_path, 'rb'))
    num_trajectories = data['terminals'].sum() or 1000
    avg_return = data['rewards'].sum() / num_trajectories
    print('[ cuds/off_policy ] {} trajectories | avg return: {}'.format(num_trajectories, avg_return))

    for key in ['log_pis', 'data_policy_mean', 'data_policy_logvar']:
        del data[key]

    replay_pool.add_samples(data)


def restore_pool_contiguous(replay_pool, load_path):
    print('[ cuds/off_policy ] Loading contiguous replay pool from: {}'.format(load_path))
    import numpy as np
    data = np.load(load_path)

    state_dim = replay_pool.fields['observations'].shape[1]
    action_dim = replay_pool.fields['actions'].shape[1]
    expected_dim = state_dim + action_dim + state_dim + 1 + 1
    actual_dim = data.shape[1]
    assert expected_dim == actual_dim, 'Expected {} dimensions, found {}'.format(expected_dim, actual_dim)

    dims = [state_dim, action_dim, state_dim, 1, 1]
    ends = []
    current_end = 0
    for d in dims:
        current_end += d
        ends.append(current_end)
    states, actions, next_states, rewards, dones = np.split(data, ends, axis=1)[:5]
    replay_pool.add_samples({
        'observations': states,
        'actions': actions,
        'next_observations': next_states,
        'rewards': rewards,
        'terminals': dones.astype(bool)
    })

def restore_pool_softlearning_multitask(replay_pool, experiment_roots, max_size, save_path=None, multitask_type=None, relabel_balance_batch=False):
    dataset = [{'observations': None, 'actions': None, 'next_observations': None, 'rewards': None, 'terminals': None} for task_idx in range(len(experiment_roots))]
    reward_jump = [None for _ in range(len(experiment_roots))]
    reward_run = None
    max_size = max_size / len(experiment_roots)
    for i, experiment_root in enumerate(experiment_roots):
        print('[ cuds/off_policy ] Loading replay pool data: {}'.format(experiment_root))
        try:
            with gzip.open(experiment_root, 'rb') as f:
                latest_samples = pickle.load(f)
        except:
            with open(experiment_root, 'rb') as f:
                latest_samples = pickle.load(f)

        if type(latest_samples) is list:
            latest_samples = {key: np.concatenate([latest_samples[i][key] for i in range(len(latest_samples))], axis=0) for key in latest_samples[0].keys()}
        key = list(latest_samples.keys())[0]
        num_samples = latest_samples[key].shape[0]
        for field_name, data in latest_samples.items():
            assert data.shape[0] == num_samples, data.shape
        if dataset[i]['observations'] is None:
            for key in dataset[i]:
                dataset[i][key] = latest_samples[key]
        else:
            for key in dataset[i]:
                dataset[i][key] = np.concatenate([dataset[i][key], latest_samples[key]], axis=0)
        try:
            if reward_jump[i] is None and 'reward_jump' in latest_samples.keys():
                reward_jump[i] = latest_samples['reward_jump']
            else:
                reward_jump[i] = np.concatenate([reward_jump[i], latest_samples['reward_jump']], axis=0)
        except:
            print('%s does not contain reward jump' % experiment_root)
        if 'reward_run' in latest_samples.keys() and 'jump' in experiment_root:
            print('%s contains reward run' % experiment_root)
            if reward_run is None:
                reward_run = latest_samples['reward_run']
            else:
                reward_run = np.concatenate([reward_run, latest_samples['reward_run']], axis=0)

    # change reward to successes for metaworld tasks
    if 'metaworld' in experiment_roots[0]:
        for task_idx in range(len(experiment_roots)):
            if task_idx == 0:
                dataset[task_idx]['rewards'] = np.expand_dims((np.linalg.norm(dataset[task_idx]['observations'][:, 3:5] - np.array([[0.425, 0.55]]), axis=-1) <= 0.15).astype(float), axis=1)
            if task_idx == 1:
                dataset[task_idx]['rewards'] = np.expand_dims((np.linalg.norm(dataset[task_idx]['observations'][:, 3:5] - np.array([[-0.075, 0.7]]), axis=-1) <= 0.15).astype(float), axis=1)
            if task_idx == 2:
                dataset[task_idx]['rewards'] = np.expand_dims((np.abs(dataset[task_idx]['observations'][:, 7] - 0.54) <= 0.06).astype(float), axis=1)
            if task_idx == 3:
                dataset[task_idx]['rewards'] = np.expand_dims((np.abs(dataset[task_idx]['observations'][:, 7] - 0.74) <= 0.06).astype(float), axis=1)

    all_data = {}
    if multitask_type == 'relabel-all' or multitask_type == 'hipi':
        relabel_data = [{'observations': None, 'actions': None, 'next_observations': None, 'rewards': None, 'terminals': None} for task_idx in range(len(experiment_roots))]
        if 'cheetah' in experiment_roots[0]:
            reward_ctrl =  [-0.1*np.expand_dims(np.square(dataset[j]['actions']).sum(axis=-1), axis=1) for j in range(len(experiment_roots))]
        elif 'walker' in experiment_roots[0]:
            reward_ctrl =  [-0.001*np.expand_dims(np.square(dataset[j]['actions']).sum(axis=-1), axis=1) for j in range(len(experiment_roots))]
        elif 'ant' in experiment_roots[0]:
            reward_ctrl =  [-0.5*np.expand_dims(np.square(dataset[j]['actions']).sum(axis=-1), axis=1) for j in range(len(experiment_roots))]
            reward_contact =  [0.5 * 1e-3 * np.sum(np.square(np.clip(dataset[j]['observations'][:, 27:-len(experiment_roots)], -1, 1))) for j in range(len(experiment_roots))]  
        for task_idx in range(len(experiment_roots)):
            for key in relabel_data[task_idx]:
                if key != 'rewards':
                    relabel_data[task_idx][key] = np.concatenate([dataset[j][key] for j in range(len(experiment_roots))], axis=0)
            if len(experiment_roots) == 4:
                if 'cheetah' in experiment_roots[0]:
                    if task_idx == 0:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[1] = relabel_rewards[1] - reward_jump[1]
                        relabel_rewards[2] = -(relabel_rewards[2] - reward_ctrl[2]) + reward_ctrl[2]
                        relabel_rewards[3] = -(relabel_rewards[3] - reward_jump[3] - reward_ctrl[3]) + reward_ctrl[3]
                    if task_idx == 1:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = relabel_rewards[0] + reward_jump[0]
                        relabel_rewards[2] = -(relabel_rewards[2] - reward_ctrl[2]) + reward_ctrl[2] + reward_jump[2]
                        relabel_rewards[3] = -(relabel_rewards[3] - reward_jump[3] - reward_ctrl[3]) + reward_ctrl[3] + reward_jump[3]
                    if task_idx == 2:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = -(relabel_rewards[0] - reward_ctrl[0]) + reward_ctrl[0]
                        relabel_rewards[1] = -(relabel_rewards[1] - reward_jump[1] - reward_ctrl[1]) + reward_ctrl[1]
                        relabel_rewards[3] = relabel_rewards[3] - reward_jump[3]
                    if task_idx == 3:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = -(relabel_rewards[0] - reward_ctrl[0]) + reward_ctrl[0] + reward_jump[0]
                        relabel_rewards[1] = -(relabel_rewards[1] - reward_jump[1] - reward_ctrl[1]) + reward_ctrl[1] + reward_jump[1]
                        relabel_rewards[2] = relabel_rewards[2] + reward_jump[2]
                elif 'metaworld' in experiment_roots[0]:
                    if task_idx == 0:
                        relabel_rewards = np.expand_dims((np.linalg.norm(relabel_data[task_idx]['observations'][:, 3:5] - np.array([[0.425, 0.55]]), axis=-1) <= 0.15).astype(float), axis=1)
                    if task_idx == 1:
                        relabel_rewards = np.expand_dims((np.linalg.norm(relabel_data[task_idx]['observations'][:, 3:5] - np.array([[-0.075, 0.7]]), axis=-1) <= 0.15).astype(float), axis=1)
                    if task_idx == 2:
                        relabel_rewards = np.expand_dims((np.abs(relabel_data[task_idx]['observations'][:, 7] - 0.54) <= 0.06).astype(float), axis=1)
                    if task_idx == 3:
                        relabel_rewards = np.expand_dims((np.abs(relabel_data[task_idx]['observations'][:, 7] - 0.74) <= 0.06).astype(float), axis=1)
            elif len(experiment_roots) == 3:
                if 'cheetah' in experiment_roots[0]:
                    if task_idx == 0:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[1] = -(relabel_rewards[1] - reward_ctrl[1]) + reward_ctrl[1]
                        relabel_rewards[2] = relabel_rewards[2] - reward_jump[2] + np.minimum(reward_run, 3.0)
                    if task_idx == 1:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = -(relabel_rewards[0] - reward_ctrl[0]) + reward_ctrl[0]
                        relabel_rewards[2] = np.minimum(-reward_run, 3.0) + reward_ctrl[2]
                    if task_idx == 2:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = reward_ctrl[0] + reward_jump[0]
                        relabel_rewards[1] = reward_ctrl[1] + reward_jump[1]
                elif 'walker' in experiment_roots[0]:
                    if task_idx == 0:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[1] = -(relabel_rewards[1] - reward_ctrl[1] - 1.0) + reward_ctrl[1] + 1.0
                        relabel_rewards[2] = relabel_rewards[2] - reward_jump[2] + reward_run + np.abs(reward_run)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        not_dones = (dataset[2]['next_observations'][:, 0] > 0.8) \
                                    * (dataset[2]['next_observations'][:, 0] < 2.0) \
                                    * (dataset[2]['next_observations'][:, 1] > -1.0) \
                                    * (dataset[2]['next_observations'][:, 1] < 1.0)
                        relabel_terminals[2] = np.expand_dims(~not_dones, axis=1)
                    if task_idx == 1:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = -(relabel_rewards[0] - reward_ctrl[0] - 1.0) + reward_ctrl[0] + 1.0
                        relabel_rewards[2] = relabel_rewards[2] - reward_jump[2] - reward_run + np.abs(reward_run)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        not_dones = (dataset[2]['next_observations'][:, 0] > 0.8) \
                                    * (dataset[2]['next_observations'][:, 0] < 2.0) \
                                    * (dataset[2]['next_observations'][:, 1] > -1.0) \
                                    * (dataset[2]['next_observations'][:, 1] < 1.0)
                        relabel_terminals[2] = np.expand_dims(~not_dones, axis=1)
                    if task_idx == 2:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        # relabel_rewards[0] = reward_ctrl[0] + reward_jump[0] + 1.0
                        # relabel_rewards[1] = reward_ctrl[1] + reward_jump[1] + 1.0
                        relabel_rewards[0] = reward_ctrl[0] + reward_jump[0] + 1.0 - np.abs(relabel_rewards[0] - reward_ctrl[0] - 1.0)
                        relabel_rewards[1] = reward_ctrl[1] + reward_jump[1] + 1.0 - np.abs(relabel_rewards[1] - reward_ctrl[1] - 1.0)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        relabel_terminals[0] = ~((dataset[0]['next_observations'][:, 0] > 0.8) \
                                    * (dataset[0]['next_observations'][:, 1] > -1.0) \
                                    * (dataset[0]['next_observations'][:, 1] < 1.0))
                        relabel_terminals[0] = np.expand_dims(relabel_terminals[0], axis=1)
                        relabel_terminals[1] = ~((dataset[1]['next_observations'][:, 0] > 0.8) \
                                    * (dataset[1]['next_observations'][:, 1] > -1.0) \
                                    * (dataset[1]['next_observations'][:, 1] < 1.0))
                        relabel_terminals[1] = np.expand_dims(relabel_terminals[1], axis=1)
                elif 'ant' in experiment_roots[0]:
                    if task_idx == 0:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[1] = -(relabel_rewards[1] - reward_ctrl[1] - reward_contact[1] - 1.0) + reward_ctrl[1] + reward_contact[1] + 1.0
                        relabel_rewards[2] = relabel_rewards[2] - reward_jump[2] + reward_run + np.abs(reward_run)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        not_dones = (dataset[2]['next_observations'][:, 0] > 0.2) \
                                    * (dataset[2]['next_observations'][:, 0] < 1.0)
                        relabel_terminals[2] = np.expand_dims(~not_dones, axis=1)
                    if task_idx == 1:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        relabel_rewards[0] = -(relabel_rewards[0] - reward_ctrl[0] - reward_contact[0] - 1.0) + reward_ctrl[0] + reward_contact[0] + 1.0
                        relabel_rewards[2] = relabel_rewards[2] - reward_jump[2] - reward_run + np.abs(reward_run)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        not_dones = (dataset[2]['next_observations'][:, 0] > 0.2) \
                                    * (dataset[2]['next_observations'][:, 0] < 1.0)
                        relabel_terminals[2] = np.expand_dims(~not_dones, axis=1)
                    if task_idx == 2:
                        relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                        # relabel_rewards[0] = reward_ctrl[0] + reward_jump[0] + 1.0
                        # relabel_rewards[1] = reward_ctrl[1] + reward_jump[1] + 1.0
                        relabel_rewards[0] = reward_ctrl[0] + reward_jump[0] + reward_contact[0] + 1.0 - np.abs(relabel_rewards[0] - reward_ctrl[0] - reward_contact[0] - 1.0)
                        relabel_rewards[1] = reward_ctrl[1] + reward_jump[1] + reward_contact[1] + 1.0 - np.abs(relabel_rewards[1] - reward_ctrl[1] - reward_contact[1] - 1.0)
                        relabel_terminals = [dataset[j]['terminals'] for j in range(len(experiment_roots))]
                        relabel_terminals[0] = ~(dataset[0]['next_observations'][:, 0] > 0.2)
                        relabel_terminals[0] = np.expand_dims(relabel_terminals[0], axis=1)
                        relabel_terminals[1] = ~(dataset[1]['next_observations'][:, 0] > 0.2)
                        relabel_terminals[1] = np.expand_dims(relabel_terminals[1], axis=1)
            elif len(experiment_roots) == 2:
                if task_idx == 0:
                    relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                    relabel_rewards[1] = relabel_rewards[1] - reward_jump[1] + reward_run
                if task_idx == 1:
                    relabel_rewards = [dataset[j]['rewards'] for j in range(len(experiment_roots))]
                    relabel_rewards[0] = relabel_rewards[0] = reward_ctrl[0] + reward_jump[0] + 1.0 - np.abs(relabel_rewards[0] - reward_ctrl[0] - 1.0)
            if type(relabel_rewards) is list:
                relabel_rewards = np.concatenate([relabel_rewards[j] for j in range(len(experiment_roots))], axis=0)
            if 'walker' in experiment_roots[0] or 'ant' in experiment_roots[0]:
                relabel_terminals = np.concatenate([relabel_terminals[j] for j in range(len(experiment_roots))], axis=0)
                relabel_data[task_idx]['terminals'] = relabel_terminals
            if relabel_balance_batch:
                relabel_masks = [np.ones_like(dataset[j]['rewards']) for j in range(len(experiment_roots))]
                relabel_masks[task_idx] = np.zeros_like(dataset[task_idx]['rewards'])
                relabel_masks = np.concatenate([relabel_masks[j] for j in range(len(experiment_roots))], axis=0)
                relabel_data[task_idx]['relabel_masks'] = relabel_masks
                # original_task_id_onehot = relabel_data[task_idx]['observations'][:, -len(experiment_roots):].copy()
                # relabel_data[task_idx]['original_task_id_onehot'] = original_task_id_onehot
                relabel_rewards = relabel_rewards * (1.0 - relabel_masks)
            relabel_data[task_idx]['rewards'] = relabel_rewards
            relabel_data[task_idx]['observations'][:, -len(experiment_roots):] = 0.
            relabel_data[task_idx]['next_observations'][:, -len(experiment_roots):] = 0.
            relabel_data[task_idx]['observations'][:, -(len(experiment_roots)-task_idx)] = 1.
            relabel_data[task_idx]['next_observations'][:, -(len(experiment_roots)-task_idx)] = 1.
        if multitask_type == 'relabel-all':
            for key in relabel_data[0]:
                all_data[key] = np.concatenate([relabel_data[task_idx][key] for task_idx in range(len(experiment_roots))], axis=0)
        else:
            assert multitask_type == 'hipi'
            for key in dataset[0].keys():
                if key != 'rewards' and key != 'terminals':
                    all_data[key] = np.concatenate([dataset[task_idx][key] for task_idx in range(len(experiment_roots))], axis=0)
                else:
                    all_data[key] = np.concatenate([relabel_data[task_idx][key] for task_idx in range(len(experiment_roots))], axis=1)
    else:
        for key in dataset[0].keys():
            all_data[key] = np.concatenate([dataset[task_idx][key] for task_idx in range(len(experiment_roots))], axis=0)
    replay_pool.add_samples(all_data)
    for task_idx in range(len(experiment_roots)):
        d = {}
        size = replay_pool.pools[task_idx].size
        for key in replay_pool.pools[task_idx].fields.keys():
            d[key] = replay_pool.pools[task_idx].fields[key][:size]

        num_paths = 0
        temp = 0
        path_end_idx = []
        path_len = 1000 if 'metaworld' not in experiment_roots[0] else 200
        for i in range(d['terminals'].shape[0]):
            if multitask_type == 'hipi':
                termination_flag = d['terminals'][i][task_idx]
            else:
                termination_flag = d['terminals'][i]                
            if termination_flag or i - temp + 1 == path_len:
                num_paths += 1
                temp = i + 1
                path_end_idx.append(i)
        if multitask_type == 'hipi':
            rewards = d['rewards'][:, task_idx]
        else:
            rewards = d['rewards']
        total_return = rewards.sum()
        avg_return = total_return / num_paths
        buffer_max, buffer_min = -np.inf, np.inf
        path_return = 0.0
        for i in range(d['rewards'].shape[0]):
            if multitask_type == 'hipi':
                path_return += d['rewards'][i, task_idx]
            else:
                path_return += d['rewards'][i]
            if i in path_end_idx:
                if path_return > buffer_max:
                    buffer_max = path_return
                if path_return < buffer_min:
                    buffer_min = path_return
                path_return = 0.0
        print('[ cuds/off_policy ] Replay pool {} has {} transitions'.format(task_idx, replay_pool.pools[task_idx].size))
        print('[ cuds/off_policy ] Replay pool {} average return is {}, buffer_max is {}, buffer_min is {}'.format(task_idx, avg_return, buffer_max, buffer_min))
