import numpy as np
import os
import ipdb
from scipy.special import softmax
import random
import matplotlib.pyplot as plt
import pandas as pd
from utils import *
import copy
from tqdm import tqdm
# first

#calib 30%  random
import json
# Path to the JSON file
json_file_path = '/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/06.json'
data_split_json={}
with open(json_file_path, 'r') as file:
    data_split_json['calibration'] = json.load(file)[:]
# raise NotImplementedError(data_split_json)
json_file_path = '/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/09.json'
# Load the JSON file
# data_split_json={}
with open(json_file_path, 'r') as file:
    data_split_json['validation'] = json.load(file)[:]


def compute_kl_divergence(pred_dist, ideal_dist):
    pred_dist = np.maximum(pred_dist, epsilon)
    return np.sum(pred_dist * np.log(pred_dist / ideal_dist), axis=1) #entropy
    # return np.sum(ideal_dist * np.log(ideal_dist / pred_dist), axis=1)

ideal_distribution = np.zeros(20, dtype = np.float32)
ideal_distribution[1:] = 1.0
epsilon = 1e-10
ideal_distribution = np.maximum(ideal_distribution, epsilon)

val_output_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/voxformer_360_1355_tg/sequences/09/predictions/"
val_target_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/kitti360/preprocess_new/labels/2013_05_28_drive_0009_sync/"

cal_output_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/voxformer_360_1355_tg/sequences/06/predictions/"
cal_target_path = "/root/autodl-tmp/vox/mmdetection3d/VoxFormer_UQ/kitti360/preprocess_new/labels/2013_05_28_drive_0006_sync/"
results_path = "./kitti360/ours/sem_1"
os.makedirs(results_path, exist_ok=True)

geometry_quantile_file = "./kitti360/ours/cp1/class_balanced_entropy_quantile.npy"
R = 3
R = 1

alphas = np.arange(0.99, 0.00, -0.01) 

class_nums = 21
tps = np.zeros((R, class_nums), dtype=np.int)
fps = np.zeros((R, class_nums), dtype=np.int)
fns = np.zeros((R, class_nums), dtype=np.int)
tns = np.zeros((R, class_nums), dtype=np.int)
class_nums = class_nums - 1
considered_class = [6]

target_aplha=np.array([6,39,36,37,43,41,76,94,17,7,7,10,18,37,31,41,49,14,65,76,30])/100

store_result_name = "cp_class_based_entropy_results_only_rare"
store_result_file = f"{store_result_name}.npy"
person_alpha=0.7

store_semantic_score_file = f"semantic_scores_for_person_{int(person_alpha*100)}.npy"
store_semantic_quantile_file = f"semantic_quantile_for_person_{int(person_alpha*100)}.npy"
store_coverage_result_file = f"coverage_for_HCP_person{int(person_alpha*100)}.xlsx"

occupied_quantiles = np.load(os.path.join(geometry_quantile_file))

index_per=np.isclose(person_alpha, alphas)
index_per=np.nonzero(index_per)
index_per=index_per[0][0]

for r in range(0,1):
    print(f"The {r} Round")
    scores = [None]*class_nums
    cal_files = data_split_json['calibration']

 # # ############################################################################################################################   
    ## generate score
    for file in tqdm(cal_files):
        target, output = get_output(cal_output_path, cal_target_path, file)

        
        person_quantile=occupied_quantiles[index_per, r, 6]
        skl = compute_kl_divergence(output, ideal_distribution)
        filtered_index = skl<=person_quantile
        for cls in range(1, class_nums):
            index = target == cls
            tps_index=np.logical_and(index, filtered_index)
            fns_index=np.logical_and(index, filtered_index==False)
            tps_sg=tps_index.sum()
            fns_sg=fns_index.sum()
            tps[r, cls] += tps_sg
            fns[r, cls] += fns_sg
            
            filtered_scores = 1 - output[:,cls]
            filtered_scores = filtered_scores[tps_index]

            if scores[cls] is None:
                scores[cls] = filtered_scores
            else:
                
                scores[cls] = np.concatenate((scores[cls], filtered_scores), axis=0)

    store_data = np.concatenate((np.expand_dims(tps, axis=0), np.expand_dims(fps, axis=0), np.expand_dims(fns, axis=0), np.expand_dims(tns, axis=0)), axis=0)
    np.save(os.path.join(results_path, store_result_file), store_data)
    #save score
    np.save(os.path.join(results_path, store_semantic_score_file), scores)

# # ############################################################################################################################
    tps, fps, fns, tns = read_previous_occupanied_results(os.path.join(results_path, store_result_file))
    scores = np.load(os.path.join(results_path, store_semantic_score_file), allow_pickle=True)
    
    semantic_quantiles = np.zeros(class_nums)
    for cls in range(1, class_nums):
        if cls !=6:
            alpha_oy= 1 - (tps[r, cls])/(tps[r, cls]+fns[r, cls]+1e-20)
        else:
            alpha_oy=person_alpha
    

        alpha_oy=max(alpha_oy,1e-20)
        target_aplha_class=int(alpha_oy*1.1*100+5)/100
        target_aplha[cls]=target_aplha_class
        print("target_aplha_class",target_aplha_class)

        alpha_sy = 1 - (1-target_aplha_class)/(1-alpha_oy)
        
        alpha_sy=max(alpha_sy, 1e-10)

        print(f"For class {cls}: alpha_y {target_aplha[cls]}, alpha_oy {alpha_oy}, alpha_sy {alpha_sy}")
        print(scores[cls])
        semantic_quantiles[cls] = compute_quantile(scores[cls], alpha_sy)
    np.save(os.path.join(results_path, store_semantic_quantile_file), semantic_quantiles)

# # ############################################################################################################################

semantic_quantiles = np.load(os.path.join(results_path, store_semantic_quantile_file))
all_set_sizes = 0
voxels_nums = 0

class_group_sizes = np.zeros(class_nums)
class_coverage_sizes = np.zeros(class_nums)

for r in range(0, R):
    print(f"The {r} Round")
    val_files = data_split_json['validation']
    for file in tqdm(val_files):
        target, output = get_output(val_output_path, val_target_path, file)

        person_quantile=occupied_quantiles[index_per, r, 6]
        skl = compute_kl_divergence(output, ideal_distribution)
        filtered_index = skl<=person_quantile
        occupied_output = output[filtered_index]
        occupied_target = target[filtered_index]
        prediction_sets = (1 - occupied_output) <= semantic_quantiles

        nonempty_prediction_sets = prediction_sets[:,1:]
        voxels_nums += output.shape[0]
        set_sizes = nonempty_prediction_sets.sum()
        all_set_sizes += set_sizes

        for cls in range(1, class_nums):
            class_group_sizes[cls] += (target == cls).sum()

            index = occupied_target == cls
            class_prediction_sets = prediction_sets[index]
            class_coverage_sizes[cls] += class_prediction_sets[:,cls].sum()


    AvgSize = all_set_sizes / voxels_nums

    empirical_coverage = np.zeros(class_nums)
    coverage_gap = np.zeros(class_nums)
    excel_list = {}

    for cls in range(1, class_nums):
        empirical_coverage[cls] = class_coverage_sizes[cls] / (class_group_sizes[cls]+1e-20)
        coverage_gap[cls] = abs(empirical_coverage[cls] - (1- target_aplha[cls]))
        print(f"For class {cls}: empirical_coverage: {empirical_coverage[cls]}, coverage_gap: {coverage_gap[cls]}")
        excel_list[class_name_dict[cls]] = [empirical_coverage[cls], coverage_gap[cls], 1-target_aplha[cls]]
    
    CovGap = coverage_gap[1:].mean()
    print(f"For all, AvgSize {AvgSize}, CovGap {CovGap}")
    excel_list["CovGap"] = [CovGap, -1, -1]
    excel_list["AvgSize"] = [AvgSize, -1, -1]
    pd_data = pd.DataFrame(excel_list)
    pd_data.to_excel(os.path.join(results_path, store_coverage_result_file), index=False)

# # # # ############################################################################################################################
