import warnings
warnings.filterwarnings("ignore", message="TORCH_CUDA_ARCH_LIST")
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import json
import tyro
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import threading
import queue
import multiprocessing
import concurrent.futures
import import_helper

# load functions from GaussianAvatars
from scene import Camera
from gaussian_renderer import GaussianModel, FlameGaussianModel
from mesh_renderer import NVDiffRenderer
from utils.camera_utils import loadCam
from scene.dataset_readers import CameraInfo, focal2fov, fov2focal
from gaussian_renderer import render

@dataclass
class PipelineConfig:
    debug: bool = False
    compute_cov3D_python: bool = False
    convert_SHs_python: bool = False

@dataclass
class Config:
    pipeline: PipelineConfig = field(default_factory=PipelineConfig)
    point_path: Optional[Path] = None
    motion_path: Optional[Path] = None
    output_path: Optional[Path] = None
    sh_degree: int = 3
    background_color: tuple[float, float, float] = (1., 1., 1.)
    fps: int = 25

def write_data(path2data):
    for path, data in path2data.items():
        if not path.parent.exists():
            path.parent.mkdir(parents=True, exist_ok=True)

        if path.suffix in [".png", ".jpg"]:
            data = data.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
            Image.fromarray(data).save(path)
        elif path.suffix in [".obj"]:
            with open(path, "w") as f:
                f.write(data)
        elif path.suffix in [".txt"]:
            with open(path, "w") as f:
                f.write(data)
        elif path.suffix in [".npz"]:
            np.savez(path, **data)
        else:
            raise NotImplementedError(f"Unknown file type: {path.suffix}")

def readCamerasFromTransform(path, transformsfile, white_background, extension=".png"):
    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        if 'camera_angle_x' in contents:
            fovx_shared = contents["camera_angle_x"]

        frame = contents["frames"][0]
        file_path = frame["file_path"]
        if extension not in frame["file_path"]:
            file_path += extension
        cam_name = os.path.join(path, file_path)

        c2w = np.array(frame["transform_matrix"])
        # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
        c2w[:3, 1:3] *= -1

        # get the world-to-camera transform and set R, T
        w2c = np.linalg.inv(c2w)
        R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
        T = w2c[:3, 3]

        bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])

        image_path = os.path.join(path, cam_name)
        image_name = Path(cam_name).stem
        
        if 'w' in frame and 'h' in frame:
            image = None
            width = frame['w']
            height = frame['h']
        else:
            image = Image.open(image_path)
            im_data = np.array(image.convert("RGBA"))
            norm_data = im_data / 255.0
            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
            width, height = image.size

        if 'camera_angle_x' in frame:
            fovx = frame["camera_angle_x"]
        else:
            fovx = fovx_shared
        fovy = focal2fov(fov2focal(fovx, width), height)

        timestep = frame["timestep_index"] if 'timestep_index' in frame else None
        camera_id = frame["camera_index"] if 'camera_id' in frame else None
            
        cam_info = CameraInfo(
            uid=0, R=R, T=T, FovY=fovy, FovX=fovx, bg=bg, image=image, 
            image_path=image_path, image_name=image_name, 
            width=width, height=height, 
            timestep=timestep, camera_id=camera_id)
    return cam_info

def loadCam(id, cam_info, resolution_scale = 1.0):
    orig_w, orig_h = cam_info.width, cam_info.height
    image_width, image_height = round(orig_w/(resolution_scale)), round(orig_h/(resolution_scale))

    return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
                  FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
                  image_width=image_width, image_height=image_height,
                  bg=cam_info.bg, 
                  image=cam_info.image, 
                  image_path=cam_info.image_path,
                  image_name=cam_info.image_name, uid=id, 
                  timestep=cam_info.timestep, data_device='cuda')

class GaussianRender:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.cam = loadCam(0, readCamerasFromTransform(self.cfg.point_path.parent, "transform.json", True))
        
        # Initialize Gaussian model
        print("Initializing 3D Gaussians...")
        self.init_gaussians()

        if self.gaussians.binding is not None:
            self.mesh_renderer = NVDiffRenderer(use_opengl=False)
            self.num_timesteps = self.gaussians.num_timesteps
            self.reset_flame_param()

        # Parallel rendering setup
        self.num_workers = 8 # Adjust based on GPU capacity
        self.streams = [torch.cuda.Stream() for _ in range(self.num_workers)]
        self.task_queue = queue.Queue()
        self.lock = threading.Lock()
        self.background_tensor = torch.tensor(self.cfg.background_color).to('cuda')

    def init_gaussians(self):
        if (Path(self.cfg.point_path).parent / "flame_param.npz").exists(): 
            self.gaussians = FlameGaussianModel(self.cfg.sh_degree)
        else:
            self.gaussians = GaussianModel(self.cfg.sh_degree)

        if self.cfg.point_path is not None:
            if self.cfg.point_path.exists():
                self.gaussians.load_ply(self.cfg.point_path, has_target=False, motion_path=self.cfg.motion_path)
            else:
                raise FileNotFoundError(f'{self.cfg.point_path} does not exist.')

    def reset_flame_param(self):
        self.flame_param = {
            'expr': torch.zeros(1, self.gaussians.n_expr),
            'rotation': torch.zeros(1, 3),
            'neck': torch.zeros(1, 3),
            'jaw': torch.zeros(1, 3),
            'eyes': torch.zeros(1, 6),
            'translation': torch.zeros(1, 3)
        }

    def render_set(self):
        max_threads = multiprocessing.cpu_count()
        worker_args = []
        output_folder = self.cfg.point_path.parent if self.cfg.output_path is None else self.cfg.output_path
        os.makedirs(self.cfg.output_path, exist_ok=True) 
        
        for timestep in tqdm(range(0, self.num_timesteps)):
            if self.gaussians.binding is not None:
                self.gaussians.select_mesh_by_timestep(timestep)
            background = torch.tensor(self.cfg.background_color, dtype=torch.float32, device="cuda")
            rendering = render(self.cam, self.gaussians, self.cfg.pipeline, background)["render"]
            path2data = {}
            path2data[Path(output_folder) / f'{timestep:05d}.png'] = rendering
            worker_args.append([path2data])

            if len(worker_args) == max_threads or timestep == self.num_timesteps-1:
                with concurrent.futures.ThreadPoolExecutor(max_threads) as executor:
                    futures = [executor.submit(write_data, *args) for args in worker_args]
                    concurrent.futures.wait(futures)
                worker_args = []
        
        try:
            os.system(
                f"ffmpeg -y -framerate {self.cfg.fps} -f image2 -pattern_type glob -i '{output_folder}/*.png' "
                f"-c:v libx264 -preset veryslow -qp 0 -pix_fmt yuv444p {output_folder}/renders.mp4"
            )
            print(f"Video saved to {output_folder}/renders.mp4")
        except Exception as e:
            print(f"Video creation failed: {str(e)}")
            
        
if __name__ == "__main__":
    cfg = tyro.cli(Config)
    renderer = GaussianRender(cfg)
    renderer.render_set()