import os
import glob
import shutil
from argparse import ArgumentParser
import numpy as np
import torch
import pickle as pkl
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation, RotationSpline
import numpy
from PIL import Image, ImageOps
from pytorch3d.io import load_obj, save_ply
import cv2
import json
from tqdm import tqdm



# Input: P 3x4 numpy matrix
# Output: K, R, T such that P = K*[R | T], det(R) positive and K has positive diagonal
#
# Reference implementations: 
#   - Oxford's visual geometry group matlab toolbox 
#   - Scilab Image Processing toolbox
def KRT_from_P(P):
    N = 3
    H = P[:,0:N]  # if not numpy,  H = P.to_3x3()

    [K,R] = rf_rq(H)
    K /= K[-1,-1]

    # from http://ksimek.github.io/2012/08/14/decompose/
    # make the diagonal of K positive
    sg = numpy.diag(numpy.sign(numpy.diag(K)))

    K = K @ sg
    R = sg @ R
    # det(R) negative, just invert; the proj equation remains same:
    if (numpy.linalg.det(R) < 0):
        R = -R
    # C = -H\P[:,-1]
    C = numpy.linalg.lstsq(-H, P[:,-1])[0]
    T = -R @ C
    return K, R, T

# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv2.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

# RQ decomposition of a numpy matrix, using only libs that already come with
# blender by default
#
# Author: Ricardo Fabbri
# Reference implementations: 
#   Oxford's visual geometry group matlab toolbox 
#   Scilab Image Processing toolbox
#
# Input: 3x4 numpy matrix P
# Returns: numpy matrices r,q
def rf_rq(P):
    P = P.T
    # numpy only provides qr. Scipy has rq but doesn't ship with blender
    q, r = numpy.linalg.qr(P[ ::-1, ::-1], 'complete')
    q = q.T
    q = q[ ::-1, ::-1]
    r = r.T
    r = r[ ::-1, ::-1]

    if (numpy.linalg.det(q) < 0):
        r[:,0] *= -1
        q[0,:] *= -1
    
    return r, q

def get_camera_params(camera):
    K, R_world2cv, T_world2cv = KRT_from_P(camera)
    R_world2cv_quat = Rotation.from_matrix(R_world2cv).as_quat()
    
    return K, R_world2cv_quat, T_world2cv

def get_concat(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst


def main(blender_path, input_path, cams_path, exp_name_1, exp_name_3, strand_length, speed_up, max_frames):
    cams_prams = json.load(open(cams_path))
    cams_ids =  list(cams_prams["world_2_cam"].keys())
    # Unpack cameras
    R = []
    K = []
    T = []
    P = []
    frames = []
    cams_idx_new = sorted([int(k) for k in cams_ids])
    for k in cams_idx_new:
        k = str(k)
        frames.append(int(k))
        # plt.imshow(mask_head)
        # plt.show()
        pose = np.array(cams_prams["world_2_cam"][k]) # (4, 4)
        w2c = np.eye(4)
        w2c[0:3, 0:3] = pose[0:3, 0:3]
        w2c[0:3, 3] = -pose[0:3, 0:3] @ pose[0:3, 3]
        intrinsics = np.eye(4)
        intrinsics[0:3,0:3] = np.array(cams_prams["intrinsics"]) # (4, 4)
        c2w = np.linalg.inv(w2c)
        intrinsics_modified = intrinsics.copy()
        intrinsics_modified[0, 0] /= 4  # Halve fx
        intrinsics_modified[1, 1] /= 4  # Halve fy
        intrinsics_modified[0, 2] /= 4  # Halve cx
        intrinsics_modified[1, 2] /= 4  # Halve cy
        projection_matrix = intrinsics_modified @ c2w
        K_world2cv, R_world2cv, T_world2cv = KRT_from_P(projection_matrix[:3])      
        R.append(R_world2cv)
        K.append(K_world2cv)
        T.append(T_world2cv)
        P.append(projection_matrix[:3])
    cameras_interp = np.stack(P)
    frames = np.array(frames)
    np.save(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/frames.npy', frames)
    np.save(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/cameras.npy', cameras_interp)
    os.makedirs(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/images', exist_ok=True)
    os.makedirs(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/raw_head', exist_ok=True)
    os.makedirs(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/head', exist_ok=True)
    os.makedirs(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/hair', exist_ok=True)
    hair_index = np.random.choice(20000, 20000, replace=False)
    np.save(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/hair_index.npy', hair_index)
    strands_ply_paths = glob.glob(f'{input_path}/curves_reconstruction/{exp_name_3}/strands/*.ply')
    for ply_path in tqdm(strands_ply_paths):
        basename = os.path.basename(ply_path)
        tqdm.write(basename)
        frame_index = basename.split('.')[0].split('_')[1]
        frame_index = int(frame_index)
        mesh_index = int(frame_index)-1
        verts, faces, _ = load_obj(f'{input_path}/flame_fitting/raw_data/eval_30/mesh/frame_{mesh_index:05d}.obj')
        save_ply(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/raw_head/frame_{frame_index:05d}.ply', verts=verts, faces=faces.verts_idx)
        head_ply = PlyData.read(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/raw_head/frame_{frame_index:05d}.ply')
        head_vertex = (
        np.stack([
            head_ply.elements[0].data['x'], 
            head_ply.elements[0].data['y'], 
            head_ply.elements[0].data['z']], axis=1).reshape(-1, 3, 1)
        )[..., 0]
        head_vertex = [tuple(vertex) for vertex in head_vertex.tolist()]
        head_vertex = np.array(head_vertex, dtype=np.dtype('float, float, float'))
        head_vertex.dtype.names = ['x', 'y', 'z']
        head_new_ply = PlyData([PlyElement.describe(head_vertex, 'vertex'), head_ply.elements[1]])
        head_new_ply.write(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/head/frame_{frame_index:05d}.ply')
        strands_ply = PlyData.read(ply_path).elements[0].data
        strands_npy = (
        np.stack([
            strands_ply['x'], 
            -strands_ply['z'], 
            strands_ply['y']], axis=1).reshape(-1, strand_length, 3, 1)
        )[..., 0]
        print(strands_npy.shape)
        np.save(f'{input_path}/curves_reconstruction/{exp_name_3}/blender/hair/frame_{frame_index:05d}.npy', strands_npy)
        
        os.system(
        f'{blender_path} -b main.blend -P render_color.py -- --args ' \
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/cameras.npy ' \
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/head/frame_{frame_index:05d}.ply ' \
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/hair/frame_{frame_index:05d}.npy ' \
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/images ' \
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/frames.npy ' \
        f'frame_{frame_index:05d} 256 '\
        f'{input_path}/curves_reconstruction/{exp_name_3}/blender/hair_index.npy'
        )
    


if __name__ == "__main__":
    parser = ArgumentParser(conflict_handler='resolve')

    parser.add_argument('--blender_path', default='/home/ezakharov/Libraries/blender-3.6.11-linux-x64/blender', type=str)
    parser.add_argument('--input_path', default='/home/ezakharov/Datasets/hair_reconstruction/NeuralHaircut/jenya', type=str)
    parser.add_argument('--exp_name_1', default='stage1_lor=0.1', type=str)
    parser.add_argument('--exp_name_3', default='stage3_lor=0.1', type=str)
    parser.add_argument('--strand_length', default=100, type=int)
    parser.add_argument('--speed_up', default=1, type=int)
    parser.add_argument('--max_frames', default=200, type=int)
    parser.add_argument('--cams_path', default="", type=str)

    args, _ = parser.parse_known_args()
    args = parser.parse_args()

    main(args.blender_path, args.input_path, args.cams_path, args.exp_name_1, args.exp_name_3, args.strand_length, args.speed_up, args.max_frames)