import pickle
import time

import numpy as np
try:
    import kornia
except:
    pass
import torch
import tqdm

from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils
from pcdet.utils.spconv_utils import spconv

import torch.nn as nn
import torch.nn.functional as F

from pcdet.ops.chamfer_distance import ChamferDistance
from pcdet.ops.Density_aware_Chamfer_Distance.utils.model_utils import calc_cd_full
from pathlib import Path

def load_data_to_gpu(batch_dict):
    for key, val in batch_dict.items():
        if not isinstance(val, np.ndarray):
            continue
        elif key in ["frame_id", "metadata", "calib"]:
            continue
        elif key in ["images"]:
            batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous()
        elif key in ["image_shape"]:
            batch_dict[key] = torch.from_numpy(val).int().cuda()
        else:
            batch_dict[key] = torch.from_numpy(val).float().cuda()


def cls_filtering(reconstruction_binary, reconstruction):
    reconstruction_filtered = spconv.SparseConvTensor(
    features=reconstruction.features[reconstruction_binary.features[:,0]>0],       
    indices=reconstruction.indices[reconstruction_binary.features[:,0]>0].int(),    
    spatial_shape=reconstruction.spatial_shape,                   
    batch_size=reconstruction.dense().shape[0],
    )
    return reconstruction_filtered


def eval_one_epoch(
    cfg,
    detector_meanVFE,
    detector_backbone,
    inversion_model_out,
    inversion_model_4,
    inversion_model_3,
    inversion_model_2,
    dataloader,
    epoch_id,
    logger,
    dist_test=False,
    save_to_file=False,
    result_dir=None,
    start_layer=5,
):
    result_dir.mkdir(parents=True, exist_ok=True)

    final_output_dir = result_dir / "final_result" / "data"
    if save_to_file:
        final_output_dir.mkdir(parents=True, exist_ok=True)

    dataset = dataloader.dataset
    class_names = dataset.class_names
    det_annos = []

    logger.info("*************** EPOCH %s EVALUATION *****************" % epoch_id)
    
    detector_meanVFE.eval()
    inversion_model_out.eval()
    inversion_model_4.eval()
    inversion_model_3.eval()
    inversion_model_2.eval()


    if cfg.LOCAL_RANK == 0:
        progress_bar = tqdm.tqdm(
            total=len(dataloader), leave=True, desc="eval", dynamic_ncols=True
        )
    start_time = time.time()

    original_num_list = []
    reconstruction_num_list = []
    cd_list = []
    d_cd_list = []
    f1_list = []
    hd_list = []

    for i, batch_dict in enumerate(dataloader):
        load_data_to_gpu(batch_dict)
        with torch.no_grad():
            voxel_dict = detector_meanVFE(batch_dict)

            voxel_features, voxel_coords = (
                voxel_dict["voxel_features"],
                voxel_dict["voxel_coords"],
            )

            batch_size = voxel_dict["batch_size"]
            grid_size = np.array(cfg.DATA_CONFIG.GRID_SIZE)
            sparse_shape = grid_size[::-1] + [1, 0, 0]  

            input_sp_tensor = spconv.SparseConvTensor(
                features=voxel_features,
                indices=voxel_coords.int(),
                spatial_shape=sparse_shape,
                batch_size=batch_size,
            )
            input_sp_tensor_occupancy = (input_sp_tensor.dense()==0)
            input_sp_tensor_occupancy_mask = torch.logical_not(torch.sum(input_sp_tensor_occupancy, 1, keepdim=True)==input_sp_tensor.dense().shape[1])

            backbone_output = detector_backbone(voxel_dict)

            ret_dict = backbone_output["encoded_spconv_tensor"]            
            x_conv4 = backbone_output["multi_scale_3d_features"]["x_conv4_1"]
            x_conv3 = backbone_output["multi_scale_3d_features"]["x_conv3_1"]
            x_conv2 = backbone_output["multi_scale_3d_features"]["x_conv2_1"]  

            x_conv4_occupancy = (x_conv4.dense()==0)
            x_conv4_occupancy_mask = torch.logical_not(torch.sum(x_conv4_occupancy, 1, keepdim=True)==x_conv4.dense().shape[1])

            x_conv3_occupancy = (x_conv3.dense()==0)
            x_conv3_occupancy_mask = torch.logical_not(torch.sum(x_conv3_occupancy, 1, keepdim=True)==x_conv3.dense().shape[1])

            x_conv2_occupancy = (x_conv2.dense()==0)
            x_conv2_occupancy_mask = torch.logical_not(torch.sum(x_conv2_occupancy, 1, keepdim=True)==x_conv2.dense().shape[1])


            reconstruction_4_binary, reconstruction_4 = inversion_model_out(ret_dict)
            reconstruction_3_binary, reconstruction_3 = inversion_model_4(cls_filtering(reconstruction_4_binary, reconstruction_4))
            if start_layer==4:
                reconstruction_3_binary, reconstruction_3 = inversion_model_4(x_conv4)
            reconstruction_2_binary, reconstruction_2 = inversion_model_3(cls_filtering(reconstruction_3_binary, reconstruction_3))
            if start_layer==3:
                reconstruction_2_binary, reconstruction_2 = inversion_model_3(x_conv3)
            reconstruction_result_score = inversion_model_2(cls_filtering(reconstruction_2_binary, reconstruction_2)).dense()
            if start_layer==2:
                reconstruction_result_score = inversion_model_2(x_conv2).dense()

            reconstruction_result_binary = (torch.sigmoid(reconstruction_result_score) > 0.5)

            ### generate voxel center point
            point_cloud_range = cfg.DATA_CONFIG.POINT_CLOUD_RANGE

            voxel_x = np.linspace(start=point_cloud_range[0]+0.05/2, stop=point_cloud_range[3]-0.05/2, num=grid_size[0])
            voxel_y = np.linspace(start=point_cloud_range[1]+0.05/2, stop=point_cloud_range[4]-0.05/2, num=grid_size[1])
            voxel_z = np.linspace(start=point_cloud_range[2]+0.1/2, stop=(point_cloud_range[5]+0.1)-0.1/2, num=41)

            X, Y, Z = np.meshgrid(voxel_x, voxel_y, voxel_z)

            voxel_center_point = np.concatenate((X.transpose(2,0,1).reshape(1,41,grid_size[1],grid_size[0]), Y.transpose(2,0,1).reshape(1,41,grid_size[1],grid_size[0]), Z.transpose(2,0,1).reshape(1,41,grid_size[1],grid_size[0])), axis=0)
            voxel_center_point = torch.Tensor(voxel_center_point).unsqueeze(0)      
            voxel_center_point = voxel_center_point.expand(batch_size,voxel_center_point.shape[1],voxel_center_point.shape[2],voxel_center_point.shape[3],voxel_center_point.shape[4]) 
            voxel_center_point = voxel_center_point.cuda()
            ###

            reconstruction_result = voxel_center_point * reconstruction_result_binary

            ### CD and D-CD calculation
            original_data = input_sp_tensor.dense()[0,:3].reshape(3,-1).permute(1,0)
            original_voxel = original_data[original_data.nonzero()[:,0].unique()]
            original_num = original_voxel.shape[0]
            original_num_list.append(original_num)

            reconstruction_data = reconstruction_result[0,:3].reshape(3,-1).permute(1,0)
            reconstruction_voxel = reconstruction_data[reconstruction_data.nonzero()[:,0].unique()]
            reconstruction_num = reconstruction_voxel.shape[0]
            reconstruction_num_list.append(reconstruction_num)

            cd_all = calc_cd_full(reconstruction_voxel.unsqueeze(0), original_voxel.unsqueeze(0), f1_thr=cfg.F1_THR)
            cd_p = cd_all[3].detach().cpu().numpy()
            d_cd = cd_all[2].detach().cpu().numpy()
            f1 = cd_all[1].detach().cpu().numpy()
            hd = cd_all[0].detach().cpu().numpy()
            cd_list.append(cd_p)
            d_cd_list.append(d_cd)
            f1_list.append(f1)
            hd_list.append(hd)
            ###

        
        disp_dict = {}
        if cfg.LOCAL_RANK == 0:
            disp_dict.update(
                {
                    "cd": cd_p,
                    "d_cd": d_cd,
                    "f1": f1,
                    "hd": hd,
                    "num point": "("+str(reconstruction_num)+"/"+str(original_num)+")",
                }
            )
            progress_bar.update()
            progress_bar.set_postfix(disp_dict)
            progress_bar.refresh()

    root_dir = (Path(__file__).resolve().parent / "../../").resolve()
    output_dir = root_dir / "output" / "inversion_attack_spconv" / "evaluation" / cfg.INVERSION_MODEL.OUTPUT_DIR
    output_dir.mkdir(parents=True, exist_ok=True)

    if cfg.INVERSION_MODEL.SCHEDULER == 'LAYER1_FLAG':
        cfg.INVERSION_MODEL.LAYER = 'conv_1_spconv_cls'
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+'_reconstruction_num.npy'), np.array(reconstruction_num_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+'_cd.npy'), np.array(cd_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+'_d_cd.npy'), np.array(d_cd_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+'_f1.npy'), np.array(f1_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+'_hd.npy'), np.array(hd_list))
    print("cd, d_cd, f1, and hd of ("+cfg.INVERSION_MODEL.LAYER+'_from_xconv_'+str(start_layer)+") are saved!")


    if cfg.LOCAL_RANK == 0:
        progress_bar.close()

    if cfg.LOCAL_RANK != 0:
        return {}


    logger.info("Result is save to %s" % result_dir)
    logger.info("****************Evaluation done.*****************")
    return


if __name__ == "__main__":
    pass