import pickle
import gzip
import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from itertools import zip_longest


data_root = '/path/to/pandaset/'
output_root = '/your/out/dir'
num_history_steps = 8
num_forecast_steps = 8
fps = 2
admissible_instances = ['Car', 'Medium-sized Truck', 'Other Vehicle - Construction Vehicle', 'Pickup Truck']


def transform_group(row, timestep):
    """Transformation function for each row of pandaset annotation to instance_dict"""
    return {
        "timestep": timestep,
        "translation": [row["position.x"], row["position.y"], row["position.z"]],
        "rotation": [0., 0., row["yaw"] + 0.5*np.pi],
        "size": [row["dimensions.z"], row["dimensions.x"], row["dimensions.y"]],  # pandaset has sizes w-l-h
        "attribute_label": row["label"]
    }


# Create directories for history and forecast data
out_history_dir = os.path.join(output_root, 'history')
out_forecast_dir = os.path.join(output_root, 'forecast_gt')
os.makedirs(out_history_dir, exist_ok=True)
os.makedirs(out_forecast_dir, exist_ok=True)

for scene_id in tqdm(os.listdir(data_root)):
    ann_files = os.path.join(data_root, scene_id, 'annotations/cuboids')
    local_step = -1  # defines timestep in individual scene chunk
    chunk_id = 0  # defines individual scene chunk
    history_frames = []
    forecast_frames = []
    for frame in sorted(os.listdir(ann_files))[::10//fps]:
        local_step += 1
        
        # Open frame annotation dict (pandaset format)
        with gzip.open(os.path.join(ann_files, frame), 'rb') as f:
            frame_ann = pickle.load(f)
        
        # Select only vehicles
        frame_ann = frame_ann[frame_ann['label'].isin(admissible_instances)]

        # Convert from pandaset to Car4Cast format
        frame_ann_conv = frame_ann.set_index("uuid").apply(lambda row: transform_group(row, local_step), axis=1)
        
        # If a scene is longer than num_history_steps + num_forecast_steps, it will be split in multiple chunks
        if local_step < num_history_steps:
            history_frames.append(frame_ann_conv)
        elif local_step < num_history_steps + num_forecast_steps:
            forecast_frames.append(frame_ann_conv)
        
        if (len(history_frames) == num_history_steps) and (len(forecast_frames) == num_forecast_steps):
            # Concatenate history/forecast dataframes
            out_dfs = []
            for frames_list in [history_frames, forecast_frames, history_frames + forecast_frames]:
                temp_df = pd.concat(frames_list).reset_index()
                temp_df = pd.concat([temp_df["uuid"], temp_df[0].apply(pd.Series)], axis=1)
                temp_df = temp_df.groupby("uuid").agg({"timestep": list, "translation": list, "rotation": list, "size": list, "attribute_label": list})
                out_dfs.append(temp_df)
            history_df, forecast_df, all_df = out_dfs

            # Filter dataframes so that they share all instances
            common_uuids = history_df.index.intersection(forecast_df.index)
            history_df = history_df.loc[common_uuids]
            forecast_df = forecast_df.loc[common_uuids]
            all_df = all_df.loc[common_uuids]
            
            # Cluster instances by position to partition the scene in multiple sub-scenes with fewer instances
            # Compute average positions per instance
            all_translations = all_df['translation'].to_list()
            all_translations = np.array(list(zip_longest(*all_translations, fillvalue=[np.nan, np.nan, np.nan])))
            mean_translations = np.nanmean(all_translations, axis=0)
            
            # Cluster so that there are ~10 instances per sub-scene
            kmeans = KMeans(n_clusters=len(history_df)//10).fit(mean_translations)
            
            for frames_df, out_dir in zip([history_df, forecast_df], [out_history_dir, out_forecast_dir]):                
                # Split in separate dataframes based on cluster assignment
                frames_df['cluster'] = kmeans.labels_
                clusters_df = [frames_df[frames_df['cluster'] == i].drop('cluster', axis=1) for i in range(max(kmeans.labels_))]

                # Save each clustered dataframe into JSON
                for cluster_i, cluster_df in enumerate(clusters_df):
                    with open(os.path.join(out_dir, f'{scene_id}_{chunk_id}_{cluster_i}.json'), 'w') as json_file:
                        cluster_df = cluster_df.to_dict(orient="index")
                        json.dump(cluster_df, json_file, indent=4)
                        
            history_frames = []
            forecast_frames = []
            local_step = -1
            chunk_id += 1