import argparse
import os
import json
import numpy as np

from tqdm import tqdm
from scipy.interpolate import interp1d


def resample(data, source_timesteps, source_freq, target_freq):
    data = np.array(data)  # shape: (N, D)
    source_timesteps = np.array(source_timesteps)
    source_times = source_timesteps / source_freq

    start_time = source_times[0]
    end_time = source_times[-1]
    duration = end_time - start_time
    n_target_samples = int(np.floor(duration * target_freq)) + 1

    target_times = np.arange(n_target_samples) / target_freq + start_time
    target_timesteps = np.round(target_times * target_freq).astype(int).tolist()

    # Interpolate each dimension separately
    interpolated = []
    for d in range(data.shape[1]):
        interp_fn = interp1d(source_times, data[:, d], kind='linear', fill_value='extrapolate')
        interpolated.append(interp_fn(target_times))
    
    # Stack back into list of lists
    resampled_data = np.stack(interpolated, axis=1).tolist()

    return resampled_data, target_timesteps


def resample_categorical(data, source_timesteps, source_freq, target_freq):
    data = np.array(data)
    source_timesteps = np.array(source_timesteps)
    source_times = source_timesteps / source_freq

    start_time = source_times[0]
    end_time = source_times[-1]
    duration = end_time - start_time
    n_target_samples = int(np.floor(duration * target_freq)) + 1

    target_times = np.arange(n_target_samples) / target_freq + start_time
    target_timesteps = np.round(target_times * target_freq).astype(int).tolist()

    resampled = []

    for t in target_times:
        idx = np.argmin(np.abs(source_times - t))
        resampled.append(data[idx])

    return resampled, target_timesteps


if __name__ == '__main__':
    # Set up argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--source_freq', 
        type=int, 
        required=True,
        help="Frequency of measurements of source dictionary (Hz)"
    )
    parser.add_argument(
        '--target_freq', 
        type=int, 
        default=2,
        help="Target frequency (Hz)"
    )
    parser.add_argument(
        '--json_dir', 
        type=str, 
        required=True, 
        help="Path to the directory of source, full-scene JSON files"
    )
    parser.add_argument(
        '--out_dir', 
        type=str, 
        required=True, 
        help="Path to the directory to save the resampled JSON file"
    )
    args = parser.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    
    for scene_json in tqdm(os.listdir(args.json_dir)):
        if not os.path.join(args.json_dir, scene_json).endswith('.json'):
            continue

        scene_id = scene_json.replace('.json', '')
        scene_dict_resampled = {}

        with open(os.path.join(args.json_dir, scene_json), 'r') as f:
            scene_dict = json.load(f)

        # Iterate over instances in the scene
        for instance_token, instance in scene_dict.items():
            # Interpolate values of the instance at the target resolution
            instance['translation'], target_ts = resample(instance['translation'], instance['timestep'], args.source_freq, args.target_freq)
            instance['rotation'], _ = resample(instance['rotation'], instance['timestep'], args.source_freq, args.target_freq)
            instance['size'], _ = resample(instance['size'], instance['timestep'], args.source_freq, args.target_freq)
            
            # Sample attribute labels based nearest neighbors at target frequency
            instance['attribute_label'], _ = resample_categorical(instance['attribute_label'], instance['timestep'], args.source_freq, args.target_freq)

            instance['timestep'] = target_ts

            # Overwrite instance dict in modified scene dict
            scene_dict_resampled[instance_token] = instance
        
        # Save resampled scene dict
        with open(os.path.join(args.out_dir, scene_json), 'w') as json_file:
            json.dump(scene_dict_resampled, json_file, indent=4)
