from torch.utils.data import Dataset
from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal, focal2fov
import torch
from utils.camera_utils import loadCam
from utils.graphics_utils import focal2fov

class FourDGSdataset(Dataset):
    def __init__(
        self,
        dataset,
        args
    ):
        self.dataset = dataset
        self.args = args
    
    def __getitem__(self, index):

        try:
            image, w2c, time = self.dataset[index]
            R,T = w2c
            FovX = focal2fov(self.dataset.focal[0], image.shape[2])
            FovY = focal2fov(self.dataset.focal[0], image.shape[1])
        except:
            caminfo = self.dataset[index]
            image = caminfo.image
            mask = caminfo.mask
            depth = caminfo.depth
            R = caminfo.R
            T = caminfo.T
            FovX = caminfo.FovX
            FovY = caminfo.FovY
            Znear = caminfo.Znear
            Zfar = caminfo.Zfar
            time = caminfo.time
        
        return Camera(colmap_id=index,R=R,T=T,FoVx=FovX,FoVy=FovY,image=image, depth=depth, mask=mask, gt_alpha_mask=None,
                          image_name=f"{index}",uid=index,data_device=torch.device("cuda:0"),time=time,
                          Znear=Znear, Zfar=Zfar)
    
    def __len__(self):
        return len(self.dataset)
