# -*- coding: utf-8 -*-
import glob
import json
import math
import os
import sys
import time
from pathlib import Path

import numpy as np
import torch
import torchvision
from torchvision import transforms
from PIL import Image

from .shared_dataset import SharedDataset
from utils.camera_utils import get_loop_cameras, get_rays
from utils.camera_utils import get_loop_cameras, camera_normalization_objaverse, build_camera_principle

from utils.graphics_utils import getProjectionMatrix

from utils.image_utils import make_normalize_transform

OBJAVERSE_ROOT_FIX = None  # Change this to your data directory
OBJAVERSE_ROOT_RANDOM = None
OBJAVERSE_LVIS_ANNOTATION_PATH = None

assert OBJAVERSE_LVIS_ANNOTATION_PATH is not None, "Update filtering .json path"


class ObjaverseDataset(SharedDataset):
    def __init__(self,
                 cfg,
                 dataset_name="train",
                 total_view_input=4,
                 total_view_sup=32
                 ) -> None:

        super(ObjaverseDataset).__init__()
        self.cfg = cfg
        # self.root_dir = OBJAVERSE_ROOT
        self.root_dir_input = Path(OBJAVERSE_ROOT_FIX)
        self.root_dir_sup = Path(OBJAVERSE_ROOT_RANDOM)
        self.total_view_input = self.cfg.data.input_images
        self.total_view_sup = total_view_sup
        # load the file names
        with open(OBJAVERSE_LVIS_ANNOTATION_PATH) as f:
            self.object_ids = json.load(f)
        train_rate = 99.9
        print('total number of training objects: ', len(self.object_ids))
        self.resize = transforms.Resize((self.cfg.data.training_resolution,self.cfg.data.training_resolution))
        self.normalize = transforms.Compose(
            [   
                make_normalize_transform(),
            ]
        )
        
        # split the dataset for training and validation
        total_objects = len(self.object_ids)
        self.dataset_name = dataset_name
        if self.dataset_name == "val" or dataset_name == "vis":
            # validation or visualisation on Objaverse
            self.object_ids = self.object_ids[math.floor(total_objects / 100. * train_rate):]  # used last 1% as validation
        elif self.dataset_name == 'val_vis_test':
            with open(OBJAVERSE_TEST_OBJ_PATH) as f:
                self.object_ids = json.load(f)
        elif self.dataset_name == "test":
            raise NotImplementedError  # Objaverse does not have separate test subset
        else:
            self.object_ids = self.object_ids[:math.floor(total_objects / 100. * train_rate)]  # used first 99% as training

        if cfg.data.subset != -1:
            self.object_ids = self.object_ids[:cfg.data.subset]

        print('============= length of dataset %d =============' % len(self.object_ids))

        self.projection_matrix = getProjectionMatrix(
            znear=self.cfg.data.znear, zfar=self.cfg.data.zfar,
            fovX=cfg.data.fov * 2 * np.pi / 360,
            fovY=cfg.data.fov * 2 * np.pi / 360).transpose(0, 1)

        self.image_side_target = self.cfg.data.training_resolution
        self.opengl_to_colmap = torch.tensor([[1,  0,  0,  0],
                                              [0, -1,  0,  0],
                                              [0,  0, -1,  0],
                                              [0,  0,  0,  1]], dtype=torch.float32)

        self.imgs_per_obj_train = self.cfg.opt.imgs_per_obj
        

    def __len__(self):
        return len(self.object_ids)

    def load_im(self, path, color=(0,0,0,0)):
        pixel_img = Image.open(path)
        
        # transform to tensor
        bg_color = torch.tensor([1., 1., 1.], dtype=torch.float32).unsqueeze(1).unsqueeze(2)
        img = torchvision.transforms.functional.pil_to_tensor(pixel_img) / 255.0
        fg_mask = self.resize(img[3:, ...])
        img = self.resize(img[:3, ...])
        img = img * fg_mask + bg_color * (1 - fg_mask)
        norm_img = self.normalize(img)
        return img, norm_img, fg_mask

    def load_imgs_and_convert_cameras(self, index, num_views):
         # input views
        if self.total_view_input == 2:
            VIEWS = ["_front", "_back"]
        elif self.total_view_input == 4:
            VIEWS = ["_front", "_left", "_back", "_right"]
        elif self.total_view_input == 6:
            VIEWS = ["_front", "_left", "_front_right", "_back", "_right", "_back_left"]
        elif self.total_view_input == 14:
            VIEWS = ["_front", "_left", "_bottom", "_back", "_right", "_top", "_front_right", "_front_left", "_back_right", "_back_left", "_front_right_top", "_front_left_top", "_back_right_top", "_back_left_top"]
        else:
            raise NotImplementedError(f"not implemented for {self.total_view_input}")
        
        data_dict = {}
        # read input images and extrinsics
        foldername_input = os.path.join(self.root_dir_input, self.object_ids[index])
        foldername_sup = os.path.join(self.root_dir_sup, self.object_ids[index])
        # validation dataset is used for scoring - fix cond frame for reproducibility
        # in trainng need to randomly sample the conditioning frame
        total_views = self.total_view_input + self.total_view_sup
        if self.dataset_name != "train":
            indexes = torch.arange(total_views)
        else:
            indexes = torch.randperm(total_views)[:num_views]
            indexes = torch.cat([torch.arange(self.total_view_input), indexes], dim=0)

        world_view_transforms = []
        view_world_transforms = []
        w2c_source = []

        camera_centers = []
        imgs = []
        norm_imgs = []
        fg_masks = []

        camera_centers = []
        for i in indexes:
            if i < self.total_view_input:
                path = os.path.join(foldername_input, f"rgb_000{VIEWS[i]}.webp")
                img, norm_img, fg_mask = self.load_im(path)
                # w2c
                path = os.path.join(foldername_input, f"000{VIEWS[i]}_RT.txt")
                w2c_cmo = torch.from_numpy(np.loadtxt(path)).to(dtype=torch.float32)
            else:
                idx = i - self.total_view_input
                path = os.path.join(foldername_sup, f"rgb_000_{idx:03d}.webp")
                img, norm_img, fg_mask = self.load_im(path)
                # w2c
                path = os.path.join(foldername_sup, f"000_{idx:03d}_RT.txt")
                w2c_cmo = torch.from_numpy(np.loadtxt(path)).to(dtype=torch.float32)
            imgs.append(img)
            fg_masks.append(fg_mask)
            norm_imgs.append(norm_img)
            w2c_source.append(w2c_cmo)
            w2c_cmo = torch.cat([w2c_cmo, torch.tensor([[0, 0, 0, 1]], dtype=torch.float32)], dim=0)  # 4x4
            # camera poses in .npy files are in OpenGL convention:
            #     x right, y up, z into the camera (backward),
            # need to transform to COLMAP / OpenCV:
            #     x right, y down, z away from the camera (forward)
            w2c_cmo = torch.matmul(self.opengl_to_colmap, w2c_cmo)
            # need row major oder
            world_view_transform = w2c_cmo.transpose(0, 1)
            view_world_transform = w2c_cmo.inverse().transpose(0, 1)
            camera_center = view_world_transform[3, :3].clone()
           
            world_view_transforms.append(world_view_transform)
            view_world_transforms.append(view_world_transform)
            # full_proj_transforms.append(full_proj_transform)
            camera_centers.append(camera_center)
        
        imgs = torch.stack(imgs)
        norm_imgs = torch.stack(norm_imgs)
        fg_masks = torch.stack(fg_masks)
        w2c_source = torch.stack(w2c_source, dim=0) #(n_view+input_view, 3, 4)
        world_view_transforms = torch.stack(world_view_transforms) # (n_view+input_view, 4, 4)
        view_world_transforms = torch.stack(view_world_transforms)
        camera_centers = torch.stack(camera_centers)

        pps_pixels = torch.zeros((imgs.shape[0], 2))
        
        if self.cfg.data.mod_camera_dec:
            poses = camera_normalization_objaverse(normed_dist_to_center='auto', poses=w2c_source)
            intrinsics = torch.tensor([self.cfg.data.intrinsics[:2], self.cfg.data.intrinsics[2:], [128, 128]]).repeat(len(indexes), 1, 1)
            source_camera = build_camera_principle(poses, intrinsics)
        # fix the distance of the source camera to the object / world center
        assert torch.norm(camera_centers[0]) > 1e-5, \
            "Camera is at {} from center".format(torch.norm(camera_centers[0]))
        translation_scaling_factor = 2.0 / torch.norm(camera_centers[0])
        world_view_transforms[:, 3, :3] *= translation_scaling_factor
        view_world_transforms[:, 3, :3] *= translation_scaling_factor
        camera_centers *= translation_scaling_factor

        full_proj_transforms = world_view_transforms.bmm(self.projection_matrix.unsqueeze(0).expand(
            world_view_transforms.shape[0], 4, 4))
       
        data_dict = {"gt_images": imgs,
                    "norm_imgs": norm_imgs,
                    "w2c_source": w2c_source,
                    "world_view_transforms": world_view_transforms,
                    "view_to_world_transforms": view_world_transforms,
                    "full_proj_transforms": full_proj_transforms,
                    "camera_centers": camera_centers,
                    "pps_pixels": pps_pixels,
                    "fg_masks": fg_masks}
        if self.cfg.data.mod_camera_dec:
            data_dict['source_camera'] = source_camera
        
        return data_dict


    def load_loop(self, index, num_imgs_in_loop):
        w2c_source = []
        world_view_transforms = []
        view_world_transforms = []
        camera_centers = []
        imgs = []
        norm_imgs = []
        fg_masks = []
        gt_imgs_and_cameras = self.load_imgs_and_convert_cameras(index, self.total_view_input + self.total_view_sup)
        loop_cameras_c2w_cmo, camera_poses = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop)

        for src_idx in range(self.cfg.data.input_images):
            imgs.append(gt_imgs_and_cameras["gt_images"][src_idx])
            fg_masks.append(gt_imgs_and_cameras["fg_masks"][src_idx])
            camera_centers.append(gt_imgs_and_cameras["camera_centers"][src_idx])
            world_view_transforms.append(gt_imgs_and_cameras["world_view_transforms"][src_idx])
            view_world_transforms.append(gt_imgs_and_cameras["view_to_world_transforms"][src_idx])
            w2c_source.append(gt_imgs_and_cameras["w2c_source"][src_idx])

        for loop_idx in range(len(loop_cameras_c2w_cmo)):
            loop_camera_c2w_cmo = loop_cameras_c2w_cmo[loop_idx]
            view_world_transform = torch.from_numpy(loop_camera_c2w_cmo).transpose(0, 1)
            world_view_transform = torch.from_numpy(loop_camera_c2w_cmo).inverse().transpose(0, 1)
            camera_center = view_world_transform[3, :3].clone()

            camera_centers.append(camera_center)
            world_view_transforms.append(world_view_transform)
            view_world_transforms.append(view_world_transform)
            w2c_source.append(camera_poses[loop_idx])

            # use the closest camera as reference gt image
            closest_gt_idx = torch.argmin(torch.norm(
                gt_imgs_and_cameras["camera_centers"] - camera_center.unsqueeze(0), dim=-1)).item()
            imgs.append(gt_imgs_and_cameras["gt_images"][closest_gt_idx])
            norm_imgs.append(self.normalize(gt_imgs_and_cameras["gt_images"][closest_gt_idx]))
            fg_masks.append(gt_imgs_and_cameras["fg_masks"][closest_gt_idx])
            
        imgs = torch.stack(imgs)
        norm_imgs = torch.stack(norm_imgs)
        fg_masks = torch.stack(fg_masks)
        world_view_transforms = torch.stack(world_view_transforms)
        view_world_transforms = torch.stack(view_world_transforms)
        camera_centers = torch.stack(camera_centers)
        
        full_proj_transforms = world_view_transforms.bmm(self.projection_matrix.unsqueeze(0).expand(
            world_view_transforms.shape[0], 4, 4))

        pps_pixels = torch.zeros((imgs.shape[0], 2))

        if self.cfg.data.mod_camera_dec:
            w2c_source = torch.stack(w2c_source, dim=0)
            poses = camera_normalization_objaverse(normed_dist_to_center='auto', poses=w2c_source)
            intrinsics = torch.tensor([self.cfg.data.intrinsics[:2], self.cfg.data.intrinsics[2:], [128, 128]]).repeat(w2c_source.shape[0], 1, 1)
            source_camera = build_camera_principle(poses, intrinsics)
        data_dict = {"gt_images": imgs.to(memory_format=torch.channels_last),
                    "norm_imgs": norm_imgs,
                    "w2c_source": w2c_source,
                    "world_view_transforms": world_view_transforms,
                    "view_to_world_transforms": view_world_transforms,
                    "full_proj_transforms": full_proj_transforms,
                    "camera_centers": camera_centers,
                    "pps_pixels": pps_pixels,
                    'fg_masks': fg_masks}
        if self.cfg.data.mod_camera_dec:
            data_dict['source_camera'] = source_camera
        return data_dict

    def get_example_id(self, index):
        example_id = self.object_ids[index]
        return example_id

    def __getitem__(self, index):
        # # load the rendered images

        if self.dataset_name == "vis" or self.dataset_name == "val_vis_test":
            images_and_camera_poses = self.load_loop(index, 100)
        else:
            if self.dataset_name == "train":
                num_views = self.imgs_per_obj_train
            else:
                num_views = self.total_view_sup + self.total_view_input

            images_and_camera_poses = self.load_imgs_and_convert_cameras(index, num_views)
        images_and_camera_poses = self.make_poses_relative_to_first(images_and_camera_poses)
        if self.cfg.data.use_plucker_emb:
            plucker_embs = []
            for input_idx in range(self.cfg.data.input_images):
                rays_o, rays_d = get_rays(images_and_camera_poses["view_to_world_transforms"][input_idx], self.cfg.data.training_resolution, self.cfg.data.training_resolution, self.cfg.data.fov, opengl=False) # [h, w, 3]
                plucker_emb = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
                plucker_embs.append(plucker_emb)
                 
            plucker_embs = torch.stack(plucker_embs, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
            images_and_camera_poses["plucker_emb"] = plucker_embs
        images_and_camera_poses["source_cv2wT_quat"] = self.get_source_cw2wT(images_and_camera_poses["view_to_world_transforms"])
        start = time.time()
        # images_and_camera_poses["example_id"] = self.object_ids[index]
        return images_and_camera_poses
