import asyncio
import websockets
import json
import torch
import numpy as np
from mano_utils import hand_pose_estimator

async def handle_client(websocket):
    print("Client connected")
    estimator = hand_pose_estimator(
        model_path="./MANO/mano_v1_2/models",
        batch_size=350
    )

    try:
        async for message in websocket:
            try:
                data = json.loads(message)
                processed_batch = []
                
                for item in data["batch"]:
                    num_cams = item.get("num_cams", len(item["keypoints_2d"]))
                    keypoints_2d = np.array(item["keypoints_2d"])
                    K_list = [np.array(k) for k in item["K_list"]]
                    RT_list = [np.array(rt) for rt in item["RT_list"]]
                    
                    if num_cams == 1:
                        raise ValueError("Single camera detection is not supported")
                    elif num_cams == 2:
                        print("Duplicating camera data to make 3 cameras")
                        keypoints_2d = np.vstack([keypoints_2d, keypoints_2d[0:1]])
                        K_list.append(K_list[0])
                        RT_list.append(RT_list[0])
                        num_cams = 3
                    
                    if len(keypoints_2d) != num_cams or len(K_list) != num_cams or len(RT_list) != num_cams:
                        raise ValueError(f"Camera data count mismatch: "
                                         f"keypoints_2d={len(keypoints_2d)}, "
                                         f"K_list={len(K_list)}, "
                                         f"RT_list={len(RT_list)}, "
                                         f"expected={num_cams}")
                    
                    processed_batch.append({
                        "keypoints_2d": keypoints_2d,
                        "K_list": K_list,
                        "RT_list": RT_list,
                        "num_cams": num_cams
                    })
                
                await websocket.ping()
                
                keypoints_2d_list = [item["keypoints_2d"] for item in processed_batch]
                K_list_list = [item["K_list"] for item in processed_batch]
                RT_list_list = [item["RT_list"] for item in processed_batch]
                num_cams_list = [item["num_cams"] for item in processed_batch]
                
                output, transl, loss = estimator.estimate_pose_batch(
                    keypoints_2d_list=keypoints_2d_list,
                    K_list_list=K_list_list,
                    RT_list_list=RT_list_list,
                    num_cams_list=num_cams_list
                )
                if loss > 2200:
                    raise ValueError(f"High loss detected: {loss}. This may indicate poor input data quality.")

                response = json.dumps({
                    "joints_3d_batch": [j.detach().cpu().numpy().tolist() for j in output.joints.detach()],
                    "betas": output.betas.detach().cpu().numpy().tolist(),
                    "global_orient": output.global_orient.detach().cpu().numpy().tolist(),
                    "hand_pose": output.hand_pose.detach().cpu().numpy().tolist(),
                    "transl": transl.detach().cpu().numpy().tolist(),
                    "loss": loss,
                })
                await websocket.send(response)
                
            except ValueError as e:
                error_msg = f"Data validation error: {str(e)}"
                print(error_msg)
                await websocket.send(json.dumps({
                    "error": error_msg,
                    "type": "validation",
                    "status": "rejected"
                }))
            except Exception as e:
                error_msg = f"Processing error: {str(e)}"
                print(error_msg)
                await websocket.send(json.dumps({
                    "error": error_msg,
                    "type": "processing",
                    "status": "failed"
                }))
                
    except websockets.exceptions.ConnectionClosedError as e:
        print(f"Connection closed: {e}")
    except Exception as e:
        print(f"Unexpected error: {e}")

async def main():
    print("WebSocket server started on ws://0.0.0.0:8766")
    async with websockets.serve(
        handle_client, 
        "0.0.0.0", 
        8766,
        ping_interval=20,
        ping_timeout=600,
        close_timeout=600,
        max_size = 10485760
    ):
        await asyncio.Future()  # run forever

if __name__ == "__main__":
    asyncio.run(main())