import numpy as np
import torch
import math
import ipdb
import os
import uncertainty_toolbox as uct
import properscoring as ps
import matplotlib.pyplot as plt

def compute_quantile(scores, alpha):
    n = scores.shape[0]
    q_level = np.ceil((n+1)*(1-alpha))/n
    q = np.quantile(scores, q_level, interpolation='higher')
    return q

# ipdb.set_trace()
data_path = "./kitti/depth_uq/sequences"
gt_path = "./kitti/lidar_depth"

# # scene = "00"
# # scene_path = os.path.join(data_path, scene)
# # file = "000000_uq.npy"
# # file_path = os.path.join(scene_path, file)
# # print(file_path)
# # data = np.load(file_path, allow_pickle=True)
# # data_mean = np.load(os.path.join(scene_path, "000000.npy"), allow_pickle=True)
# # print(data.shape)

# scores = None
# scenes = ["00", "01", "02", "03", "04", "05", "06", "07", "09", "10"]
# for scene in scenes:
#     scene_path = os.path.join(data_path, scene)
#     files_in_folder = os.listdir(scene_path)
#     print(scene)
#     for file in files_in_folder:
#         # file_num = '{:0{width}d}'.format(i, width=6)
#         if file[-3:] != "npy":
#             continue
#         file = file.split(".")[0]
#         if file[-2:] == "uq":
#             continue
#         data_mean = np.load(os.path.join(scene_path, f"{file}.npy"), allow_pickle=True)
#         data_cov = np.load(os.path.join(scene_path, f"{file}_uq.npy"), allow_pickle=True)
#         data_cov = np.exp(data_cov/2)
#         data_gt = np.load(os.path.join(gt_path, scene, "lidar_depth", f"{file}.npz"), allow_pickle=True)
#         data_gt = data_gt['depth_image']
#         data_gt = data_gt.reshape(-1)
#         data_mean = data_mean.reshape(-1)
#         data_cov = data_cov.reshape(-1)
#         index = data_gt != -1
#         data_gt = data_gt[index]
#         data_mean = data_mean[index]
#         data_cov = data_cov[index]
#         score = np.abs(data_gt-data_mean) / data_cov
#         if scores is None:
#             scores = score
#         else:
#             scores = np.concatenate((scores, score))

# confident_range = 0.6827
# quantile = compute_quantile(scores, 1-confident_range)
# print(f"For {confident_range}, quantile: {quantile}")

# confident_range = 0.9545
# quantile = compute_quantile(scores, 1-confident_range)
# print(f"For {confident_range}, quantile: {quantile}")

# confident_range = 0.9973
# quantile = compute_quantile(scores, 1-confident_range)
# print(f"For {confident_range}, quantile: {quantile}")


def nll(mean, std, target):
    var = np.power(std, 2)
    negative_log_prob = 0.5 * np.log(2*np.pi*var) + 0.5 * ((mean - target) ** 2 / var)
    return negative_log_prob.mean()


def crps(mean, std, target):
    crps = ps.crps_gaussian(target, mean, std)
    return crps.mean()

def ece(mean, std, target):
    return 0


mean = None
std = None
target = None
scenes = ["08"]
quantile = 1.6550744771957397
# ipdb.set_trace()
quantile_flag = True
for scene in scenes:
    scene_path = os.path.join(data_path, scene)
    files_in_folder = os.listdir(scene_path)
    print(scene)
    files_in_folder = files_in_folder[:100]
    for file in files_in_folder:
        # file_num = '{:0{width}d}'.format(i, width=6)
        if file[-3:] != "npy":
            continue
        file = file.split(".")[0]
        if file[-2:] == "uq":
            continue
        print(file)
        data_mean = np.load(os.path.join(scene_path, f"{file}.npy"), allow_pickle=True)
        data_cov = np.load(os.path.join(scene_path, f"{file}_uq.npy"), allow_pickle=True)
        data_cov = np.exp(data_cov/2)
        data_gt = np.load(os.path.join(gt_path, scene, "lidar_depth", f"{file}.npz"), allow_pickle=True)
        data_gt = data_gt['depth_image']
        data_gt = data_gt.reshape(-1)
        data_mean = data_mean.reshape(-1)
        data_cov = data_cov.reshape(-1)
        index = data_gt != -1
        data_gt = data_gt[index]
        data_mean = data_mean[index]
        data_cov = data_cov[index]
        if quantile_flag:
            data_cov = data_cov * quantile
        if mean is None:
            mean = data_mean
            std = data_cov
            target = data_gt
        else:
            mean = np.concatenate((mean, data_mean), axis=0)
            std = np.concatenate((std, data_cov), axis=0)
            target = np.concatenate((target, data_gt), axis=0)
    
    # NLL = nll(mean, std, target)
    # CRPS = crps(mean, std, target)

    uct.viz.plot_calibration(mean, std, target)
    plt.gcf().set_size_inches(4, 4)
    if quantile_flag:
        plt.savefig(f"figure/depth_cp_calibration.png")
    else:
        plt.savefig(f"figure/depth_wo_cp_calibration.png")
    
    metrics = uct.metrics.get_all_metrics(mean, std, target)

    # print(f"NLL: {NLL}, CRPS: {CRPS}")
    print(metrics)

