import torch
import numpy as np
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pickle

class TensorDataset(Dataset):

    def __init__(self, *tensors):
        self.tensors = tensors
    
    def _maybe_tensor_or_dict(self, tensor, index):
        if isinstance(tensor, dict):
            return {k: tensor[k][index] for k in tensor.keys()}
        else:
            return tensor[index]

    def __getitem__(self, index):
        return tuple(self._maybe_tensor_or_dict(t, index) for t in self.tensors)

    def __len__(self):
        if isinstance(self.tensors[0], dict):
            return self.tensors[0][next(iter(self.tensors[0].keys()))].size(0)
        else:
            return self.tensors[0].size(0)

class Dataset(object):
    def __init__(self, states=None, actions=None, next_states=None, rewards=None, dones=None):
        self.states = states
        self.actions = actions
        self.next_states = next_states
        self.rewards = rewards
        self.dones = dones
        lengths = []
        if not self.states is None:
            if isinstance(self.states, dict):
                lengths.extend([len(self.states[k]) for k in self.states.keys()])
            else:
                lengths.append(len(self.states))

        if not self.next_states is None:
            if isinstance(self.next_states, dict):
                lengths.extend([len(self.next_states[k]) for k in self.next_states.keys()])
            else:
                lengths.append(len(self.next_states))

        if not self.actions is None:
            lengths.append(len(self.actions))
        if not self.rewards is None:
            lengths.append(len(self.rewards))
        if not self.dones is None:
            lengths.append(len(self.dones))

        assert all(l == lengths[0] for l in lengths), "Not all dataset members were the same length!"

    @property
    def fields(self):
        raise NotImplementedError

    def to_tensor_dataset(self):
        tensors = []
        for field in self.fields:
            if isinstance(field, dict):
                tensors.append({k : torch.from_numpy(v) for k,v in field.items()})
            else:
                tensors.append(torch.from_numpy(field))
        return TensorDataset(*tensors)

    def save(self, path):
        save_kwargs = {}
        save_kwargs['actions'] = self.actions
        save_kwargs['rewards'] = self.rewards
        save_kwargs['dones'] = self.dones
        if isinstance(self.states, dict):
            for k in self.states.keys():
                save_kwargs['states_' + k] = self.states[k]
        else:
            save_kwargs['states'] = self.states
        if isinstance(self.next_states, dict):
            for k in self.next_states.keys():
                save_kwargs['next_states_' + k] = self.next_states[k]
        else:
            save_kwargs['next_states'] = self.next_states
            
        save_kwargs = {k: v for k, v in save_kwargs.items() if not v is None}
        np.savez(path, **save_kwargs)

    @classmethod
    def load(cls, path, fraction=1):
        data = np.load(path)
        state_keys = [k for k in data.keys() if k.startswith('states')]

        if len(state_keys) == 0:
            states = None
            num_data_pts = None
        elif len(state_keys) == 1:
            num_data_pts = int(fraction*len(data[state_keys[0]]))
            states = data[state_keys[0]][:num_data_pts]
        else:
            num_data_pts = int(fraction*len(data[state_keys[0]]))
            states = {k[len('states_'):]: data[k][:num_data_pts] for k in state_keys}
        
        if num_data_pts is None:
            key = [k for k in data.keys() if not k.startswith('states') and not k.startswith('next_states')][0]
            num_data_pts = int(fraction*len(data[key]))

        next_state_keys = [k for k in data.keys() if k.startswith('next_states')]
        if len(next_state_keys) == 0:
            next_states = None
        elif len(next_state_keys) == 1:
            next_states = data[next_state_keys[0]][:num_data_pts]
        else:
            next_states = {k[len('next_states_'):]: data[k][:num_data_pts] for k in next_state_keys}

        cls_kwargs = {k:(data[k][:num_data_pts] if k in data else None) for k in ("actions", "rewards", "dones")}
        
        dataset = cls(states=states, next_states=next_states, **cls_kwargs)
        data.close()
        return dataset

    @classmethod
    def merge(cls, datasets):
        assert len(datasets) > 0, "Must pass in more than one dataset"
        assert all([isinstance(d, Dataset) for d in datasets]), "Must all be datasets"
        actions = np.concatenate([d.actions for d in datasets], axis=0) if not datasets[0].actions is None else None
        rewards = np.concatenate([d.rewards for d in datasets], axis=0) if not datasets[0].rewards is None else None
        dones = np.concatenate([d.dones for d in datasets], axis=0) if not datasets[0].dones is None else None

        if datasets[0].states is None:
            states = None
        elif isinstance(datasets[0].states, dict):
            states = {k : np.concatenate([d.states[k] for d in datasets], axis=0) for k in datasets[0].states.keys()}
        elif isinstance(datasets[0].states, np.ndarray):
            states = np.concatenate([d.states for d in datasets], axis=0)
        
        if datasets[0].next_states is None:
            next_states = None
        elif isinstance(datasets[0].next_states, dict):
            next_states = {k : np.concatenate([d.next_states[k] for d in datasets], axis=0) for k in datasets[0].next_states.keys()}
        elif isinstance(datasets[0].states, np.ndarray):
            next_states = np.concatenate([d.next_states for d in datasets], axis=0)
        
        return cls(states=states, actions=actions, next_states=next_states, rewards=rewards, dones=dones)

class BehaviorCloningDataset(Dataset):        

    @property
    def fields(self):
        return self.states, self.actions


class InverseModelDataset(Dataset):
    
    @property
    def fields(self):
        return self.states, self.next_states, self.actions
        

class BabyAITrajectoryDataset(Dataset):

    def __init__(self, images, missions, subgoals, actions, masks=None, next_images=None):
        self.images = images
        self.missions = missions
        self.subgoals = subgoals
        self.actions = actions

        if not masks is None:
            self.masks = masks
        else:
            self.masks = None

        if not next_images is None:
            self.next_images = next_images
        else:
            self.next_images = None

    def __getitem__(self, index):
        img = torch.from_numpy(self.images[index])
        mission = torch.from_numpy(self.missions[index])
        subgoal = torch.from_numpy(self.subgoals[index])
        action = torch.from_numpy(self.actions[index])
        
        if not self.next_images is None:
            next_img = torch.from_numpy(self.next_images[index])
        else:
            next_img = None

        if not self.masks is None:
            mask = torch.from_numpy(self.masks[index])
        else:
            mask = None

        # Trim so we don't run out of block size.
        # NOTE: this is a bit of a hack :/, should fix later.
        if subgoal.shape[0] > 350:
            subgoal = subgoal[:350]
            mask = mask[:350, :]
        if img.shape[0] > 340:
            img = img[:340]
            action = action[:340]
            if not self.masks is None:
                mask = mask[:, :340]
            if not self.next_images is None:
                next_img = next_img[:340]

        return img, mission, mask, subgoal, action, next_img

    def __len__(self):
        return len(self.images)

    def save(self, path):
        dataset = {
            'images' : self.images,
            'missions' : self.missions,
            'subgoals': self.subgoals,
            'actions' : self.actions,
            'masks': self.masks,
            'next_images': self.next_images
        }
        with open(path + ".pkl", 'wb') as f:
            pickle.dump(dataset, f)
        
    @classmethod
    def load(cls, path, fraction=1):
        with open(path, 'rb') as f:
            dataset = pickle.load(f)
        if fraction < 1.0:
            # Load only a fraction of the dataset
            num_data_pts = int(fraction*len(dataset['images']))
            dataset = {k:v[:num_data_pts] for k,v in dataset.items() if not v is None}
        if not 'next_images' in dataset:
            dataset['next_images'] = None # short hack
        if not 'masks' in dataset:
            dataset['masks'] = None
        return cls(**dataset)

    @classmethod
    def merge(cls, datasets):
        images, missions, subgoals, masks, actions, next_images = [], [], [], [], [], []
        for dataset in datasets:
            images.extend(dataset.images)
            missions.extend(dataset.missions)
            subgoals.extend(dataset.subgoals)
            actions.extend(dataset.actions)
            if not dataset.masks is None:
                masks.extend(dataset.masks)
            if not dataset.next_images is None:
                next_images.extend(dataset.next_images)
        if len(masks) == 0:
            masks = None
        if len(next_images) == 0:
            next_images = None
        return cls(images, missions, subgoals, actions, masks=masks, next_images=next_images)

def traj_collate_fn(batch):
    images, missions, masks, subgoals, actions, next_images = zip(*batch)
    # Fixed length items
    missions = torch.stack(missions, dim=0)
    # Determine the size of the batch via pad_sequence
    images = pad_sequence(images, batch_first=True, padding_value=0)
    actions = pad_sequence(actions, batch_first=True, padding_value=-100)
    subgoals = pad_sequence(subgoals, batch_first=True, padding_value=0).long()

    obs = {'image': images, 'mission': missions, 'label': subgoals, }

    if not next_images[0] is None:
        next_images = pad_sequence(next_images, batch_first=True, padding_value=0)
        obs['next_image'] = next_images

    if not masks[0] is None:
        # This will need to be done in a for loop to expand the mask.
        B, S, T = images.shape[0], images.shape[1], subgoals.shape[1]
        mask_tensor = torch.zeros(B, T, S, dtype=torch.bool)
        for i, mask in enumerate(masks):
            t, s = mask.shape[0], mask.shape[1]
            mask_tensor[i, :t, :s] = mask
        obs['mask'] = mask_tensor

    return obs, actions
