import json
import os
import glob

import hydra
import numpy as np
from scipy.spatial import distance


class MetricsComputer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.cfg_rl_waymo = self.cfg.datasets.rl_waymo
        output_dir = os.path.join("/".join(self.cfg.eval.model_path.split("/")[:-1]), "scene_results")
        self.all_files = glob.glob(output_dir + "/partition*.json")
        self.ades_all = []
        self.fdes_all = []
        self.goal_achieved_all = []
        # distributional realism metrics
        self.lin_speed_sim_all = []
        self.lin_speed_gt_all = []
        self.ang_speed_sim_all = []
        self.ang_speed_gt_all = []
        self.accel_sim_all = []
        self.accel_gt_all = []
        self.nearest_dist_sim_all = []
        self.nearest_dist_gt_all = []
        # common sense metrics 
        self.collision_rate_scenario = []
        self.offroad_rate_scenario = []
        print(self.all_files)
        for json_file in self.all_files:
            with open(json_file, 'r') as f:
                data = json.load(f)
            self.ades_all.append(data['ade'])
            self.fdes_all.append(data['fde'])
            self.goal_achieved_all.append(data["goal_success"])
            self.lin_speed_sim_all.append(np.concatenate(data["lin_speed_sim"]))
            self.lin_speed_gt_all.append(np.concatenate(data['lin_speed_gt']))
            self.ang_speed_sim_all.append(np.concatenate(data['ang_speed_sim']))
            self.ang_speed_gt_all.append(np.concatenate(data["ang_speed_gt"]))
            self.accel_sim_all.append(self.nested_data_handler(data["accel_sim"]))
            self.accel_gt_all.append(self.nested_data_handler(data["accel_gt"]))
            self.nearest_dist_sim_all.append(np.concatenate(data["nearest_dist_sim"]))
            self.nearest_dist_gt_all.append(np.concatenate(data["nearest_dist_gt"]))
            self.collision_rate_scenario.append(data["collision"])
            self.offroad_rate_scenario.append(data["off_road"])

        self.ades_all = np.concatenate(self.ades_all, axis=0)
        self.fdes_all = np.concatenate(self.fdes_all, axis=0)
        self.goal_achieved_all = np.concatenate(self.goal_achieved_all)
        self.lin_speed_sim_all = np.concatenate(self.lin_speed_sim_all)
        self.lin_speed_gt_all = np.concatenate(self.lin_speed_gt_all)
        self.ang_speed_sim_all = np.concatenate(self.ang_speed_sim_all)
        self.ang_speed_gt_all = np.concatenate(self.ang_speed_gt_all)
        self.accel_sim_all = np.concatenate(self.accel_sim_all)
        self.accel_gt_all = np.concatenate(self.accel_gt_all)
        self.nearest_dist_sim_all = np.concatenate(self.nearest_dist_sim_all)
        self.nearest_dist_gt_all = np.concatenate(self.nearest_dist_gt_all)
        self.collision_rate_scenario = np.concatenate(self.collision_rate_scenario, axis=0)
        self.offroad_rate_scenario = np.concatenate(self.offroad_rate_scenario, axis=0)
    
    def nested_data_handler(self, data):
        new_data = []
        for d in data:
            new_data.append(np.array(d).reshape(-1))
        return np.concatenate(new_data)

    def compute_metrics(self):
        metrics_dict = {}

        metrics_dict['goal'] = np.array(self.goal_achieved_all).mean()
        metrics_dict['collision_rate'] = np.array(self.collision_rate_scenario).mean()
        metrics_dict['offroad_rate'] = np.array(self.offroad_rate_scenario).mean()
        
        metrics_dict['fde'] = np.array(self.fdes_all).mean()
        metrics_dict['ade'] = np.array(self.ades_all).mean()

        # lin speed jsd 
        lin_speeds_gt = np.concatenate(self.lin_speed_gt_all, axis=0)
        lin_speeds_sim = np.concatenate(self.lin_speed_sim_all, axis=0)
        lin_speeds_gt = np.clip(lin_speeds_gt, 0, 30)
        lin_speeds_sim = np.clip(lin_speeds_sim, 0, 30)
        bin_edges = np.arange(201) * 0.5 * (100 / 30)
        P_lin_speeds_sim = np.histogram(lin_speeds_sim, bins=bin_edges)[0] / len(lin_speeds_sim)
        Q_lin_speeds_sim = np.histogram(lin_speeds_gt, bins=bin_edges)[0] / len(lin_speeds_gt)
        metrics_dict['lin_speed_jsd'] = distance.jensenshannon(P_lin_speeds_sim, Q_lin_speeds_sim)
        
        # ang speed jsd
        ang_speeds_gt = np.concatenate(self.ang_speed_gt_all, axis=0)
        ang_speeds_sim = np.concatenate(self.ang_speed_sim_all, axis=0)
        ang_speeds_gt = np.clip(ang_speeds_gt, -50, 50)
        ang_speeds_sim = np.clip(ang_speeds_sim, -50, 50)
        bin_edges = np.arange(201) * 0.5 - 50 
        P_ang_speeds_sim = np.histogram(ang_speeds_sim, bins=bin_edges)[0] / len(ang_speeds_sim)
        Q_ang_speeds_sim = np.histogram(ang_speeds_gt, bins=bin_edges)[0] / len(ang_speeds_gt)
        metrics_dict['ang_speed_jsd'] = distance.jensenshannon(P_ang_speeds_sim, Q_ang_speeds_sim)

        # accel jsd
        # discretize then undiscretize gt actions
        # accels_gt = np.concatenate(self.accel_gt_all, axis=0)
        accels_gt = self.accel_gt_all
        accels_gt =  ((np.clip(accels_gt, a_min=self.cfg_rl_waymo.min_accel, a_max=self.cfg_rl_waymo.max_accel) - self.cfg_rl_waymo.min_accel)
                             / (self.cfg_rl_waymo.max_accel - self.cfg_rl_waymo.min_accel))
        accels_gt = np.round(accels_gt * (self.cfg_rl_waymo.accel_discretization - 1))
        accels_gt /= (self.cfg_rl_waymo.accel_discretization - 1)
        accels_gt = (accels_gt * (self.cfg_rl_waymo.max_accel - self.cfg_rl_waymo.min_accel)) + self.cfg_rl_waymo.min_accel
        # accels_sim = np.concatenate(self.accel_sim_all, axis=0)
        accels_sim = self.accel_sim_all
        bin_edges = np.arange(self.cfg.datasets.rl_waymo.accel_discretization + 1) * 2 - self.cfg.datasets.rl_waymo.accel_discretization
        P_accels_sim = np.histogram(accels_sim, bins=bin_edges)[0] / len(accels_sim)
        Q_accels_sim = np.histogram(accels_gt, bins=bin_edges)[0] / len(accels_gt)
        metrics_dict['accel_jsd'] = distance.jensenshannon(P_accels_sim, Q_accels_sim)

        # nearest dist jsd
        nearest_dists_gt = np.concatenate(self.nearest_dist_gt_all, axis=0)
        nearest_dists_sim = np.concatenate(self.nearest_dist_sim_all, axis=0)
        nearest_dists_gt = np.clip(nearest_dists_gt, 0, 40)
        nearest_dists_sim = np.clip(nearest_dists_sim, 0, 40)
        bin_edges = np.arange(201) * 0.5 * (100 / 40)
        P_nearest_dists_sim = np.histogram(nearest_dists_sim, bins=bin_edges)[0] / len(nearest_dists_sim)
        Q_nearest_dists_sim = np.histogram(nearest_dists_gt, bins=bin_edges)[0] / len(nearest_dists_gt)
        metrics_dict['nearest_dist_jsd'] = distance.jensenshannon(P_nearest_dists_sim, Q_nearest_dists_sim)
        return metrics_dict, ["{}: {:.6f}".format(k,v) for (k,v) in metrics_dict.items()]
    


@hydra.main(version_base=None, config_path="/home/git_repos/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    met_computer = MetricsComputer(cfg)
    metrics_dict, metrics_str = met_computer.compute_metrics()
    print(metrics_str)

if __name__ == "__main__":
    main()