"""
Live mode entry point for the module Gaussian Head Renderer, as described in the paper.
"""

import socket
import tyro
import json
import threading
from exp_mapper import ExpressionMapper
import socket, json, threading, time, queue
import torch
import sys
from pathlib import Path

import import_helper
from local_viewer import LocalViewer, Config


UPDATE_FPS = 60
UPDATE_INTERVAL = 1.0 / UPDATE_FPS

def socket_receiver(sock, q):
    while True:
        data, _ = sock.recvfrom(65536)
        try:
            msg = json.loads(data)
            with q.mutex:
                q.queue.clear()
            q.put(msg)
        except:
            pass

def gui_updater(q, gui, mapper):
    cur_expr = cur_jaw = cur_head = cur_eye = None
    last_update = time.time()
    while True:
        try:
            latest = q.get(timeout=UPDATE_INTERVAL)
            if "/W" in latest:
                expr, jaw = mapper.get_expr_and_jaw([latest["/W"][:-1]])
                cur_expr, cur_jaw = expr, jaw
            if "/HR" in latest:
                cur_head = mapper.get_head_rotation([latest["/HR"]])
            if "/ER" in latest:
                cur_eye = mapper.get_eye_rotation([latest["/ER"]])

            now = time.time()
            if now - last_update >= UPDATE_INTERVAL:
                if cur_expr is not None and cur_jaw is not None:
                    gui.flame_param['expr'] = torch.tensor(cur_expr, dtype=torch.float32)
                    gui.flame_param['jaw']  = torch.tensor(cur_jaw, dtype=torch.float32)
                if cur_head is not None:
                    gui.flame_param['neck'] = torch.tensor(cur_head, dtype=torch.float32)
                if cur_eye is not None:
                    gui.flame_param['eyes'] = torch.tensor(cur_eye, dtype=torch.float32)
                gui.gaussians.update_mesh_by_param_dict(gui.flame_param)
                gui.need_update = True
                last_update = now

        except queue.Empty:
            continue

if __name__ == "__main__":
    cfg = tyro.cli(Config)
    
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind(("0.0.0.0", 9000))
    
    q = queue.Queue()
    mapper = ExpressionMapper(cfg)
    gui = LocalViewer(cfg)
    threading.Thread(target=socket_receiver, args=(sock, q), daemon=True).start()
    threading.Thread(target=gui_updater, args=(q, gui, mapper), daemon=True).start()
    gui.run()
