import os
import os.path as osp
import numpy as np
from PIL import Image, ImageFile

import torch
from torch.utils.data import Dataset

from nerv.utils import glob_all, load_obj

#from .utils import BaseTransforms

from ...base_slots.datasets.langtable import LangtableInstDataset
from ...base_slots.datasets.utils import BaseTransforms

ImageFile.LOAD_TRUNCATED_IMAGES = True

class LangtableActionDataset(LangtableInstDataset):
    """Langtable dataset with language instructions."""

    def __init__(
        self,
        data_root,
        split,
        langtable_transform,
        n_sample_frames=16,
        frame_offset=None,
        video_len=50,
        input_frames=6,
    ):
        super().__init__(
            data_root=data_root,
            split=split,
            langtable_transform=langtable_transform,
            n_sample_frames=n_sample_frames,
            frame_offset=frame_offset,
            video_len=video_len,
        )
        self.input_frames = input_frames

        # load actions
        self.act_root = os.path.join(data_root, "actions")
        actions = np.load(os.path.join(self.act_root, "actions.npy"), allow_pickle=True)
        self.actions = [torch.from_numpy(a) for a in actions]

    def _read_actions(self, idx):
        folder, start_idx = self.valid_idx[idx]
        episode_actions = self.actions[int(os.path.basename(folder))] # [Length, act_dim]
        return episode_actions[start_idx + self.input_frames - 1] # get last action

    def _read_insts_raw(self, idx):
        folder, _ = self.valid_idx[idx]
        inst_file = os.path.join(self.inst_root, os.path.basename(folder)+".npy")
        return torch.tensor(np.load(inst_file))
 
    '''
    def collate_fn(self, batch):
        data_idxs = [data['data_idx'] for data in batch]
        imgs = [data['img'] for data in batch]
        insts = [data['instruction'] for data in batch]
        actions = [data['actions'] for data in batch]

        data_dict = {
            'data_idx': torch.stack(data_idxs),
            'img': torch.stack(imgs),
            'instruction': torch.stack(insts),
            'actions': np.array(actions, dtype='object') # variable length
        }
        return data_dict
    '''

    def __getitem__(self, idx):
        """Data dict:
            - data_idx: int
            - img: [T, 3, H, W]
        """
        frames = self._read_frames(idx)
        insts = self._read_insts(idx)
        actions = self._read_actions(idx)
        insts_raw = self._read_insts_raw(idx)
        data_dict = {
            'data_idx': idx,
            'img': frames,
            'instruction': insts,
            'actions': actions,
            'instruction_raw': insts_raw,
            # 'instruction_text': insts_text,
        }
        return data_dict

def build_langtable_action_dataset(params, val_only=False):
    """Build langtable video dataset."""
    if params.model == 'RoboticsTransformer':
        # don't transform image since we'll use huggingface processor
        import torchvision.transforms as transforms
        langtable_transform=transforms.Compose([
            transforms.ToTensor(),  # [3, H, W]
            transforms.Resize(params.resolution),
        ])
    else:
        langtable_transform=BaseTransforms(params.resolution)
    args = dict(
        data_root=params.data_root,
        split='val',
        langtable_transform=langtable_transform,
        n_sample_frames=params.n_sample_frames,
        frame_offset=params.frame_offset,
        video_len=params.video_len,
        input_frames=params.input_frames,
    )
    val_dataset = LangtableActionDataset(**args)
    if val_only:
        return val_dataset
    args['split'] = 'train'
    train_dataset = LangtableActionDataset(**args)
    return train_dataset, val_dataset

