import os
import math
import json
import numpy as np

from tqdm import tqdm
from nuscenes.nuscenes import NuScenes

data_root = '/path/to/nuscenes'
out_root = '/your/out/dir'
save_dir = os.path.join(out_root, 'full')

nusc = NuScenes(version='v1.0-trainval', dataroot=data_root, verbose=True)

# The order of operations will be the following
# Iterate over scenes in the dataset
# |---For each scene, collect all instance tokens
# |---|---For each instance token in the scene, collect data in temporal order

# Iterate over all scenes in dataset
for scene in tqdm(nusc.scene):
    # Initialize scene dictionary (i.e., JSON file)
    scene_dict = {}

    # Keep track of min and max timestamp in the scene (will be used to normalize)
    min_time = float('inf')
    max_time = -float('inf')

    # Group all instance tokens in the scene
    instance_tokens_scene = []

    # Read first sample token (i.e., frame) of the scene
    sample_token = scene['first_sample_token']

    # Iterate over sample tokens (i.e., frames) in the scene
    while sample_token != '':
        # Get sample (i.e., frame data)
        sample = nusc.get('sample', sample_token)

        # Get all annotation tokens (i.e., instances) in the frame
        annotation_tokens = sample['anns']

        # Get corresponding instance tokens (i.e., unique instance identifiers) to each annotation token (i.e., frame-specific instance identifier)
        instance_tokens_frame = []
        for annotation_token in annotation_tokens:            
            annotation_metadata =  nusc.get('sample_annotation', annotation_token)
            instance_tokens_frame.append(annotation_metadata['instance_token'])
        instance_tokens_scene.extend(instance_tokens_frame)
        
        # Update min and max timestamp (only if valid instances were found in this frame)
        min_time = min(min_time, sample['timestamp'])
        max_time = max(max_time, sample['timestamp'])

        # Go to next sample (i.e., frame)
        sample_token = sample['next']
    
    # Remove duplicates instance tokens (the same instance token has appeared as different annotation tokens in different frames)
    instance_tokens_scene = list(set(instance_tokens_scene))

    # Now we can traverse the scene instance-wise and accumulate data for each instance (as per the desired output format)
    for instance_token in instance_tokens_scene:
        # Initialize instance dict containing timesteps, translations, rotations, etc. for the specific instance
        instance_dict = {
            'timestep': [],
            'translation': [],
            'rotation': [],
            'size': [],
            'attribute_label': []
        }

        # Fetch category
        category_token = nusc.get('instance', instance_token)['category_token']
        category_name = nusc.get('category', category_token)['name']
        # Skip instance if not a vehicle
        if 'vehicle' not in category_name:
            continue
        
        # Get annotation tokens (i.e., frame-specific instance identifiers) for the instance token
        annotation_tokens = nusc.field2token('sample_annotation', 'instance_token', instance_token)
        
        # Iterate over timesteps for the instance
        for annotation_token in annotation_tokens:
            annotation_metadata = nusc.get('sample_annotation', annotation_token)

            # Append translation, size, and category directly
            instance_dict['translation'].append(annotation_metadata['translation'])
            instance_dict['size'].append([annotation_metadata['size'][2], annotation_metadata['size'][0], annotation_metadata['size'][1]])
            instance_dict['attribute_label'].append(category_name)

            # Convert rotation from quaternion to roll, pitch, yaw
            w, x, y, z = annotation_metadata['rotation']
            roll = math.atan2(2 * (w * x + y * z), 1 - 2 * (x ** 2 + y ** 2))
            pitch = math.asin(2 * (w * y - z * x))
            yaw = math.atan2(2 * (w * z + x * y), 1 - 2 * (y ** 2 + z ** 2))
            instance_dict['rotation'].append([roll, pitch, yaw])

            # Fetch timestep based on sample (i.e., frame) token
            sample = nusc.get('sample', annotation_metadata['sample_token'])
            # Normalize timestamps to integers starting from 0 (N.B. the annotation rate is 2 Hz = 1 frame / 500000 µs)
            timestep = round((sample['timestamp'] - min_time) / 500000.)
            instance_dict['timestep'].append(timestep)

        if len(instance_dict['timestep']) == 0:
            continue

        # Sort timesteps and shuffle other lists accordingly
        sorted_indices = sorted(range(len(instance_dict['timestep'])), key=lambda i: instance_dict['timestep'][i])
        instance_dict['timestep'] = [instance_dict['timestep'][i] for i in sorted_indices]
        instance_dict['translation'] = [instance_dict['translation'][i] for i in sorted_indices]
        instance_dict['rotation'] = [instance_dict['rotation'][i] for i in sorted_indices]
        instance_dict['size'] = [instance_dict['size'][i] for i in sorted_indices]
        instance_dict['attribute_label'] = [instance_dict['attribute_label'][i] for i in sorted_indices]

        # Insert into scene dict
        scene_dict[instance_token] = instance_dict

    # Save scene dict as JSON
    with open(os.path.join(save_dir, f'{scene["token"]}.json'), 'w') as json_file:
        json.dump(scene_dict, json_file, indent=4)