import os
import json
import argparse
from tqdm import tqdm
import numpy as np
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils import static_instance_dict, load_instance_data


if __name__ == '__main__':
    # Set up argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--history_dir', 
        type=str, 
        required=True, 
        help="Path to the directory of historical JSON files"
    )
    parser.add_argument(
        '--out_dir', 
        type=str, 
        required=True, 
        help="Path to the output directory for forecasted JSON files"
    )
    args = parser.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    
    future_timesteps = np.arange(8, 16)
    
    for scene_id in tqdm(os.listdir(args.history_dir)):
        historical_data, _ = load_instance_data(os.path.join(args.history_dir, scene_id), numpy=True)
        forecasted_data = {}

        for instance_id in historical_data:
            # Initialize forecast dict
            history_dict = historical_data[instance_id]
            forecast_dict = static_instance_dict(history_dict)
                        
            # Prepare design matrix for least squares: [timestep, 1] to fit y = a*t + b
            timesteps = history_dict['timestep']
            A = np.vstack([timesteps, np.ones_like(timesteps)]).T  # shape (n, 2)
            A_future = np.vstack([future_timesteps, np.ones_like(future_timesteps)]).T  # shape (n_future, 2)

            # Translations
            trans_future = np.zeros((len(future_timesteps), 3))
            for axis in range(3):
                values = history_dict['translation'][:, axis]
                if len(values) == 1:
                    # Constant extrapolation
                    trans_future[:, axis] = values[0]
                else:
                    coeffs_trans, _, _, _ = np.linalg.lstsq(A, values, rcond=None)
                    trans_future[:, axis] = A_future @ coeffs_trans

            # Rotations
            rot_future = np.zeros((len(future_timesteps), 3))
            for axis in range(3):
                values = history_dict['rotation'][:, axis]
                if len(values) == 1:
                    # Constant extrapolation
                    rot_future[:, axis] = values[0]
                else:
                    coeffs_rot, _, _, _ = np.linalg.lstsq(A, values, rcond=None)
                    rot_future[:, axis] = A_future @ coeffs_rot


            forecast_dict['translation'] = trans_future.tolist()
            forecast_dict['rotation'] = rot_future.tolist()
            forecast_dict['timestep'] = forecast_dict['timestep'].tolist()
            forecast_dict['size'] = forecast_dict['size'].tolist()
            forecast_dict['attribute_label'] = forecast_dict['attribute_label'].tolist()
            
            forecasted_data[instance_id] = forecast_dict
        
        # Save forecasted data
        out_path = os.path.join(args.out_dir, scene_id)
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(forecasted_data, f, indent=4)
            