import os
import torch
from src.models.network.lgm.models import LGM
from src.models.network.lgm.options import config_defaults
import numpy as np
import tqdm
from kiui.op import recenter
from kiui.cam import orbit_camera

input_image = torch.randn(1, 4, 10, 256, 256).cuda()

opt = config_defaults['big']
model = LGM(opt).cuda()

mode_path = "/mnt/petrelfs/wenhao1/3D/LAST/pretrain/LGM/model.safetensors"
from safetensors import safe_open

tensors = {}
with safe_open(mode_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)

timesteps = torch.tensor([1]).cuda()
output = model.forward_gaussians(input_image, timesteps)


device = "cuda"
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1

images = []
elevation = 0
azimuth = np.arange(0, 360, 2, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
    
    cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)

    cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
    
    # cameras needed by gaussian rasterizer
    cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
    cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
    cam_pos = - cam_poses[:, :3, 3] # [V, 3]

    image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
    images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
