import argparse
import os
import json
import numpy as np
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from tqdm import tqdm
from sklearn.cluster import KMeans

from utils import empty_instance_dict

def split_spatial(history_dict: dict, forecast_dict: dict, num_instances: int):
    # Cluster instances by position to partition the scene in multiple sub-scenes with fewer instances
    # Don't cluster if there are already less than num_instances
    if len(history_dict) <= num_instances:
        return [history_dict], [forecast_dict]
    
    # Compute average positions per instance
    mean_translations = {}
    for instance_token, instance in history_dict.items():
        mean_translations[instance_token] = np.mean(instance['translation'], axis=0).tolist()
    for instance_token, instance in forecast_dict.items():
        mean_translations[instance_token] = mean_translations.get('instance_token', []) + [np.mean(instance['translation'], axis=0).tolist()]        
        mean_translations[instance_token] = np.mean(mean_translations[instance_token], axis=0).tolist()
        
    # Cluster so that there are ~num_instances instances per sub-scene
    kmeans = KMeans(n_clusters=len(mean_translations)//num_instances).fit(list(mean_translations.values()))
    label_dict = {lab: [token for i, token in enumerate(mean_translations.keys()) if kmeans.labels_[i] == lab] for lab in kmeans.labels_}

    # Split in separate dictionaries based on cluster assignment
    history_dicts = []
    forecast_dicts = []
    for _, instance_tokens in label_dict.items():
        sub_history_dict = {}
        sub_forecast_dict = {}
        for instance_token in instance_tokens:
            if instance_token in history_dict:
                sub_history_dict[instance_token] = history_dict[instance_token]
            if instance_token in forecast_dict:
                sub_forecast_dict[instance_token] = forecast_dict[instance_token]
        history_dicts.append(sub_history_dict)
        forecast_dicts.append(sub_forecast_dict)
        
    return history_dicts, forecast_dicts


if __name__ == '__main__':
    # Set up argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--num_history', 
        type=int, 
        default=8,
        help="Number of history steps"
    )
    parser.add_argument(
        '--num_forecast', 
        type=int, 
        default=8,
        help="Number of forecast steps"
    )
    parser.add_argument(
        '--json_dir', 
        type=str, 
        required=True, 
        help="Path to the directory of source, full-scene JSON files"
    )
    parser.add_argument(
        '--out_dir', 
        type=str, 
        required=True, 
        help="Path to the directory to save divided JSON files (in out_dir/history and out_dir/forecast_gt)"
    )
    args = parser.parse_args()

    num_history = args.num_history
    num_forecast = args.num_forecast

    os.makedirs(os.path.join(args.out_dir, 'history'), exist_ok=True)
    os.makedirs(os.path.join(args.out_dir, 'forecast_gt'), exist_ok=True)
    
    for scene_json in tqdm(os.listdir(args.json_dir)):
        if not os.path.join(args.json_dir, scene_json).endswith('.json'):
            continue

        scene_id = scene_json.replace('.json', '')

        with open(os.path.join(args.json_dir, scene_json), 'r') as f:
            scene_dict = json.load(f)        

        max_time = max([instance['timestep'][-1] for _, instance in scene_dict.items()])
        if max_time+1 < num_history+num_forecast:
            print(f'WARNING: skipping scene {scene_id} as it has fewer frames than the intended number')
            continue

        chunk_id = 0

        # Initialize history and forecast dicts
        history_dict = {}
        forecast_dict = {}

        for frame in range(max_time):
            # Iterate over instances in the scene
            for instance_token, instance in scene_dict.items():
                if frame not in instance["timestep"]:
                    continue  # Skip this instance if the frame exceeds its data
                else:
                    eff_time_index = instance["timestep"].index(frame)

                timestep = instance["timestep"][eff_time_index]
                translation = instance["translation"][eff_time_index]
                rotation = instance["rotation"][eff_time_index]
                size = instance["size"][eff_time_index]
                attribute_label = instance["attribute_label"][eff_time_index]

                local_timestep = frame - chunk_id*(num_history + num_forecast)

                if local_timestep < num_history:
                    if instance_token not in history_dict:
                        history_dict[instance_token] = empty_instance_dict()
                    history_dict[instance_token]['timestep'].append(local_timestep)
                    history_dict[instance_token]['translation'].append(translation)
                    history_dict[instance_token]['rotation'].append(rotation)
                    history_dict[instance_token]['size'].append(size)
                    history_dict[instance_token]['attribute_label'].append(attribute_label)
                elif local_timestep < num_history + num_forecast:
                    if instance_token not in forecast_dict:
                        forecast_dict[instance_token] = empty_instance_dict()
                    forecast_dict[instance_token]['timestep'].append(local_timestep)
                    forecast_dict[instance_token]['translation'].append(translation)
                    forecast_dict[instance_token]['rotation'].append(rotation)
                    forecast_dict[instance_token]['size'].append(size)
                    forecast_dict[instance_token]['attribute_label'].append(attribute_label)
            if local_timestep == num_history + num_forecast - 1:
                # Perform spatial splitting and save each sub-dict
                history_dicts, forecast_dicts = split_spatial(history_dict, forecast_dict, num_instances=10)
                for cluster_i, sub_history_dict, sub_forecast_dict in zip(range(len(history_dicts)), history_dicts, forecast_dicts):
                    with open(os.path.join(args.out_dir, 'history', f'{scene_id}_{chunk_id}_{cluster_i}.json'), 'w') as json_file:
                        json.dump(sub_history_dict, json_file, indent=4)
                    with open(os.path.join(args.out_dir, 'forecast_gt', f'{scene_id}_{chunk_id}_{cluster_i}.json'), 'w') as json_file:
                        json.dump(sub_forecast_dict, json_file, indent=4)
                pass

                history_dict = {}
                forecast_dict = {}
                chunk_id += 1