import os
import json
import argparse
import numpy as np
import soundfile as sf
from scipy.interpolate import interp1d


def load_arkit_data(arkit_path):
    with open(arkit_path, 'r') as f:
        data = json.load(f)
    timestamps = np.array([entry['timestamp'] for entry in data])
    keys = ["/W", "/HR", "/ELR", "/ERR"]
    blendshapes = {key: np.array([entry[key] for entry in data]) for key in keys}
    return timestamps, blendshapes


def get_duration(audio_path):
    try:
        audio, samplerate = sf.read(audio_path)
        duration = len(audio) / samplerate
    except Exception as e:
        # print(f"Audio doesn't exist, loading from log.json instead.")
        log_path = os.path.join(os.path.dirname(audio_path), "log.json")
        with open(log_path, 'rb') as f:
            duration = json.load(f)["duration"]
    print(f"Duration: {duration:.2f} seconds")
    return duration


def interpolate_blendshapes(timestamps, blendshapes, duration, target_fps):
    target_frame_count = int(duration * target_fps)
    target_times = np.linspace(0, duration, target_frame_count)
    aligned_blendshapes = {}
    for key in blendshapes.keys(): 
        interpolator = interp1d(timestamps, blendshapes[key], axis=0, bounds_error=False, fill_value="extrapolate")
        aligned_blendshapes[key] = interpolator(target_times)

    return aligned_blendshapes, target_times


def save_output(aligned_blendshapes, timestamps, output_path):
    output = []
    for i in range(len(timestamps)):
        output.append({
            "timestamp": float(timestamps[i]),
            "/W": aligned_blendshapes["/W"][i].tolist(),
            "/HR": aligned_blendshapes["/HR"][i].tolist(),
            "/ELR": aligned_blendshapes["/ELR"][i].tolist(),
            "/ERR": aligned_blendshapes["/ERR"][i].tolist()
        })
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)


def main(root, fps, mirror):
    arkit_path = os.path.join(root, "arkit_raw.json")
    audio_path = os.path.join(root, "audio.wav")
    print("mirror:", mirror)
    timestamps, blendshapes = load_arkit_data(arkit_path)
    duration = get_duration(audio_path)

    aligned_blendshapes, target_times = interpolate_blendshapes(timestamps, blendshapes, duration, fps)
    if mirror:
        columns_to_swap = [
            [0, 1], [3, 4], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17],
            [18, 19], [20, 21], [23, 25], [27, 28], [29, 30], [32, 38], [33, 34],
            [35, 36], [43, 44], [45, 46], [47, 48], [49, 50]
        ]
        for i in range(0, len(aligned_blendshapes["/W"])):
            for col1, col2 in columns_to_swap:
                aligned_blendshapes["/W"][i][[col1, col2]] = aligned_blendshapes["/W"][i][[col2, col1]]
            aligned_blendshapes["/HR"][i][1] = -aligned_blendshapes["/HR"][i][1]
            aligned_blendshapes["/HR"][i][2] = -aligned_blendshapes["/HR"][i][2]
            elr = aligned_blendshapes["/ELR"][i].copy()
            err = aligned_blendshapes["/ERR"][i].copy()
            elr[1] *= -1
            err[1] *= -1
            aligned_blendshapes["/ELR"][i] = err
            aligned_blendshapes["/ERR"][i] = elr
                
    print(f"Aligned blendshapes to {fps} FPS.")
    save_output(aligned_blendshapes, target_times, os.path.join(root, f"arkit.json"))
    print(f"Saved aligned blendshapes to {os.path.join(root, 'arkit.json')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Align ARKit blendshapes to a given FPS.")
    parser.add_argument("--root", type=str, help="Path to the folder containing arkit.json and audio.wav")
    parser.add_argument("--fps", type=float, default=25.0, help="Target FPS for alignment (default: 25.0)")
    parser.add_argument("--mirror", action='store_true', help="Whether to mirror the ARKit data")

    args = parser.parse_args()
    main(args.root, args.fps, args.mirror)
