import glob
import os
import sys
import pdb
import os.path as osp
import numpy as np

sys.path.append(os.getcwd())

import torch

from smpl_sim.smpllib.smpl_parser import (
    SMPL_Parser,
    SMPLH_Parser,
    SMPLX_Parser,
)

import argparse
from scipy.ndimage import uniform_filter1d

def load_motion(data_path):
    entry_data = dict(np.load(open(data_path, "rb"), allow_pickle=True))

    if not 'mocap_framerate' in entry_data:
        return
    framerate = entry_data['mocap_framerate']

    root_trans = entry_data['trans']
    pose_aa = np.concatenate([entry_data['poses'][:, :66], np.zeros((root_trans.shape[0], 6))], axis=-1)
    betas = entry_data['betas']
    gender = entry_data['gender']
    N = pose_aa.shape[0]
    return {
        "pose_aa": pose_aa,
        "gender": gender,
        "trans": root_trans,
        "betas": betas,
        "fps": framerate
    }

def foot_detect(positions, thres=0.002):
    fid_r, fid_l = [8, 11], [7, 10]
    positions = positions.numpy()
    velfactor, heightfactor = np.array([thres, thres]), np.array([0.2, 0.17])
    feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
    feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 
    feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
    feet_l_h = positions[1:,fid_l,2]
    #     feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
    # feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32)
    feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(int) & (feet_l_h < heightfactor).astype(int)).astype(np.float32)
    feet_l = np.concatenate([np.array([[1., 1.]]),feet_l],axis=0)
    feet_l = np.max(feet_l, axis=1, keepdims=True)
    feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
    feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
    feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
    feet_r_h = positions[1:,fid_r,2]
    #     feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
    feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor).astype(int) & (feet_r_h < heightfactor).astype(int)).astype(np.float32)
    # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32)
    feet_r = np.concatenate([np.array([[1., 1.]]),feet_r],axis=0)
    feet_r = np.max(feet_r, axis=1, keepdims=True)
    return feet_l, feet_r

def EMA_smooth(trans, alpha=0.3):
    ema = np.zeros_like(trans)
    ema[0] = trans[0]
    for i in range(1, len(trans)):
        ema[i] = alpha * trans[i] + (1 - alpha) * ema[i-1]
    return ema

def moving_average(data, window_size=5):
    window_size = window_size if window_size%2==1 else window_size+1
    smoothed = uniform_filter1d(data, size=window_size, axis=0, mode="nearest")
    return smoothed

def correct_motion(contact_mask, verts, trans):
    contact_indices = np.where(np.any(contact_mask != [0, 0], axis=1))[0]
    trans[contact_indices,2] -= torch.min(verts[contact_indices, :, 2],dim=1,keepdim=True)[0].squeeze(-1)
    trans[:,2] = torch.from_numpy(EMA_smooth(trans[:,2]))
    # trans = torch.from_numpy(moving_average(trans))
    return trans

def process_motion(motion, correction):
    smpl_parser_n = SMPL_Parser(model_path="data/smpl", gender="neutral")
    all_data = {}
    skip = int(motion['fps'] // 30)
    trans = torch.from_numpy(motion['trans'][::skip])
    gender = motion['gender']
    N = trans.shape[0]
    pose_aa = torch.from_numpy(motion['pose_aa'][::skip]).float()
    betas = torch.from_numpy(motion['betas']).unsqueeze(0)

    if N < 10:
        print("too short")
        return None

    with torch.no_grad():
        origin_verts, origin_global_trans = smpl_parser_n.get_joints_verts(pose_aa, betas, trans)
        origin_global_trans[..., 2] -= origin_verts[0, :, 2].min().item()
        feet_l , feet_r = foot_detect(origin_global_trans)
        contact_mask = np.concatenate([feet_l,feet_r],axis=-1)
        print(contact_mask)

        if correction:
            trans = correct_motion(contact_mask, origin_verts, trans)

    return {
        "pose_aa": pose_aa,
        "gender": gender,
        "trans": trans,
        "betas": betas,
        "contact_mask": contact_mask,
        "fps": 30
    }
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--folder', type=str, required=True)
    parser.add_argument('--correction', type=bool, default=False)
    args = parser.parse_args()
    folder = args.folder
    correction = args.correction

    output_folder = folder + '_contact_mask'
    os.makedirs(output_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.split('.')[-1] != 'npz':
            continue
        motion = load_motion(folder+'/'+filename)
        processed_motion = process_motion(motion, correction)
        np.savez(output_folder+'/'+filename, trans = processed_motion['trans'], poses=processed_motion['pose_aa'], gender=processed_motion['gender'], 
                 mocap_framerate=processed_motion['fps'], contact_mask = processed_motion['contact_mask'], betas = processed_motion['betas'])
