import numpy as np
from PIL import Image
#import open3d as o3d
import trimesh
import imageio
import torch
import math
import os
import json, ast

def get_number_of_images(pose_intrinsic_json):
    f = open(pose_intrinsic_json)
    data = json.load(f)
    return len(data)

class Camera:
    def __init__(self,  
                 intrinsic_resolution, 
                 poses_intrinsic_path, 
                 depths_path, 
                 extension_depth, 
                 depth_scale):
        file = open(poses_intrinsic_path)
        self.poses_intrinsic_json = ast.literal_eval(json.dumps(json.load(file)))
        self.intrinsic = np.array(self.poses_intrinsic_json['frame_000000']['intrinsic'])
        self.intrinsic_original_resolution = intrinsic_resolution
        self.poses_intrinsic_path = poses_intrinsic_path
        self.depths_path = depths_path
        self.extension_depth = extension_depth
        self.depth_scale = depth_scale
    
    def get_adapted_intrinsic(self, desired_resolution):
        '''Get adjusted camera intrinsics.'''
        if self.intrinsic_original_resolution == desired_resolution:
            return self.intrinsic
        
        resize_width = int(math.floor(desired_resolution[1] * float(
                        self.intrinsic_original_resolution[0]) / float(self.intrinsic_original_resolution[1])))
        
        adapted_intrinsic = self.intrinsic.copy()
        adapted_intrinsic[0, 0] *= float(resize_width) / float(self.intrinsic_original_resolution[0])
        adapted_intrinsic[1, 1] *= float(desired_resolution[1]) / float(self.intrinsic_original_resolution[1])
        adapted_intrinsic[0, 2] *= float(desired_resolution[0] - 1) / float(self.intrinsic_original_resolution[0] - 1)
        adapted_intrinsic[1, 2] *= float(desired_resolution[1] - 1) / float(self.intrinsic_original_resolution[1] - 1)
        return adapted_intrinsic
    
    def load_poses(self, indices):
        
        frame_id = f"frame_000000"
        shape = np.linalg.inv(np.array(self.poses_intrinsic_json[frame_id]['aligned_pose']))[:3, :].shape
        poses = np.zeros((len(indices), shape[0], shape[1]))
        for i, idx in enumerate(indices):
            frame_id = f"frame_{str(idx).zfill(6)}"
            poses[i] = np.linalg.inv(np.array(self.poses_intrinsic_json[frame_id]['aligned_pose']))[:3, :]
        return poses
    
    def load_depth(self, idx, depth_scale):
        # raise NotImplementedError
        depth_path = os.path.join(self.depths_path, f"frame_{str(idx).zfill(6)}" + self.extension_depth)
        sensor_depth = imageio.v2.imread(depth_path) / depth_scale
        return sensor_depth


class Images:
    def __init__(self, 
                 images_path, 
                 extension, 
                 indices):
        self.images_path = images_path
        self.extension = extension
        self.indices = indices
        self.images = self.load_images(indices)
    
    def load_images(self, indices):
        images = []
        for idx in indices:
            img_path = os.path.join(self.images_path, f"frame_{str(idx).zfill(6)}{self.extension}")
            images.append(Image.open(img_path).convert("RGB"))
        return images
    
    def get_as_np_list(self):
        images = []
        for i in range(len(self.images)):
            images.append(np.asarray(self.images[i]))
        return images
    
class InstanceMasks3D:
    def __init__(self, masks_path):
        # self.masks = torch.from_numpy(np.load(masks_path))
        self.masks = torch.load(masks_path)
        self.num_masks = self.masks.shape[1]
    
    
class PointCloud:
    def __init__(self, 
                 point_cloud_path):
        #pcd = o3d.io.read_point_cloud(point_cloud_path)
        #self.points = np.asarray(pcd.points)
        pcd = trimesh.load(point_cloud_path, process=False)
        self.points = np.asarray(pcd.vertices)
        self.num_points = self.points.shape[0]
    
    def get_homogeneous_coordinates(self):
        return np.append(self.points, np.ones((self.num_points,1)), axis = -1)
    