import os
import random
import json
import torch

import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import numpy as np

from decord import VideoReader
from torch.utils.data.dataset import Dataset
from packaging import version as pver

def unpack_mm_params(p):
    if isinstance(p, (tuple, list)):
        return p[0], p[1]
    elif isinstance(p, (int, float)):
        return p, p
    raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}')


class RandomHorizontalFlipWithPose(nn.Module):
    def __init__(self, p=0.5):
        super(RandomHorizontalFlipWithPose, self).__init__()
        self.p = p

    def get_flip_flag(self, n_image):
        return torch.rand(n_image) < self.p

    def forward(self, image, flip_flag=None):
        n_image = image.shape[0]
        if flip_flag is not None:
            assert n_image == flip_flag.shape[0]
        else:
            flip_flag = self.get_flip_flag(n_image)

        ret_images = []
        for fflag, img in zip(flip_flag, image):
            if fflag:
                ret_images.append(F.hflip(img))
            else:
                ret_images.append(img)
        return torch.stack(ret_images, dim=0)


class Camera(object):
    def __init__(self, entry):
        fx, fy, cx, cy = entry[1:5]
        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy
        w2c_mat = np.array(entry[7:]).reshape(3, 4)
        w2c_mat_4x4 = np.eye(4)
        w2c_mat_4x4[:3, :] = w2c_mat
        self.w2c_mat = w2c_mat_4x4
        self.c2w_mat = np.linalg.inv(w2c_mat_4x4)


def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')


def ray_condition(K, c2w, H, W, device, flip_flag=None):
    # c2w: B, V, 4, 4
    # K: B, V, 4

    B, V = K.shape[:2]

    j, i = custom_meshgrid(
        torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
        torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
    )
    i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]
    j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]

    n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
    if n_flip > 0:
        j_flip, i_flip = custom_meshgrid(
            torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
            torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
        )
        i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        i[:, flip_flag, ...] = i_flip
        j[:, flip_flag, ...] = j_flip

    fx, fy, cx, cy = K.chunk(4, dim=-1)     # B,V, 1

    zs = torch.ones_like(i)                 # [B, V, HxW]
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    zs = zs.expand_as(ys)

    directions = torch.stack((xs, ys, zs), dim=-1)              # B, V, HW, 3
    directions = directions / directions.norm(dim=-1, keepdim=True)             # B, V, HW, 3

    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)        # B, V, HW, 3
    rays_o = c2w[..., :3, 3]                                        # B, V, 3
    rays_o = rays_o[:, :, None].expand_as(rays_d)                   # B, V, HW, 3
    # c2w @ dirctions
    rays_dxo = torch.cross(rays_o, rays_d)                          # B, V, HW, 3
    plucker = torch.cat([rays_dxo, rays_d], dim=-1)
    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)             # B, V, H, W, 6
    # plucker = plucker.permute(0, 1, 4, 2, 3)
    return plucker


class RealEstate10KPCDRenderDataset(Dataset):
    def __init__(
            self,
            video_root_dir,
            sample_n_frames=49,
            relative_pose=True,
            zero_t_first_frame=True,
            image_size=[480, 720],
            rescale_fxy=True,
            shuffle_frames=False,
            hflip_p=0.0,
    ):
        if hflip_p != 0.0:
            use_flip = True
        else:
            use_flip = False
        root_path = video_root_dir
        self.root_path = root_path
        self.relative_pose = relative_pose
        self.zero_t_first_frame = zero_t_first_frame
        self.sample_n_frames = sample_n_frames

        self.c2w_root = os.path.join(self.root_path, 'c2w_matrices')
        self.video_root = os.path.join(self.root_path, 'joint_videos')
        self.intrinsics_root = os.path.join(self.root_path, 'intrinsics')
        self.captions_root = os.path.join(self.root_path, 'captions')
        self.dataset = sorted([n.replace('.mp4','') for n in os.listdir(self.video_root)])
        self.length = len(self.dataset)
        sample_size = image_size
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        self.sample_size = sample_size
        if use_flip:
            pixel_transforms = [transforms.Resize(sample_size),
                                RandomHorizontalFlipWithPose(hflip_p),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
        else:
            pixel_transforms = [transforms.Resize(sample_size),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
        self.rescale_fxy = rescale_fxy
        self.sample_wh_ratio = sample_size[1] / sample_size[0]

        self.pixel_transforms = pixel_transforms
        self.shuffle_frames = shuffle_frames
        self.use_flip = use_flip

    def load_video_reader(self, idx):
        clip_name = self.dataset[idx]
        video_path = os.path.join(self.video_root, clip_name + '.mp4')
        video_reader = VideoReader(video_path)
        caption_path = os.path.join(self.captions_root, clip_name + '.txt')
        if os.path.exists(caption_path):
            caption = open(caption_path, 'r').read().strip()
        else:
            caption = ''
        return clip_name, video_reader, caption

    def get_batch(self, idx):
        clip_name, video_reader, video_caption = self.load_video_reader(idx)
        try:
            c2w_path = os.path.join(self.c2w_root, clip_name + '.npy')
            c2w_poses = np.load(c2w_path, allow_pickle=True)
            intrinsics_path = os.path.join(self.intrinsics_root, clip_name + '.npy')
            intrinsic_matrix = np.load(intrinsics_path, allow_pickle=True) # 3x3
            intrinsics = torch.tensor([[intrinsic_matrix[0, 0], intrinsic_matrix[1, 1], intrinsic_matrix[0, 2], intrinsic_matrix[1, 2]]], dtype=torch.float32).repeat(self.sample_n_frames, 1)
            c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
            intrinsics = intrinsics[None] # [1, n_frame, 4]
        except:
            c2w = torch.zeros(1, self.sample_n_frames, 4, 4, dtype=torch.float32)
            intrinsics = torch.zeros(1, self.sample_n_frames, 4, dtype=torch.float32)
            
        if self.use_flip:
            flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
        else:
            flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
        plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu', flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
        
        indices = np.arange(self.sample_n_frames)
        cated_pixels = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
        cated_pixels = cated_pixels / 255.
        
        # cut from the middle
        pixel_values = cated_pixels[:, :, :self.sample_size[0], :self.sample_size[1]]
        anchor_pixels = cated_pixels[:, :, :self.sample_size[0], self.sample_size[1]:]
        
        return pixel_values, anchor_pixels, video_caption, plucker_embedding, flip_flag, clip_name

    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        while True:
            try:
                video, anchor_video, video_caption, plucker_embedding, flip_flag, clip_name = self.get_batch(idx)
                break

            except Exception as e:
                idx = random.randint(0, self.length - 1)
        if self.use_flip:
            video = self.pixel_transforms[0](video)
            video = self.pixel_transforms[1](video, flip_flag)
            video = self.pixel_transforms[2](video)
            anchor_video = self.pixel_transforms[0](anchor_video)
            anchor_video = self.pixel_transforms[1](anchor_video, flip_flag)
            anchor_video = self.pixel_transforms[2](anchor_video)
        else:
            for transform in self.pixel_transforms:
                video = transform(video)
                anchor_video = transform(anchor_video)
        data = {
            'video': video, 
            'anchor_video': anchor_video,
            'caption': video_caption, 
            'controlnet_video': plucker_embedding,
        }
        return data
    
class RealEstate10KPCDRenderDataset(Dataset):
    def __init__(
            self,
            video_root_dir,
            sample_n_frames=49,
            relative_pose=True,
            zero_t_first_frame=True,
            image_size=[480, 720],
            rescale_fxy=True,
            shuffle_frames=False,
            hflip_p=0.0,
    ):
        if hflip_p != 0.0:
            use_flip = True
        else:
            use_flip = False
        root_path = video_root_dir
        self.root_path = root_path
        self.relative_pose = relative_pose
        self.zero_t_first_frame = zero_t_first_frame
        self.sample_n_frames = sample_n_frames

        self.c2w_root = os.path.join(self.root_path, 'c2w_matrices')
        self.video_root = os.path.join(self.root_path, 'joint_videos')
        self.intrinsics_root = os.path.join(self.root_path, 'intrinsics')
        self.captions_root = os.path.join(self.root_path, 'captions')
        self.dataset = sorted([n.replace('.mp4','') for n in os.listdir(self.video_root)])
        self.length = len(self.dataset)
        sample_size = image_size
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        self.sample_size = sample_size
        if use_flip:
            pixel_transforms = [transforms.Resize(sample_size),
                                RandomHorizontalFlipWithPose(hflip_p),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
        else:
            pixel_transforms = [transforms.Resize(sample_size),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
        self.rescale_fxy = rescale_fxy
        self.sample_wh_ratio = sample_size[1] / sample_size[0]

        self.pixel_transforms = pixel_transforms
        self.shuffle_frames = shuffle_frames
        self.use_flip = use_flip

    def load_video_reader(self, idx):
        clip_name = self.dataset[idx]
        video_path = os.path.join(self.video_root, clip_name + '.mp4')
        video_reader = VideoReader(video_path)
        caption_path = os.path.join(self.captions_root, clip_name + '.txt')
        if os.path.exists(caption_path):
            caption = open(caption_path, 'r').read().strip()
        else:
            caption = ''
        return clip_name, video_reader, caption

    def get_batch(self, idx):
        clip_name, video_reader, video_caption = self.load_video_reader(idx)
        try:
            c2w_path = os.path.join(self.c2w_root, clip_name + '.npy')
            c2w_poses = np.load(c2w_path, allow_pickle=True)
            intrinsics_path = os.path.join(self.intrinsics_root, clip_name + '.npy')
            intrinsic_matrix = np.load(intrinsics_path, allow_pickle=True) # 3x3
            intrinsics = torch.tensor([[intrinsic_matrix[0, 0], intrinsic_matrix[1, 1], intrinsic_matrix[0, 2], intrinsic_matrix[1, 2]]], dtype=torch.float32).repeat(self.sample_n_frames, 1)
            c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
            intrinsics = intrinsics[None] # [1, n_frame, 4]
        except:
            c2w = torch.zeros(1, self.sample_n_frames, 4, 4, dtype=torch.float32)
            intrinsics = torch.zeros(1, self.sample_n_frames, 4, dtype=torch.float32)
            
        if self.use_flip:
            flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
        else:
            flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
        plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu', flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
        
        indices = np.arange(self.sample_n_frames)
        cated_pixels = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
        cated_pixels = cated_pixels / 255.
        
        # cut from the middle
        pixel_values = cated_pixels[:, :, :self.sample_size[0], :self.sample_size[1]]
        anchor_pixels = cated_pixels[:, :, :self.sample_size[0], self.sample_size[1]:]
        
        return pixel_values, anchor_pixels, video_caption, plucker_embedding, flip_flag, clip_name

    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        while True:
            try:
                video, anchor_video, video_caption, plucker_embedding, flip_flag, clip_name = self.get_batch(idx)
                break

            except Exception as e:
                idx = random.randint(0, self.length - 1)
        if self.use_flip:
            video = self.pixel_transforms[0](video)
            video = self.pixel_transforms[1](video, flip_flag)
            video = self.pixel_transforms[2](video)
            anchor_video = self.pixel_transforms[0](anchor_video)
            anchor_video = self.pixel_transforms[1](anchor_video, flip_flag)
            anchor_video = self.pixel_transforms[2](anchor_video)
        else:
            for transform in self.pixel_transforms:
                video = transform(video)
                anchor_video = transform(anchor_video)
        data = {
            'video': video, 
            'anchor_video': anchor_video,
            'caption': video_caption, 
            'controlnet_video': plucker_embedding,
        }
        return data
    
class RealEstate10KPCDRenderCapEmbDataset(RealEstate10KPCDRenderDataset):
    def __init__(
            self,
            video_root_dir,
            text_embedding_path,
            sample_n_frames=49,
            relative_pose=True,
            zero_t_first_frame=True,
            image_size=[480, 720],
            rescale_fxy=True,
            shuffle_frames=False,
            hflip_p=0.0,
    ):
        super().__init__(
            video_root_dir,
            sample_n_frames=sample_n_frames,
            relative_pose=relative_pose,
            zero_t_first_frame=zero_t_first_frame,
            image_size=image_size,
            rescale_fxy=rescale_fxy,
            shuffle_frames=shuffle_frames,
            hflip_p=hflip_p,
        )
        self.text_embedding_path = text_embedding_path
        self.mask_root = os.path.join(self.root_path, 'masks')

    def get_batch(self, idx):
        clip_name, video_reader, video_caption = self.load_video_reader(idx)
        try:
            c2w_path = os.path.join(self.c2w_root, clip_name + '.npy')
            c2w_poses = np.load(c2w_path, allow_pickle=True)
            intrinsics_path = os.path.join(self.intrinsics_root, clip_name + '.npy')
            intrinsic_matrix = np.load(intrinsics_path, allow_pickle=True) # 3x3
            intrinsics = torch.tensor([[intrinsic_matrix[0, 0], intrinsic_matrix[1, 1], intrinsic_matrix[0, 2], intrinsic_matrix[1, 2]]], dtype=torch.float32).repeat(self.sample_n_frames, 1)
            c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
            intrinsics = intrinsics[None] # [1, n_frame, 4]
        except:
            c2w = torch.zeros(1, self.sample_n_frames, 4, 4, dtype=torch.float32)
            intrinsics = torch.zeros(1, self.sample_n_frames, 4, dtype=torch.float32)
    
        cap_emb_path = os.path.join(self.text_embedding_path, clip_name + '.pt')
        video_caption_emb = torch.load(cap_emb_path, weights_only=True)
        if self.use_flip:
            flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
        else:
            flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
        plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu', flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
        
        indices = np.arange(self.sample_n_frames)
        cated_pixels = torch.from_numpy(video_reader.get_batch(indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
        cated_pixels = cated_pixels / 255.
        
        # cut from the middle
        pixel_values = cated_pixels[:, :, :self.sample_size[0], :self.sample_size[1]]
        anchor_pixels = cated_pixels[:, :, :self.sample_size[0], self.sample_size[1]:]
        try:
            masks = np.load(os.path.join(self.mask_root, clip_name + '.npz'))['mask']*1.0
            masks = torch.from_numpy(masks).unsqueeze(1)
        except:
            threshold = 0.1  # you can adjust this value
            masks = (anchor_pixels.sum(dim=1, keepdim=True) < threshold).float()
        return pixel_values, anchor_pixels, masks, video_caption_emb, plucker_embedding, flip_flag, clip_name
    
    def __getitem__(self, idx):
        while True:
            try:
                video, anchor_video, mask, video_caption_emb, plucker_embedding, flip_flag, clip_name = self.get_batch(idx)
                break

            except Exception as e:
                idx = random.randint(0, self.length - 1)
        if self.use_flip:
            video = self.pixel_transforms[0](video)
            video = self.pixel_transforms[1](video, flip_flag)
            video = self.pixel_transforms[2](video)
            anchor_video = self.pixel_transforms[0](anchor_video)
            anchor_video = self.pixel_transforms[1](anchor_video, flip_flag)
            anchor_video = self.pixel_transforms[2](anchor_video)
        else:
            for transform in self.pixel_transforms:
                video = transform(video)
                anchor_video = transform(anchor_video)
        data = {
            'video': video, 
            'anchor_video': anchor_video,
            'caption_emb': video_caption_emb, 
            'controlnet_video': plucker_embedding,
            'mask': mask
        }
        return data