import os
import torch
import torchvision
import tqdm
import numpy as np
from nerf.provider import NeRFDataset
from nerf.utils import *
from nerf.network_particle import NeRFNetwork
# from camera import get_rays  # 用于计算 ray_origins 和 ray_directions

def get_rays(H, W, focal, cam_pose):
        """
        计算 ray_origins (rays_o) 和 ray_directions (rays_d)
        
        H, W: 图像高度和宽度
        focal: 相机焦距
        cam_pose: 相机外参 (4x4 矩阵)
        
        返回:
        rays_o: (H*W, 3) 光线起点
        rays_d: (H*W, 3) 光线方向
        """
        i, j = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy")
        
        # 计算像素归一化坐标
        dirs = torch.stack([(i - W * 0.5) / focal, 
                            -(j - H * 0.5) / focal, 
                            -torch.ones_like(i)], dim=-1)  # 变成 (H, W, 3)

        # 变换到世界坐标系
        rays_d = (dirs[..., None, :] @ cam_pose[:3, :3].T).squeeze(-2)  # 方向旋转
        rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)  # 归一化

        # 光线起点就是相机位置
        rays_o = cam_pose[:3, 3].expand_as(rays_d)

        return rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)

class Renderer:
    def __init__(self, checkpoint_path, device="cuda"):
        self.device = device
        self.opt = None
        self.model = NeRFNetwork(self.opt).to(self.device)  # 加载 NeRFNetwork
        self.load_checkpoint(checkpoint_path)

    

    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model"])
        print(f"Checkpoint loaded from {checkpoint_path}")

    def generate_images(self, output_dir, num_views=10, image_size=256):
        os.makedirs(output_dir, exist_ok=True)
        self.model.eval()

        with torch.no_grad():
            for idx in tqdm.tqdm(range(num_views), desc="Rendering Views"):
                rays_o, rays_d = self.get_rays_for_view(idx, image_size)

                # 渲染
                preds, preds_depth = self.render_single_view(rays_o, rays_d)

                # 保存 RGB 图像
                rgb_path = os.path.join(output_dir, f"view_{idx:03d}_rgb.png")
                torchvision.utils.save_image(preds.permute(2, 0, 1), rgb_path, normalize=True, range=(0, 1))

                # 保存深度图
                depth_path = os.path.join(output_dir, f"view_{idx:03d}_depth.png")
                torchvision.utils.save_image(preds_depth.unsqueeze(0), depth_path, normalize=True)

                print(f"Saved {rgb_path} and {depth_path}")

    def get_rays_for_view(self, view_idx, image_size=256):
        """
        生成指定视角的 ray_origins 和 ray_directions
        """
        # 假设有一个预定义的相机轨迹
        cam_pose = self.get_camera_pose(view_idx)  
        rays_o, rays_d = self.get_rays(image_size, image_size, self.model.focal, cam_pose)
        return rays_o.to(self.device), rays_d.to(self.device)

    def get_camera_pose(self, view_idx):
        """
        生成虚拟相机的视角轨迹，这里可以用预定义的轨迹（比如绕着物体旋转）
        """
        theta = 2 * np.pi * (view_idx / 10)  # 让视角绕着物体旋转
        cam_pose = np.array([
            [np.cos(theta), 0, np.sin(theta), 0],
            [0, 1, 0, 0],
            [-np.sin(theta), 0, np.cos(theta), -3],  # 让相机离物体 3 米
            [0, 0, 0, 1]
        ])
        return torch.tensor(cam_pose, dtype=torch.float32).to(self.device)

    def render_single_view(self, rays_o, rays_d):
        """
        渲染一个视角的图像
        """
        with torch.cuda.amp.autocast():
            preds, preds_depth = self.model.render(rays_o, rays_d)
        return preds, preds_depth

if __name__ == "__main__":
    checkpoint_path = "/gpfs/share/home/2301111469/sim_3d_reward/exp7/new_allstage_pickscore100_sim_cfg4.5_cutebunny_c7.5_stage3/2025-01-27-scale-7.5-lr-0.001-albedo-le-10-render-512-cube-sd-2.1-5000-finetune-dth-0.2-tet-256/checkpoints/best_df_ep0300.pth"
    output_dir = "outputs/rendered_nerf"
    num_views = 10  # 生成 10 个视角

    renderer = Renderer(checkpoint_path)
    renderer.generate_images(output_dir, num_views)
