import copy
import os
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils import load_instance_data, instance_data_scored_timesteps, add_score, is_collision, instance_motion_category, out_of_map, static_instance_dict

def eval_scene(history_file, forecast_file, prediction_file, map_file):
    historical_data, _ = load_instance_data(history_file, numpy=True)
    forecasted_data, _ = load_instance_data(forecast_file, numpy=True)

    if os.path.isdir(prediction_file):
        predicted_files = [load_instance_data(f, numpy=True, historical_data=historical_data) for f in os.listdir(prediction_file)]
    elif prediction_file.endswith('.json'):
        predicted_files = [load_instance_data(prediction_file, numpy=True, historical_data=historical_data)]
    else:
        print('!!!!!!!!!!!!!', history_file, forecast_file, f'{prediction_file}.json')
        predicted_files = [load_instance_data(f'{prediction_file}.json', numpy=True, historical_data=historical_data)]

    # Iterate over the multiple (K) guesses to find the "best world"
    add_worlds = []
    predicted_files_copy = copy.deepcopy(predicted_files)
    for predicted_data, _ in predicted_files_copy:
        for instance_id, instance_dict in predicted_data.items():
            for key, val in instance_dict.items():
                if key == 'attribute_label':
                    continue
                elif key in ['timestep', 'translation', 'rotation', 'size']:
                    assert not np.issubdtype(val.dtype, np.unicode_), (instance_id, key)
        valid_future_agents = list(set(historical_data.keys()) & set(forecasted_data.keys()))

        # Collect ADD scores for each instance in this world
        add_k = []
        for instance_id in valid_future_agents:
            if instance_id not in predicted_data:
                # Assume the instance to be static at the last historical positions
                predicted_data[instance_id] = static_instance_dict(historical_data[instance_id])

            translations_gt, rotations_gt, sizes_gt, translations_pred, rotations_pred, sizes_pred, scored_timesteps = instance_data_scored_timesteps(forecasted_data, predicted_data, instance_id, verbose=False)
            if sizes_pred is None:
                sizes_pred = np.array([sizes_gt[-1]] * len(translations_pred))

            try:
                add = add_score(translations_pred, rotations_pred, sizes_pred, translations_gt, rotations_gt, sizes_gt)
            except:
                raise ValueError(translations_pred.shape, translations_pred.dtype, translations_gt.shape, translations_gt.dtype, scored_timesteps)
            add_k.append(add)
        
        # Average ADD score over all instances in this world
        add_worlds.append(np.mean(add_k))
    
    # Find the "best world" as the one with minimal ADD and compute all metrics in it
    k_star = np.argmin(add_worlds)
    predicted_data, correct_format = predicted_files[k_star]

    metrics = {'overall': {}, 'static': {}, 'linear': {}, 'nonlinear': {}}    

    # Compute Instance Precision (P_instance) and Recall (R_instance)
    # If an agent appears in the forecast times, we don't expect it to be predicted by the model, so we don't count it
    predicted_agents = set(predicted_data.keys())
    valid_future_agents = list(set(historical_data.keys()) & set(forecasted_data.keys()))
    valid_predicted_agents = set(valid_future_agents) & set(predicted_data.keys())
    
    try:
        # Compute Formatting Accuracy (ACC_f)
        correctly_formatted = 0
        for instance_id in valid_predicted_agents:
            correctly_formatted += 1 if correct_format[instance_id] else 0
        metrics['overall']['Formatting Accuracy (ACC_f)'] = correctly_formatted / len(valid_predicted_agents) if len(valid_predicted_agents) > 0 else 0.
    except KeyError as e:
        raise ValueError(valid_predicted_agents, '____________________', correct_format, '____________________', correct_format.keys(), historical_data.keys())

    colliding_agents = []
    missed_agents = {'overall': 0, 'static': 0, 'linear': 0, 'nonlinear': 0}
    add_agents = {'overall': [], 'static': [], 'linear': [], 'nonlinear': []}
    fde_agents = {'overall': [], 'static': [], 'linear': [], 'nonlinear': []}
    ade_agents = {'overall': [], 'static': [], 'linear': [], 'nonlinear': []}
    re_agents = {'overall': [], 'static': [], 'linear': [], 'nonlinear': []}
    vhe_agents = {'overall': [], 'static': [], 'linear': [], 'nonlinear': []}
    omr_agents = {'overall': 0, 'static': 0, 'linear': 0, 'nonlinear': 0}

    for instance_id in valid_future_agents:
        if instance_id not in predicted_data:
            # Assume the instance to be static at the last historical positions
            predicted_data[instance_id] = static_instance_dict(historical_data[instance_id])

        translations_gt, rotations_gt, sizes_gt, translations_pred, rotations_pred, sizes_pred, scored_idxs = instance_data_scored_timesteps(forecasted_data, predicted_data, instance_id, verbose=False)
        scored_timesteps = predicted_data[instance_id]['timestep'][scored_idxs]
        
        # If an instance is not forecasted at any target timestep, count it as a miss and consider its latest position
        # if len(scored_timesteps) < len(forecasted_data[instance_id]['timestep']):
        if np.any(~np.isin(forecasted_data[instance_id]['timestep'], scored_timesteps)):
            valid_predicted_agents.remove(instance_id)

            missing = set(forecasted_data[instance_id]['timestep']) - set(scored_timesteps)
            nearest = []
            for val in sorted(missing):
                closest = min(predicted_data[instance_id]['timestep'], key=lambda x: (abs(x - val), x))
                nearest.append(closest)
            
            valid_pred_timesteps = set(predicted_data[instance_id]['timestep']) & set(scored_timesteps)
            valid_pred_timesteps = np.array(list(valid_pred_timesteps))
            new_pred_timesteps = np.sort(np.hstack((nearest, valid_pred_timesteps)))
            new_pred_idxs = np.searchsorted(predicted_data[instance_id]['timestep'], new_pred_timesteps, side='left')
            new_pred_idxs = np.where(
                (new_pred_idxs < len(predicted_data[instance_id]['timestep'])) & (predicted_data[instance_id]['timestep'][new_pred_idxs] == new_pred_timesteps),
                new_pred_idxs,
                new_pred_idxs - 1
            )
            predicted_data[instance_id]['translation'] = predicted_data[instance_id]['translation'][new_pred_idxs]
            predicted_data[instance_id]['rotation'] = predicted_data[instance_id]['rotation'][new_pred_idxs]
            if 'size' in predicted_data[instance_id]:
                predicted_data[instance_id]['size'] = predicted_data[instance_id]['size'][new_pred_idxs]
            if 'attribute_label' in predicted_data[instance_id]:
                predicted_data[instance_id]['attribute_label'] = predicted_data[instance_id]['attribute_label'][new_pred_idxs]
            predicted_data[instance_id]['timestep'] = np.sort(np.hstack((scored_timesteps, list(missing))))

            translations_gt, rotations_gt, sizes_gt, translations_pred, rotations_pred, sizes_pred, scored_idxs = instance_data_scored_timesteps(forecasted_data, predicted_data, instance_id, verbose=False)
            scored_timesteps = predicted_data[instance_id]['timestep'][scored_idxs]

        if sizes_pred is None:
            sizes_pred = np.array([sizes_gt[-1]] * len(translations_pred))

        # Get motion category ('static', 'linear', 'nonlinear') of instance
        motion_category, _ = instance_motion_category(historical_data[instance_id]['translation'], translations_gt, historical_data[instance_id]['timestep'], scored_timesteps)

        # Collision checking to compute Collision Rate (CR)
        if instance_id not in colliding_agents:
            for instance_id_j in valid_future_agents:
                if instance_id_j not in predicted_data:
                    predicted_data[instance_id_j] = static_instance_dict(historical_data[instance_id_j])
                if (instance_id_j in colliding_agents) or (instance_id_j == instance_id):
                    continue
                _, _, sizes_gt_j, translations_pred_j, rotations_pred_j, sizes_pred_j, scored_idxs_j = instance_data_scored_timesteps(forecasted_data, predicted_data, instance_id_j, verbose=False)                
                if sizes_pred_j is None:
                    sizes_pred_j = np.array([sizes_gt_j[-1]] * len(translations_pred_j))
                scored_timesteps_j = predicted_data[instance_id_j]['timestep'][scored_idxs_j]
                common_timesteps = list(set(scored_timesteps) & set(scored_timesteps_j))
                common_timesteps_i = np.where(np.isin(common_timesteps, scored_timesteps))[0]
                common_timesteps_j = np.where(np.isin(common_timesteps, scored_timesteps_j))[0]
                if is_collision(translations_pred[common_timesteps_i], rotations_pred[common_timesteps_i], sizes_pred[common_timesteps_i], 
                                translations_pred_j[common_timesteps_j], rotations_pred_j[common_timesteps_j], sizes_pred_j[common_timesteps_j]):
                    colliding_agents += [instance_id, instance_id_j]
        
        # Compute Average Distance of Model Points (ADD)
        add = add_score(translations_pred, rotations_pred, sizes_pred, translations_gt, rotations_gt, sizes_gt)
        add_agents['overall'].append(add)
        add_agents[motion_category].append(add)

        # Compute Miss Rate (MR)
        trajectory = np.vstack((historical_data[instance_id]['translation'], translations_gt))
        distances = np.linalg.norm(np.diff(trajectory, axis=0), axis=1)
        total_length = np.sum(distances)
        if add > max(0.1*total_length, 1):
            missed_agents['overall'] += 1
            missed_agents[motion_category] += 1
        
        # Compute Final Displacement Error (FDE) & Average Displacement Error (ADE)
        fde = np.linalg.norm(translations_pred[-1] - translations_gt[-1])
        fde_agents['overall'].append(fde)
        fde_agents[motion_category].append(fde)
        ade = np.mean(np.linalg.norm(translations_pred - translations_gt, axis=1))
        ade_agents['overall'].append(ade)
        ade_agents[motion_category].append(ade)

        # Compute Rotation Error (RE)
        rot_error = np.mean(np.arccos(np.cos(rotations_pred - rotations_gt)), axis=0)
        re_agents['overall'].append(rot_error)
        re_agents[motion_category].append(rot_error)

        # Compute Velocity Heading Shift (VHS)
        velocities = translations_pred[1:] - translations_pred[:-1]
        not_static = np.linalg.norm(velocities, axis=1) >= 0.5
        vel_angle = np.arctan2(velocities[not_static, 1], velocities[not_static, 0])
        if len(vel_angle) > 0:
            vhe = np.abs(np.mean(np.arccos(np.cos(vel_angle - rotations_pred[:-1, -1][not_static]))))
            vhe_agents['overall'].append(vhe)
            vhe_agents[motion_category].append(vhe)
        
        # Compute Out of Map Rate (OMR)
        if map_file is not None:
            with open(map_file, 'rb') as f:
                concave_hull = pickle.load(f)
            if out_of_map(translations_pred, concave_hull):
                omr_agents['overall'] += 1
                omr_agents[motion_category] += 1
    
    # Compute metrics for all motion categories
    if len(valid_predicted_agents) == len(predicted_agents) == 0:
        instance_precision = 0.
    else:
        instance_precision = len(valid_predicted_agents) / len(predicted_agents)
    if len(valid_predicted_agents) == len(valid_future_agents) == 0:
        instance_recall = 0.
    else:
        instance_recall = len(valid_predicted_agents) / len(valid_future_agents)
    # Compute F1 score
    if instance_precision + instance_recall == 0:
        instance_f1 = 0.0
    else:
        instance_f1 = 2 * instance_precision * instance_recall / (instance_precision + instance_recall)
    metrics['overall']['Instance F1 Score (F1_instance)'] = instance_f1
    metrics['overall']['Instance Precision (P_instance)'] = instance_precision
    metrics['overall']['Instance Recall (R_instance)'] = instance_recall
    metrics['overall']['Collision Rate (CR)'] = len(set(colliding_agents)) / len(predicted_data)
    for motion_category in metrics:
        num_agents = len(fde_agents[motion_category])
        if num_agents == 0:
            continue
        metrics[motion_category]['Final Displacement Error (FDE)'] = np.mean(fde_agents[motion_category])
        metrics[motion_category]['Average Distance of Model Points (ADD)'] = np.mean(add_agents[motion_category])
        metrics[motion_category]['Average Displacement Error (ADE)'] = np.mean(ade_agents[motion_category])
        metrics[motion_category]['Rotation Error (RE)'] = np.mean(re_agents[motion_category])
        metrics[motion_category]['Velocity Heading Shift (VHS)'] = np.mean(vhe_agents[motion_category]) if len(vhe_agents[motion_category]) > 0 else 0.
        metrics[motion_category]['Miss Rate (MR)'] = missed_agents[motion_category] / num_agents
        if map_file is not None:
            metrics[motion_category]['Out of Map Rate (OMR)'] = omr_agents[motion_category] / num_agents
    
    return metrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Evaluate vehicle forecasting")
    parser.add_argument('-H', '--history_files', type=str, required=True, help="Path to the directory containing historical data files")
    parser.add_argument('-F', '--forecast_files', type=str, required=True, help="Path to the directory containing ground-truth forecast data files")
    parser.add_argument('-P', '--prediction_files', type=str, required=True, help="Path to the directory containing predicted forecast data files/directories (in case of K>1)")
    parser.add_argument('-O', '--output_dir', type=str, required=True, help="Path to the directory to save summary of metrics")
    parser.add_argument('-M', '--map_files', type=str, default=None, help="Path to the directory containing map boundaries files (if None, don't compute Out-of-Map Rate)")
    parser.add_argument('--save_add', action='store_true', help="Whether to save a plot of the per-scene ADD score (for outlier detection)")
    args = parser.parse_args()

    all_metrics = []
    add_summary = {}

    for scene_file in tqdm(sorted(os.listdir(args.prediction_files))):
        print(scene_file)
        scene_id = scene_file.split('.')[0]
        main_scene_id = '_'.join(scene_id.split('_')[:-2])
        history_file = os.path.join(args.history_files, scene_file)
        forecast_file = os.path.join(args.forecast_files, scene_file)
        prediction_file = os.path.join(args.prediction_files, scene_id)
        map_file = os.path.join(args.map_files, f'{main_scene_id}.pkl') if args.map_files is not None else None

        metrics = eval_scene(history_file, forecast_file, prediction_file, map_file)
        print(json.dumps(metrics, sort_keys=False, indent=4))
        print()
        all_metrics.append(metrics)
        add_summary[scene_id] = metrics['overall']['Average Distance of Model Points (ADD)']
    
    # Aggregate metrics across all files
    out_metrics = {'overall': {}, 'static': {}, 'linear': {}, 'nonlinear': {}}
    out_metrics_median = {'overall': {}, 'static': {}, 'linear': {}, 'nonlinear': {}}
    for metrics in all_metrics:
        for motion_category, sub_dict in metrics.items():
            if sub_dict == {}:
                continue
            for metric, val in sub_dict.items():
                out_metrics[motion_category][metric] = out_metrics[motion_category].get(metric, []) + [val]
                out_metrics_median[motion_category][metric] = out_metrics_median[motion_category].get(metric, []) + [val]
    for motion_category, sub_dict in out_metrics.items():
        for metric, val in sub_dict.items():
            out_metrics[motion_category][metric] = np.mean(out_metrics[motion_category][metric])
            out_metrics_median[motion_category][metric] = np.median(out_metrics_median[motion_category][metric])

    os.makedirs(args.output_dir, exist_ok=True)
    model_name = args.prediction_files.split("/")[-2] if args.prediction_files.endswith("/") else args.prediction_files.split("/")[-1]
    out_file = os.path.join(args.output_dir, f'{model_name}.json')
    out_file_median = os.path.join(args.output_dir, f'{model_name}_median.json')
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(out_metrics, f, sort_keys=False, indent=4)
        print(json.dumps(out_metrics, sort_keys=False, indent=4))
        print('Summary of evaluation metrics saved to: ', out_file)
    with open(out_file_median, "w", encoding="utf-8") as f:
        json.dump(out_metrics_median, f, sort_keys=False, indent=4)
        print(json.dumps(out_metrics_median, sort_keys=False, indent=4))
        print('Summary of evaluation metrics (median) saved to: ', out_file_median)
    
    if args.save_add:
        # Save barchart of ADD values per scene
        plt.figure(figsize=(6, len(add_summary) * 0.5))
        plt.barh(add_summary.keys(), add_summary.values(), color='skyblue')
        plt.xlabel('ADD')
        plt.ylabel('scene_id')
        plt.title('Average Distance of Model Points (ADD) by scene_id')
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f'{model_name}_add.png'))
        plt.close()
