import numpy as np


def get_datasets(dataset_e_raw, dataset_s_raw, num_e=1, num_s_e=10, num_s_s=1000):
    """
    Returns D_e and D_s.
    """
    dataset_s = dataset_m_trajs(dataset_s_raw, num_s_s)
    dataset_s['flag'] = np.zeros_like(dataset_s['terminals'])
    dataset_e, dataset_s_extra = dataset_split_expert(dataset_e_raw, num_s_e, num_e + num_s_e)
    dataset_e['flag'] = np.ones_like(dataset_e['terminals'])
    # Add expert trajectories to the suboptimal dataset
    if dataset_s_extra != {}:  # Fix bugs
        dataset_s_extra['flag'] = np.ones_like(dataset_s_extra['terminals'])
        for key in dataset_s.keys():
            dataset_s[key] = np.concatenate([dataset_s[key], dataset_s_extra[key]], 0)
    return dataset_e, dataset_s


def dataset_split_expert(dataset, split_x, exp_num, terminate_on_end=False):
    """
    Returns D_e and expert data in D_s of setting 1 in the paper.
    """
    n = dataset['rewards'].shape[0]
    return_traj = []
    obs_traj = [[]]
    next_obs_traj = [[]]
    action_traj = [[]]
    reward_traj = [[]]
    done_traj = [[]]
    timeout_traj = [[]]  # Timeout

    for i in range(n - 1):
        obs_traj[-1].append(dataset['observations'][i].astype(np.float32))
        next_obs_traj[-1].append(dataset['observations'][i+1].astype(np.float32))
        action_traj[-1].append(dataset['actions'][i].astype(np.float32))
        reward_traj[-1].append(dataset['rewards'][i].astype(np.float32))
        done_traj[-1].append(bool(dataset['terminals'][i]))
        timeout_traj[-1].append(bool(dataset['timeouts'][i]))  # Timeout

        final_timestep = dataset['timeouts'][i] | dataset['terminals'][i]
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            return_traj.append(np.sum(reward_traj[-1]))
            obs_traj.append([])
            next_obs_traj.append([])
            action_traj.append([])
            reward_traj.append([])
            done_traj.append([])
            timeout_traj.append([])  # Timeout

    # select 10 trajectories
    inds_all = list(range(len(obs_traj)))
    inds_succ = inds_all[:exp_num]
    inds_s = inds_succ[-split_x:] if split_x > 0 else []  # Fix bugs
    inds_s = list(inds_s)
    inds_succ = list(inds_succ)
    inds_e = set(inds_succ) - set(inds_s)
    inds_e = list(inds_e)

    print('# {} expert trajs in D_e'.format(len(inds_e)))
    print('# {} expert trajs in D_s'.format(len(inds_s)))

    obs_traj_e = [obs_traj[i] for i in inds_e]
    next_obs_traj_e = [next_obs_traj[i] for i in inds_e]
    action_traj_e = [action_traj[i] for i in inds_e]
    reward_traj_e = [reward_traj[i] for i in inds_e]
    done_traj_e = [done_traj[i] for i in inds_e]
    timeout_traj_e = [timeout_traj[i] for i in inds_e]  # Timeout

    obs_traj_s = [obs_traj[i] for i in inds_s]
    next_obs_traj_s = [next_obs_traj[i] for i in inds_s]
    action_traj_s = [action_traj[i] for i in inds_s]
    reward_traj_s = [reward_traj[i] for i in inds_s]
    done_traj_s = [done_traj[i] for i in inds_s]
    timeout_traj_s = [timeout_traj[i] for i in inds_s]  # Timeout

    def concat_trajectories(trajectories):
        return np.concatenate(trajectories, 0)

    dataset_e = {
        'observations': concat_trajectories(obs_traj_e),
        'actions': concat_trajectories(action_traj_e),
        'next_observations': concat_trajectories(next_obs_traj_e),
        'rewards': concat_trajectories(reward_traj_e),
        'terminals': concat_trajectories(done_traj_e),
        'timeouts': concat_trajectories(timeout_traj_e),  # Timeout
    }      
        
    # INFO(jn) : Pen
    # There is no legal data in the first few expert data, and all experts will be returned
    
    # dataset_e = {
    #     'observations': concat_trajectories(obs_traj),
    #     'actions': concat_trajectories(action_traj),
    #     'next_observations': concat_trajectories(next_obs_traj),
    #     'rewards': concat_trajectories(reward_traj),
    #     'terminals': concat_trajectories(done_traj),
    #     'timeouts': concat_trajectories(timeout_traj),  # Timeout
    # }

    dataset_s = {
        'observations': concat_trajectories(obs_traj_s),
        'actions': concat_trajectories(action_traj_s),
        'next_observations': concat_trajectories(next_obs_traj_s),
        'rewards': concat_trajectories(reward_traj_s),
        'terminals': concat_trajectories(done_traj_s),
        'timeouts': concat_trajectories(timeout_traj_s),  # Timeout
    } if obs_traj_s != [] else {}  # Fix bugs

    return dataset_e, dataset_s


def dataset_m_trajs(dataset, m, terminate_on_end=False):
    """
    Returns m trajs from dataset.
    """
    n = dataset['rewards'].shape[0]
    return_traj = []
    obs_traj = [[]]
    next_obs_traj = [[]]
    action_traj = [[]]
    reward_traj = [[]]
    done_traj = [[]]
    timeout_traj = [[]]  # Timeouts

    for i in range(n - 1):
        obs_traj[-1].append(dataset['observations'][i].astype(np.float32))
        next_obs_traj[-1].append(dataset['observations'][i+1].astype(np.float32))
        action_traj[-1].append(dataset['actions'][i].astype(np.float32))
        reward_traj[-1].append(dataset['rewards'][i].astype(np.float32))
        done_traj[-1].append(bool(dataset['terminals'][i]))
        timeout_traj[-1].append(bool(dataset['timeouts'][i]))  # Timeout

        final_timestep = dataset['timeouts'][i] | dataset['terminals'][i]
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            return_traj.append(np.sum(reward_traj[-1]))
            obs_traj.append([])
            next_obs_traj.append([])
            action_traj.append([])
            reward_traj.append([])
            done_traj.append([])
            timeout_traj.append([])  # Timeout

    # select m trajectories
    inds_all = list(range(len(obs_traj)))
    inds = inds_all[:m]
    inds = list(inds)

    print('# {} diverse trajs in D_s'.format(m))

    obs_traj = [obs_traj[i] for i in inds]
    next_obs_traj = [next_obs_traj[i] for i in inds]
    action_traj = [action_traj[i] for i in inds]
    reward_traj = [reward_traj[i] for i in inds]
    done_traj = [done_traj[i] for i in inds]
    timeout_traj = [timeout_traj[i] for i in inds]  # Timeout

    def concat_trajectories(trajectories):
        return np.concatenate(trajectories, 0)

    return {
        'observations': concat_trajectories(obs_traj),
        'actions': concat_trajectories(action_traj),
        'next_observations': concat_trajectories(next_obs_traj),
        'rewards': concat_trajectories(reward_traj),
        'terminals': concat_trajectories(done_traj),
        'timeouts': concat_trajectories(timeout_traj),  # Timeout
    }
