import numpy as np
import argparse
import json

class ExpressionMapper:
    def __init__(self, args):
        self.args = args
        self.model = np.load(args.model_path)
        try:
            self.template = np.load(args.template_path)
        except Exception as e:
            print('No template to load, setting default values.')
            self.template = None

        self.jaw_scale = getattr(args, 'jaw_scale', 1.4)
        self.head_scale = getattr(args, 'head_scale', 2.0)
        self.eyes_scale = getattr(args, 'eyes_scale', 0.5)
        self.fixed_rotation = np.array([-0.15350625, -0.03199791, -0.0172692 ], dtype=np.float32)
        self.fixed_translation = np.array([-0.00397511, -0.00149212, -0.00931257], dtype=np.float32)
            
    
    def get_expr_and_jaw(self, arkit_seq):
        arkit_seq = np.array(arkit_seq, dtype=np.float32)
        assert self.model.shape[0] == arkit_seq.shape[1], \
            f"Transformation matrix and blendshape should have the same number of columns, but got {self.model.shape[0]} and {arkit_seq.shape[1]}"
        arkit_seq = np.delete(arkit_seq, np.s_[10:22], axis=1)
        trans_mat = np.delete(self.model, np.s_[10:22], axis=0)
        expr_and_jaw = arkit_seq @ trans_mat
        jaw_pose = expr_and_jaw[:, 100:].astype(np.float32) * self.jaw_scale
        expr = expr_and_jaw[:, :100].astype(np.float32)
        return expr, jaw_pose
    
    def get_head_rotation(self, rotation):
        rotation = np.array(rotation, dtype=np.float32)
        assert rotation.shape[1] == 3, "Rotation should have 2D shape (T, 3)"
        neck_pose = np.array(rotation, dtype=np.float32) / 180.0 * self.head_scale
        return neck_pose
    
    def get_eye_rotation(self, rotation):
        rotation = np.array(rotation, dtype=np.float32) 
        assert rotation.shape[1] == 6, "Rotation should have 2D shape (T, 6)"
        eyes_pose = np.array(rotation, dtype=np.float32) / 100.0
        eyes_pose[:, 1] *= self.eyes_scale
        eyes_pose[:, 4] *= self.eyes_scale # scale down the y-axis rotation
        return eyes_pose
        
    def save_npz(self, arkit_seq, output_path, rotations = None):
        expr, jaw_pose = self.get_expr_and_jaw(arkit_seq)
        T = expr.shape[0]
        if rotations is not None:
            assert len(arkit_seq) == len(rotations["neck_pose"]) and len(rotations["neck_pose"]) == len(rotations["eyes_pose"]),\
        "Blendshape and rotation should have the same length"
            neck_pose = self.get_head_rotation(rotations["neck_pose"])
            eyes_pose = self.get_eye_rotation(rotations["eyes_pose"])
        else:
            neck_pose = np.zeros((T, 3), dtype=np.float32)
            eyes_pose = np.zeros((T, 6), dtype=np.float32)
        
        rotation = np.tile(self.fixed_rotation, (T, 1))
        translation = np.tile(self.fixed_translation, (T, 1))
        dynamic_offset = np.zeros((T, 5143, 3), dtype=np.float32)
        shape = self.template['shape']
        static_offset = self.template['static_offset']

        np.savez_compressed(output_path,
            shape=shape,
            static_offset=static_offset,
            expr=expr,
            jaw_pose=jaw_pose,
            rotation=rotation,
            neck_pose=neck_pose,
            eyes_pose=eyes_pose,
            translation=translation,
            dynamic_offset=dynamic_offset
        )
        print(f"Saved new flame params with {T} frames to: {output_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input_path', type=str, default='data/arkit_record.json', help='Blendshape sequence file')
    parser.add_argument('-t', '--template_path', type=str, default='media/306/flame_param.npz', help='canonical flame model sequence file')
    parser.add_argument('-m', '--model_path', type=str , help='Path to the model file. \nIf the strategy is mlp, then it should be the directory containing the model and scaler files.')
    parser.add_argument('-o', '--output_path', type=str, default='data/306/test.npz', help='Output flame params file')

    parser.add_argument('--jaw_scale', type=float, default=1.4, help='Jaw rotation scaling factor')
    parser.add_argument('--head_scale', type=float, default=2.0, help='Head rotation scaling factor')
    parser.add_argument('--eyes_scale', type=float, default=0.5, help='Eyes rotation scaling factor')
    args = parser.parse_args()
    mapper = ExpressionMapper(args)
    
    rotations = {"neck_pose": [], "eyes_pose": []}
    arkit_seq = []
    
    if args.input_path.endswith('.json'):
        with open(args.input_path, 'r') as f:
            json_file = json.load(f)
            for d in json_file:
                arkit_seq.append(d['/W'][:-1]) # 52 --> 51
                rotations["neck_pose"].append(d['/HR'])
                rotations["eyes_pose"].append(d["/ELR"] + [0] + d["/ERR"] + [0])
        mapper.save_npz(arkit_seq, args.output_path, rotations)
    elif args.input_path.endswith('.npy'):
        arkit_seq = np.load(args.input_path)
        mapper.save_npz(arkit_seq, args.output_path)