import os
import argparse
import numpy as np

import lang_hrl
from lang_hrl.datasets.datasets import Dataset, BabyAITrajectoryDataset

def drop_single_dataset(path, drop=0.0, every_n=-1):

    # First load the dataset
    dataset = Dataset.load(path)
    states, actions, next_states, rewards, dones = dataset.states, dataset.actions, dataset.next_states, dataset.rewards, dataset.dones
    if every_n > 0:
        if isinstance(states, dict):
            states = {k: v[::every_n] for k,v in states.items()}
        elif isinstance(states, np.ndarray):
            states = states[::every_n]
        
        if isinstance(next_states, dict):
            next_states = {k: v[::every_n] for k,v in next_states.items()}
        elif isinstance(next_states, np.ndarray):
            next_states = next_states[::every_n]

        if not actions is None:
            actions = actions[::every_n]
        if not rewards is None:
            rewards = rewards[::every_n]
        if not dones is None:
            dones = dones[::every_n]

    # Now compute the random drop locations
    # Assume that the dataset has states. If not, what are we even doing?
    if drop > 0.0:
        if isinstance(states, dict):
            dataset_size = len(states[next(iter(states.keys()))])
        else:
            dataset_size = len(states)

        inds_to_keep = np.random.permutation(dataset_size)[:int(dataset_size*(1-drop) + 0.5)] # Round up
        inds_to_keep = np.sort(inds_to_keep)

        if isinstance(states, dict):
            states = {k: v[inds_to_keep] for k,v in states.items()}
        elif isinstance(states, np.ndarray):
            states = states[inds_to_keep]
        
        if isinstance(next_states, dict):
            next_states = {k: v[inds_to_keep] for k,v in next_states.items()}
        elif isinstance(next_states, np.ndarray):
            next_states = next_states[inds_to_keep]

        if not actions is None:
            actions = actions[inds_to_keep]
        if not rewards is None:
            rewards = rewards[inds_to_keep]
        if not dones is None:
            dones = dones[inds_to_keep]

    # Create the name of the pruned dataset!
    filename = os.path.basename(path)
    filename, ext = os.path.splitext(filename)
    filename += "_prune"
    if every_n > 0:
        filename += "_every" + str(every_n)
    if drop > 0.0:
        filename += "_drop" + str(drop)
    filename = filename + ext
    
    new_dataset = Dataset(states=states, actions=actions, next_states=next_states, rewards=rewards, dones=dones)
    new_dataset.save(os.path.join(os.path.dirname(path), filename))

def drop_seq_dataset(path, drop=0.0, every_n=-1):

    dataset = BabyAITrajectoryDataset.load(path)
    new_images = []
    new_missions = []
    new_subgoals = []
    new_masks = []
    new_actions = []
    for images, mission, subgoals, mask, actions in zip(dataset.images, dataset.missions, 
                                                          dataset.subgoals, dataset.masks, dataset.actions):
        if every_n > 0:
            images = images[::every_n]
            actions = actions[::every_n]
            mask = mask[:, ::every_n]

        if drop > 0:
            traj_size = len(images)
            inds_to_keep = np.random.permutation(traj_size)[:int(traj_size*(1-drop) + 0.5)] # Round up
            inds_to_keep = np.sort(inds_to_keep)
            images = images[inds_to_keep]
            actions = actions[inds_to_keep]
            mask = mask[:, inds_to_keep]

        new_images.append(images)
        new_missions.append(mission)
        new_subgoals.append(subgoals)
        new_masks.append(mask)
        new_actions.append(actions)

    filename = os.path.basename(path)
    filename, ext = os.path.splitext(filename)
    filename += "_prune"
    if every_n > 0:
        filename += "_every" + str(every_n)
    if drop > 0.0:
        filename += "_drop" + str(drop)
    filename = filename + ext

    new_dataset = BabyAITrajectoryDataset(new_images, new_missions, new_subgoals, new_actions, masks=new_masks)
    new_dataset.save(os.path.join(os.path.dirname(path), filename))


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, help="path to dataset")
    parser.add_argument("--drop", type=float, default=0.0)
    parser.add_argument("--every-n", type=int, default=-1)

    args = parser.parse_args()

    if args.path.endswith('.pkl'):
        drop_seq_dataset(args.path, args.drop, args.every_n)
    else:
        drop_single_dataset(args.path, args.drop, args.every_n)

