import os
import os.path as osp
import numpy as np
from PIL import Image, ImageFile
import pickle

from tqdm import tqdm
import torch
from torch.utils.data import Dataset

from nerv.utils import glob_all, load_obj

from .utils import BaseTransforms

ImageFile.LOAD_TRUNCATED_IMAGES = True


class LangtableDataset(Dataset):
    """Langtable dataset"""

    def __init__(
        self,
        data_root,
        split,
        langtable_transform,
        n_sample_frames=6,
        frame_offset=None,
        video_len=50,
    ):

        assert split in ['train', 'val', 'test']
        self.data_root = os.path.join(data_root, split)
        self.split = split
        self.langtable_transform = langtable_transform
        self.n_sample_frames = n_sample_frames
        self.frame_offset = frame_offset
        self.video_len = video_len

        # Get all numbers
        self.valid_idx = self._get_sample_idx()

        # by default, we load small video clips
        self.load_video = False

    def _get_video_start_idx(self, idx):
        return self.valid_idx[idx]

    def _read_frames(self, idx):
        folder, start_idx = self._get_video_start_idx(idx)
        start_idx += 1  # files start from 'test_1.png'
        filename = osp.join(folder, 'test_{}.png')
        
        


        frames = [
            Image.open(filename.format(start_idx +
                                       n * self.frame_offset)).convert('RGB')
            for n in range(self.n_sample_frames)
        ]

        frames = [self.langtable_transform(img) for img in frames]
        return torch.stack(frames, dim=0)  # [N, C, H, W]

    def _read_bboxes(self, idx):
        """Load empty bbox and pres mask for compatibility."""
        bboxes = np.zeros((self.n_sample_frames, 5, 4))
        pres_mask = np.zeros((self.n_sample_frames, 5))
        return bboxes, pres_mask

    def get_video(self, video_idx):
        folder = self.files[video_idx]

        #num_frames = (self.video_len + 1) // self.frame_offset
        num_frames = len(os.listdir(folder)) # we never use offset
        filename = osp.join(folder, 'test_{}.png')       

        frames = [
            Image.open(filename.format(1 +
                                       n * self.frame_offset)).convert('RGB')
            for n in range(num_frames)
        ]

        frames = [self.langtable_transform(img) for img in frames]
        return {
            'video': torch.stack(frames, dim=0),
            'data_idx': video_idx,
        }

    def __getitem__(self, idx):
        """Data dict:
            - data_idx: int
            - img: [T, 3, H, W]
            - bbox: [T, max_num_obj, 4], empty, for compatibility
            - pres_mask: [T, max_num_obj], empty, for compatibility
        """
        if self.load_video:
            return self.get_video(idx)

        frames = self._read_frames(idx)
        data_dict = {
            'data_idx': idx,
            'img': frames,
        }
        if self.split != 'train':
            bboxes, pres_mask = self._read_bboxes(idx)
            data_dict['bbox'] = torch.from_numpy(bboxes).float()
            data_dict['pres_mask'] = torch.from_numpy(pres_mask).bool()
        return data_dict

    def _get_sample_idx(self):
        valid_idx = []  # (video_folder, start_idx)
        files = glob_all(self.data_root, only_dir=True)
        self.files = [s.rstrip('/') for s in files]
        valid_files = []
        self.num_videos = len(self.files)

        for folder in tqdm(self.files):
            # simply use random uniform sampling
            if self.split == 'train':
                # for trajectories shorter than video_len
                frame_len = len(os.listdir(folder))
                max_start_idx = min(frame_len, self.video_len) - \
                   (self.n_sample_frames - 1) * self.frame_offset
                # max_start_idx = frame_len - (self.n_sample_frames - 1)*self.frame_offset
                #valid_idx += [(folder, idx) for idx in range(max_start_idx)]

                if frame_len > self.video_len:
                    valid_files.append(folder)
                    valid_idx += [(folder, idx) for idx in range(max_start_idx)]

            # only test once per video
            else:
                frame_len = len(os.listdir(folder))

                if frame_len > self.video_len:
                    valid_idx += [(folder, 0)]
                    valid_files.append(folder)
                    
        self.files = valid_files
        self.num_videos = len(self.files)
        
        return valid_idx

    def __len__(self):
        if self.load_video:
            return len(self.files)
        return len(self.valid_idx)


class LangtableSlotsDataset(LangtableDataset):
    """Langtable dataset with pre-computed slots."""

    def __init__(
        self,
        data_root,
        video_slots,
        split,
        langtable_transform,
        n_sample_frames=16,
        frame_offset=None,
        video_len=50,
    ):
        super().__init__(
            data_root=data_root,
            split=split,
            langtable_transform=langtable_transform,
            n_sample_frames=n_sample_frames,
            frame_offset=frame_offset,
            video_len=video_len,
        )

        # pre-computed slots
        self.video_slots = video_slots
        self.inst_root = os.path.join(data_root, "labels")
        inst = np.load(os.path.join(self.inst_root, "inst.npy"))
        # inst_word = np.load(os.path.join(self.inst_root, "inst_word.npy"))
        self.inst = torch.from_numpy(inst).float()
        # self.inst_word = torch.from_numpy(inst_word).float()

    def _read_slots(self, idx):
        """Read video frames slots."""
        folder, start_idx = self.valid_idx[idx]
        slots = self.video_slots[os.path.basename(folder)]  # [T, N, C]
        slots = [
            slots[start_idx + n * self.frame_offset]
            for n in range(self.n_sample_frames)
        ]
        return np.stack(slots, axis=0).astype(np.float32)

    def _read_insts(self, idx):
        folder, _ = self.valid_idx[idx]
        #inst_file = os.path.join(self.inst_root, os.path.basename(folder)+".npy")
        return self.inst[int(os.path.basename(folder))]

    # def _read_insts_word(self, idx):
    #     folder, _ = self.valid_idx[idx]
    #     return self.inst_word[int(os.path.basename(folder))]

    def _decode_inst(self, inst):
        """Utlity to decode encoded language instruction"""
        return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8")

    def _read_insts_text(self, idx):
        folder, _ = self.valid_idx[idx]
        inst_file = os.path.join(self.inst_root, os.path.basename(folder)+".npy")
        return self._decode_inst(np.load(inst_file))

    def __getitem__(self, idx):
        """Data dict:
            - data_idx: int
            - img: [T, 3, H, W]
            - bbox: [T, max_num_obj, 4], empty, for compatibility
            - pres_mask: [T, max_num_obj], empty, for compatibility
            - slots: [T, N, C] slots extracted from OBJ3D video frames
        """
        slots = self._read_slots(idx)
        frames = self._read_frames(idx)
        insts = self._read_insts(idx)
        # insts_word = self._read_insts_word(idx) 
        data_dict = {
            'data_idx': idx,
            'slots': slots,
            'img': frames,
            'instruction': insts
        }
        if self.split != 'train':
            bboxes, pres_mask = self._read_bboxes(idx)
            data_dict['bbox'] = torch.from_numpy(bboxes).float()
            data_dict['pres_mask'] = torch.from_numpy(pres_mask).bool()
        return data_dict

class LangtableInstDataset(LangtableDataset):
    """Langtable dataset with language instructions."""

    def __init__(
        self,
        data_root,
        split,
        langtable_transform,
        n_sample_frames=16,
        frame_offset=None,
        video_len=50,
    ):
        super().__init__(
            data_root=data_root,
            split=split,
            langtable_transform=langtable_transform,
            n_sample_frames=n_sample_frames,
            frame_offset=frame_offset,
            video_len=video_len,
        )

        # language instructions
        self.inst_root = os.path.join(data_root, "labels")
        inst = np.load(os.path.join(self.inst_root, "inst.npy"))
        inst_word = np.load(os.path.join(self.inst_root, "inst_word.npy"))
        # mask = np.load(os.path.join(self.inst_root, "inst.npy"))
        self.inst = torch.from_numpy(inst).float()
        self.inst_word = torch.from_numpy(inst_word).float()
        # self.mask = torch.from_numpy(mask).long()

    def _read_insts(self, idx):
        folder, _ = self.valid_idx[idx]
        #inst_file = os.path.join(self.inst_root, os.path.basename(folder)+".npy")
        return self.inst[int(os.path.basename(folder))]

    def _read_insts_word(self, idx):
        folder, _ = self.valid_idx[idx]
        return self.inst_word[int(os.path.basename(folder))]

    def _decode_inst(self, inst):
        """Utlity to decode encoded language instruction"""
        return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8")

    def _read_insts_text(self, idx):
        folder, _ = self.valid_idx[idx]
        inst_file = os.path.join(self.inst_root, os.path.basename(folder)+".npy")
        return self._decode_inst(np.load(inst_file))
    
    def _read_mask(self, idx):
        folder, _ = self.valid_idx[idx]
        return self.inst[int(os.path.basename(folder))]
        

    def __getitem__(self, idx):
        """Data dict:
            - data_idx: int
            - img: [T, 3, H, W]
            - bbox: [T, max_num_obj, 4], empty, for compatibility
            - pres_mask: [T, max_num_obj], empty, for compatibility
        """
        frames = self._read_frames(idx)
        insts = self._read_insts(idx)
        insts_word = self._read_insts_word(idx)
        # mask = self._read_mask(idx)
        data_dict = {
            'data_idx': idx,
            'img': frames,
            'instruction': insts,
            'mask': None,
        }
        if self.split != 'train':
            bboxes, pres_mask = self._read_bboxes(idx)
            data_dict['bbox'] = torch.from_numpy(bboxes).float()
            data_dict['pres_mask'] = torch.from_numpy(pres_mask).bool()
        return data_dict

def build_langtable_dataset(params, val_only=False):
    """Build langtable video dataset."""
    args = dict(
        data_root=params.data_root,
        split='val',
        langtable_transform=BaseTransforms(params.resolution),
        n_sample_frames=params.n_sample_frames,
        frame_offset=params.frame_offset,
        video_len=params.video_len,
    )
    val_dataset = LangtableDataset(**args)
    if val_only:
        return val_dataset
    args['split'] = 'train'
    train_dataset = LangtableDataset(**args)
    return train_dataset, val_dataset


def build_langtable_slots_dataset(params, val_only=False):
    """Build Langtable video dataset with pre-computed slots."""
    slots = load_obj(params.slots_root)
    args = dict(
        data_root=params.data_root,
        video_slots=slots['val'],
        split='val',
        langtable_transform=BaseTransforms(params.resolution),
        n_sample_frames=params.n_sample_frames,
        frame_offset=params.frame_offset,
        video_len=params.video_len,
    )
    val_dataset = LangtableSlotsDataset(**args)
    if val_only:
        return val_dataset
    args['split'] = 'train'
    args['video_slots'] = slots['train']
    train_dataset = LangtableSlotsDataset(**args)
    return train_dataset, val_dataset

def build_langtable_inst_dataset(params, val_only=False):
    """Build Langtable video dataset with pre-computed slots."""
    args = dict(
        data_root=params.data_root,
        split='val',
        langtable_transform=BaseTransforms(params.resolution),
        n_sample_frames=params.n_sample_frames,
        frame_offset=params.frame_offset,
        video_len=params.video_len,
    )
    val_dataset = LangtableInstDataset(**args)
    if val_only:
        return val_dataset
    args['split'] = 'train'
    train_dataset = LangtableInstDataset(**args)
    return train_dataset, val_dataset
