import torch
import torch.nn as nn
import os
import numpy as np
from PIL import Image
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras,
    PointsRasterizationSettings,
    PointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
)
from torch_scatter import scatter_mean

class ObjectRenderer(nn.Module):
    def __init__(self, eye, rasterizer_setting):
        super().__init__()
        self.image_size = rasterizer_setting.image_size
        self.eye = eye
        self.views = len(eye)
        self.renderer = PointsRenderer(
            rasterizer=PointsRasterizer(
                cameras=None, raster_settings=PointsRasterizationSettings(**rasterizer_setting)
            ), compositor=AlphaCompositor(background_color=(1.0, 1.0, 1.0))
        )
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, data_dict, mode):
        
        batch_size = len(data_dict["scan_idx"])
        batch_offset = data_dict["batch_offset"]
        self.R, self.T = look_at_view_transform(eye=self.eye, at=((0, 0, 0),), up=((0, 0, 1),), device=batch_offset.device)
        output_imgs = []
        context_imgs = []
        coords = data_dict["coords_float"]
        rgb = data_dict["rgb"]
        pred_masks = data_dict["pred_masks"]
        split = data_dict["split"]
        # scene_ids = {'scene0031_00', 'scene0087_00', 'scene0341_00', 'scene0457_00'}

        for i in range(batch_size):
            pred_mask = pred_masks[i]
            total_num = pred_mask.shape[0]
            batch_points_start_idx = batch_offset[i]
            batch_points_end_idx = batch_offset[i + 1]
            current_pcd_xyz = coords[batch_points_start_idx:batch_points_end_idx]
            current_pcd_rgb = rgb[batch_points_start_idx:batch_points_end_idx]

            aabb_xyz_list = []
            aabb_rgb_list = []
            context_xyz_list = []
            context_rgb_list = []
            for obj_i in range(total_num):
                current_obj_point_indicies = pred_mask[obj_i] > 0
                obj_pos = scatter_mean(current_pcd_xyz, 
                                 current_obj_point_indicies.long(), 
                                 dim=0)[1]
                dist_mat = torch.cdist(current_pcd_xyz, obj_pos.unsqueeze(0), p=2).squeeze()
                if not current_obj_point_indicies.any():
                    obj_pcd_xyz = torch.empty(size=(0, 3), device=batch_offset.device, dtype=torch.float32)
                    obj_pcd_rgb = torch.empty(size=(0, 3), device=batch_offset.device, dtype=torch.float32)
                else:
                    obj_pcd_xyz = current_pcd_xyz[current_obj_point_indicies]   # 坐标
                    obj_pcd_rgb = current_pcd_rgb[current_obj_point_indicies]   # rgb
                    obj_pcd_xyz -= obj_pcd_xyz.mean(dim=0)  # 往中心聚合
                    obj_pcd_xyz /= obj_pcd_xyz.abs().max()  # 归一化
                    # object_size = torch.cdist(obj_pcd_xyz, obj_pos.unsqueeze(0), p=2).max()
                    # context_point_indicies = dist_mat > (0.5 + object_size)
                    if mode == 'context':
                        object_point_num = current_obj_point_indicies.sum().item()
                        if dist_mat.shape[0] > object_point_num*3:
                            context_point_indicies = torch.topk(dist_mat, k=object_point_num*3, dim=0, largest=False, sorted=True)[1]
                        elif dist_mat.shape[0] > object_point_num*2:
                            context_point_indicies = torch.topk(dist_mat, k=object_point_num*2, dim=0, largest=False, sorted=True)[1]
                        else:
                            context_point_indicies = torch.topk(dist_mat, k=dist_mat.shape[0], dim=0, largest=False, sorted=True)[1]
                        context_obj_xyz = current_pcd_xyz[context_point_indicies]
                        context_obj_rgb = current_pcd_rgb[context_point_indicies]
                        context_obj_xyz -= context_obj_xyz.mean(dim=0)
                        context_obj_xyz /= context_obj_xyz.abs().max()
                for _ in range(self.views):
                    aabb_xyz_list.append(obj_pcd_xyz)
                    aabb_rgb_list.append(obj_pcd_rgb)
                    if mode == 'context':
                        context_xyz_list.append(context_obj_xyz)
                        context_rgb_list.append(context_obj_rgb)
            pytorch3d_pcd = Pointclouds(points=aabb_xyz_list, features=aabb_rgb_list)
            pytorch3d_pcd.device = batch_offset.device
            if mode == 'context':
                pytorch3d_context = Pointclouds(points=context_xyz_list, features=context_rgb_list)
                pytorch3d_context.device = batch_offset.device
            num_aabbs = len(aabb_xyz_list) // self.views
            
            R = self.R.expand(num_aabbs, -1, -1, -1).flatten(0, 1)
            T = self.T.expand(num_aabbs, -1, -1).flatten(0, 1)

            output_img = self.renderer(
                pytorch3d_pcd, dtype=torch.float32, device=batch_offset.device,
                cameras=FoVOrthographicCameras(device=batch_offset.device, R=R, T=T, znear=0.01)
            )
            output_imgs.append(output_img)
            if mode == 'context':
                context_img = self.renderer(
                    pytorch3d_context, dtype=torch.float32, device=batch_offset.device,
                    cameras=FoVOrthographicCameras(device=batch_offset.device, R=R, T=T, znear=0.01)
                )
                context_imgs.append(context_img)
            # save_sign = input('If you want to save the image? (y/n)')
            # save_sign = 'y'
            save_sign = 'n'
            # if data_dict['scene_id'][i] in scene_ids:
            #     save_sign = 'y'
            if save_sign == 'y':
                scan_id = data_dict["scene_id"][i]
                imag_path = os.path.join("data", "multiview", split , scan_id)
                os.makedirs(imag_path, exist_ok=True)
                i = -1
                for img_id in range(output_img.shape[0]):
                    if img_id % self.views == 0:
                        i += 1
                    imag_filename = os.path.join(imag_path, 'img_' + str(i) + "_" + str(img_id % self.views) +'.png')
                    img_rgb = output_img[img_id].cpu().numpy()
                    img_rgb = (img_rgb * 255).astype(np.uint8)
                    img = Image.fromarray(img_rgb)
                    # 保存图像
                    img.save(imag_filename)
                    if mode == 'context':
                        context_imag_filename = os.path.join(imag_path, 'img_context_' + str(i) + "_" + str(img_id % self.views) +'.png')
                        context_img_rgb = context_img[img_id].cpu().numpy()
                        context_img_rgb = (context_img_rgb * 255).astype(np.uint8)
                        img = Image.fromarray(context_img_rgb)
                        # 保存图像
                        img.save(context_imag_filename)   

        output_imgs = torch.cat(output_imgs, 0)
        if mode == 'context':
            context_imgs = torch.cat(context_imgs, 0)
        return output_imgs, context_imgs
