import numpy as np
import os
import ipdb
from scipy.special import softmax
import random
import matplotlib.pyplot as plt
import pandas as pd
from utils import *
import copy

#calib 30%  random
import json
# Path to the JSON file
json_file_path = '/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/06.json'
# Load the JSON file
data_split_json={}
with open(json_file_path, 'r') as file:
    data_split_json['calibration'] = json.load(file)
# raise NotImplementedError(data_split_json)
json_file_path = '/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/09.json'
# Load the JSON file
# data_split_json={}
with open(json_file_path, 'r') as file:
    data_split_json['validation'] = json.load(file)

def compute_kl_divergence(pred_dist, ideal_dist):
    pred_dist = np.maximum(pred_dist, epsilon)
    return np.sum(pred_dist * np.log(pred_dist / ideal_dist), axis=1) #entropy

ideal_distribution = np.zeros(20, dtype = np.float32)
ideal_distribution[1:] = 1.0
epsilon = 1e-10
ideal_distribution = np.maximum(ideal_distribution, epsilon)

val_output_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/voxformer_360_1355_tg/sequences/09/predictions/"
val_target_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/kitti360/preprocess_new/labels/2013_05_28_drive_0009_sync/"

cal_output_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/voxformer_360_1355_tg/sequences/06/predictions/"
cal_target_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/kitti360/preprocess_new/labels/2013_05_28_drive_0006_sync/"
results_path = "./kitti360/ours/cp1"

quantile_file = "class_balanced_entropy_quantile.npy"
R = 3
alphas = np.arange(0.99, 0.0, -0.01)

class_nums = 21
tps = np.zeros((len(alphas), R, class_nums), dtype=np.int)
fps = np.zeros((len(alphas), R, class_nums), dtype=np.int)
fns = np.zeros((len(alphas), R, class_nums), dtype=np.int)
tns = np.zeros((len(alphas), R, class_nums), dtype=np.int)
class_nums = class_nums - 1
considered_class = [6]


store_result_name = "cp_class_based_entropy_results_only_rare"
store_result_file = f"{store_result_name}.npy"


occupied_quantiles = np.zeros((len(alphas), R, class_nums))
cal_files = data_split_json['calibration']
for r in range(0,1):
    print(f"The {r} Round")
    scores = [None]*class_nums
    from tqdm import tqdm
    for file in tqdm(cal_files):

        target, output = get_output(cal_output_path, cal_target_path, file)
        for cls in range(1, class_nums):
            filtered_scores = compute_kl_divergence(output[target == cls], ideal_distribution)

            if scores[cls] is None:
                scores[cls] = filtered_scores
            else:
                scores[cls] = np.concatenate((scores[cls], filtered_scores), axis=0)

    for cls in range(1, class_nums):
            if len(scores[cls])==0:
                scores[cls] = 1e7 * np.ones(1)

    for cls in range(1, class_nums):
        for i in range(len(alphas)):
            occupied_quantiles[i, r, cls] = compute_quantile(scores[cls], alphas[i])
    os.makedirs(results_path, exist_ok=True)
    np.save(os.path.join(results_path, quantile_file), occupied_quantiles)

occupied_quantiles=np.load(os.path.join(results_path, quantile_file))

# # #########################################################################################################

for r in range(0, R):
    print(f"The {r} Round")
    files = data_split_json['validation']

    from tqdm import tqdm
    for file in tqdm(files):
        print(file)
        target, output = get_output(val_output_path, val_target_path, file)

        y_empty = (target == 0)
        for i in range(len(alphas)):
            # print(i_tmp, len(files), i, len(alphas))
            occupied_thresholds = handle_thresholds(occupied_quantiles[i, r], considered_class)

            filtered_scores = compute_kl_divergence(output, ideal_distribution)
            occupied_sets = filtered_scores <= np.array(occupied_thresholds[considered_class]).max()

            considered_classes_index = np.nonzero(occupied_sets)
            class_predict = np.zeros(output.shape[0], dtype = np.int)
            class_predict[considered_classes_index] = np.argmax(output[considered_classes_index, 1:][0], axis=1) + 1

            occupancy_predict = np.zeros(output.shape[0], dtype = np.bool)
            occupancy_predict[considered_classes_index] = True
            for cls in range(1, class_nums):
                y_true = (target == cls)
                y_pred = (class_predict == cls)

                tp = np.array(np.where(np.logical_and(y_true == 1, occupancy_predict == 1))).size
                fp = np.array(np.where(np.logical_and(y_empty == 1, y_pred == 1))).size
                fn = np.array(np.where(np.logical_and(y_true == 1, occupancy_predict != 1))).size
                tn = np.array(np.where(np.logical_and(y_empty == 1, y_pred != 1))).size

                tps[i, r, cls] += tp
                fps[i, r, cls] += fp
                fns[i, r, cls] += fn
                tns[i, r, cls] += tn
            
            cls = 0
            y_true = (target == cls)
            y_pred = (occupancy_predict == False)
            tps[i, r, cls] += np.array(np.where(np.logical_and(y_true == 1, y_pred == 1))).size
            fps[i, r, cls] += np.array(np.where(np.logical_and(y_true == 0, y_pred == 1))).size
            fns[i, r, cls] += np.array(np.where(np.logical_and(y_true == 1, y_pred == 0))).size
            tns[i, r, cls] += np.array(np.where(np.logical_and(y_true == 0, y_pred == 0))).size

            cls = 20
            tps[i, r, cls] += np.array(np.where(np.logical_and(y_true == 0, y_pred == 0))).size
            fps[i, r, cls] += np.array(np.where(np.logical_and(y_true == 1, y_pred == 0))).size
            fns[i, r, cls] += np.array(np.where(np.logical_and(y_true == 0, y_pred == 1))).size
            tns[i, r, cls] += np.array(np.where(np.logical_and(y_true == 1, y_pred == 1))).size

    store_data = np.concatenate((np.expand_dims(tps, axis=0), np.expand_dims(fps, axis=0), np.expand_dims(fns, axis=0), np.expand_dims(tns, axis=0)), axis=0)
    np.save(os.path.join(results_path, store_result_file), store_data)
#######################################################################################################
tps, fps, fns, tns = read_previous_occupanied_results(os.path.join(results_path, store_result_file))

def getAverage(data):
    return data.mean()

accuracy, precision, recall, f1_score, IoU = compute_metrics(tps, fps, fns, tns)
R = 3
for i in range(len(alphas)):
    excel_list = {}
    for cls in range(class_nums + 1):
        excel_list[class_name_dict[cls]] = [getAverage(accuracy[i,:R,cls]), getAverage(precision[i,:R,cls]), getAverage(recall[i,:R,cls]), getAverage(f1_score[i,:R,cls]), getAverage(IoU[i,:R,cls])]
    pd_data = pd.DataFrame(excel_list)
    pd_data.to_excel(os.path.join(results_path, f"{store_result_name}_{int(alphas[i]*100)}.xlsx"), index=False)

# #########################################################################################################
