import torch
import mano
import numpy as np
import os
import json
import cv2

class hand_pose_estimator:
    class IterationScheduler:
        def __init__(self, max_iter=5000, min_iter=500, decay_step=5):
            self.max_iter = max_iter
            self.min_iter = min_iter
            self.current_frame = 0
            self.decay_step = decay_step
            self.current_iter = max_iter  # Initialize to the maximum number of iterations

        def get_iteration_num(self):
            # Linearly decay the number of iterations
            if self.current_iter > self.min_iter:
                self.current_iter -= self.current_frame * (self.max_iter - self.min_iter) / self.decay_step
            self.current_frame += 1
            self.current_iter = max(self.current_iter, self.min_iter)  # Ensure iterations don't fall below the minimum
            return int(self.current_iter)
        
    def __init__(self, model_path, n_comps=6, batch_size=10, device=None):
        self.model_path = model_path
        self.n_comps = n_comps
        self.batch_size = batch_size
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.rh_model = mano.load(model_path=self.model_path,
                                  is_rhand=True,
                                  num_pca_comps=self.n_comps,
                                  batch_size=self.batch_size,
                                  flat_hand_mean=False).to(self.device)
        
        self.betas = torch.zeros((1, 10), requires_grad=False, device=self.device)*.1
        self.pose = torch.tensor([[ 0.2895,  0.6036,  0.0153, -1.0968,  0.7020, -0.1357]], requires_grad=True, device=self.device)
        self.global_orient = torch.rand((1, 3), requires_grad=True, device=self.device)
        self.transl = torch.rand((1, 3), requires_grad=True, device=self.device)
        self.iteration_scheduler = self.IterationScheduler()
        
    def estimate_pose(self, keypoints_2d, K_list, RT_list, num_cams=2):
        optimizer = torch.optim.Adam([self.pose, self.global_orient, self.transl], lr=0.4)
        keypoints_2d_torch = torch.from_numpy(keypoints_2d).float().to(self.device)  # [num_cams, 21, 2]
        
        iteration_num = self.iteration_scheduler.get_iteration_num()
        print(f"Optimize for {iteration_num} iterations.")
        for iter in range(iteration_num):
            optimizer.zero_grad()
            output = self.rh_model.forward(betas=self.betas,
                                           global_orient=self.global_orient,
                                           hand_pose=self.pose,
                                           transl=self.transl,
                                           return_tips=True,
                                           return_verts=True)
            joints_3d = output.joints  # [1, 21, 3]
            loss = 0
            for cam in range(num_cams):
                proj_2d = self.project_points(joints_3d, K_list[cam], RT_list[cam])[0]  # [21,2]
                loss += ((proj_2d - keypoints_2d_torch[cam]) ** 2).mean()
            loss.backward()
            optimizer.step()
            if iter % 50 == 0:
                print(f'Iter {iter}, Loss: {loss.item()}')
        print(self.pose)

        return output.joints

    @staticmethod
    def project_points(points_3d, K, RT):
        # points_3d: [B, 21, 3]
        B, N, _ = points_3d.shape
        points_3d_h = torch.cat([points_3d, torch.ones(B, N, 1, device=points_3d.device)], dim=-1)  # [B, 21, 4]
        P = torch.from_numpy(K @ RT).float().to(points_3d.device)  # [3,4]
        proj = torch.einsum('ij,bnj->bni', P, points_3d_h)  # [B, 21, 3]
        proj_2d = proj[..., :2] / proj[..., 2:3]
        return proj_2d  # [B, 21, 2]