import pickle
import time

import numpy as np
try:
    import kornia
except:
    pass
import torch
import tqdm

from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils
from pcdet.utils.spconv_utils import spconv

from pcdet.ops.chamfer_distance import ChamferDistance
from pcdet.ops.Density_aware_Chamfer_Distance.utils.model_utils import calc_cd_full

from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils

from pathlib import Path

def load_data_to_gpu(batch_dict):
    for key, val in batch_dict.items():
        if not isinstance(val, np.ndarray):
            continue
        elif key in ["frame_id", "metadata", "calib"]:
            continue
        elif key in ["images"]:
            batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous()
        elif key in ["image_shape"]:
            batch_dict[key] = torch.from_numpy(val).int().cuda()
        else:
            batch_dict[key] = torch.from_numpy(val).float().cuda()

def eval_one_epoch(
    cfg,
    detector_meanVFE,
    detector_backbone,
    inversion_model_naive,
    dataloader,
    epoch_id,
    logger,
    dist_test=False,
    save_to_file=False,
    result_dir=None,
):
    result_dir.mkdir(parents=True, exist_ok=True)

    final_output_dir = result_dir / "final_result" / "data"
    if save_to_file:
        final_output_dir.mkdir(parents=True, exist_ok=True)

    dataset = dataloader.dataset
    class_names = dataset.class_names
    det_annos = []

    logger.info("*************** EPOCH %s EVALUATION *****************" % epoch_id)
    
    detector_meanVFE.eval()
    inversion_model_naive.eval()

    if cfg.LOCAL_RANK == 0:
        progress_bar = tqdm.tqdm(
            total=len(dataloader), leave=True, desc="eval", dynamic_ncols=True
        )
    start_time = time.time()

    original_num_list = []
    reconstruction_num_list = []
    cd_list = []
    d_cd_list = []
    f1_list = []
    hd_list = []

    for i, batch_dict in enumerate(dataloader):
        load_data_to_gpu(batch_dict)
        with torch.no_grad():
            voxel_dict = detector_meanVFE(batch_dict)

            voxel_features, voxel_coords = (
                voxel_dict["voxel_features"],
                voxel_dict["voxel_coords"],
            )

            batch_size = voxel_dict["batch_size"]
            grid_size = np.array(cfg.DATA_CONFIG.GRID_SIZE)
            sparse_shape = grid_size[::-1] + [1, 0, 0]

            input_sp_tensor = spconv.SparseConvTensor(
                features=voxel_features,
                indices=voxel_coords.int(),
                spatial_shape=sparse_shape,
                batch_size=batch_size,
            )
            input_sp_tensor_occupancy = (input_sp_tensor.dense()==0)
            input_sp_tensor_occupancy_mask = torch.logical_not(torch.sum(input_sp_tensor_occupancy, 1, keepdim=True)==input_sp_tensor.dense().shape[1])

            backbone_output = detector_backbone(voxel_dict)

            input_feature = 0
            if cfg.INVERSION_MODEL.LAYER == 'xconv1_naive':
                input_feature = backbone_output["multi_scale_3d_features"]["x_conv1"]
            elif cfg.INVERSION_MODEL.LAYER == 'xconv1_2_naive':
                input_feature = backbone_output["multi_scale_3d_features"]["x_conv1_2"]
            elif cfg.INVERSION_MODEL.LAYER == 'xconv2_naive':
                input_feature = backbone_output["multi_scale_3d_features"]["x_conv2_1"]
            elif cfg.INVERSION_MODEL.LAYER == 'xconv3_naive':
                input_feature = backbone_output["multi_scale_3d_features"]["x_conv3_1"]
            elif cfg.INVERSION_MODEL.LAYER == 'xconv4_naive' or cfg.INVERSION_MODEL.LAYER == 'xconv4_naive_voxelresbackbone':
                input_feature = backbone_output["multi_scale_3d_features"]["x_conv4_1"]
            elif cfg.INVERSION_MODEL.LAYER == 'xconvout_naive' or cfg.INVERSION_MODEL.LAYER == 'xconvout_naive_voxelresbackbone':
                input_feature = backbone_output["encoded_spconv_tensor"]

            if input_feature == 0:
                print("Wrong cfg.inversion_model.layer!")
                exit()

            reconstruction_result = inversion_model_naive(input_feature).dense()
            ### CD and D-CD calculation
            original_data = input_sp_tensor.dense()[0,:3].reshape(3,-1).permute(1,0)
            original_voxel = original_data[original_data.nonzero()[:,0].unique()]
            original_num = original_voxel.shape[0]
            original_num_list.append(original_num)

            reconstruction_data = reconstruction_result[0,:3].reshape(3,-1).permute(1,0)
            reconstruction_voxel = reconstruction_data[reconstruction_data.nonzero()[:,0].unique()]
            reconstruction_num = reconstruction_voxel.shape[0]
            reconstruction_num_list.append(reconstruction_num)

            cd_all = calc_cd_full(reconstruction_voxel.unsqueeze(0), original_voxel.unsqueeze(0), f1_thr=cfg.F1_THR)
            cd_p = cd_all[3].detach().cpu().numpy()
            d_cd = cd_all[2].detach().cpu().numpy()
            f1 = cd_all[1].detach().cpu().numpy()
            hd = cd_all[0].detach().cpu().numpy()
            cd_list.append(cd_p)
            d_cd_list.append(d_cd)
            f1_list.append(f1)
            hd_list.append(hd)
            ###
        
        disp_dict = {}
        if cfg.LOCAL_RANK == 0:
            disp_dict.update(
                {
                    "cd": cd_p,
                    "d_cd": d_cd,
                    "f1": f1,
                    "hd": hd,
                    "num point": "("+str(reconstruction_num)+"/"+str(original_num)+")",
                }
            )
            progress_bar.update()
            progress_bar.set_postfix(disp_dict)
            progress_bar.refresh()

    root_dir = (Path(__file__).resolve().parent / "../../").resolve()
    output_dir = root_dir / "output" / "inversion_attack_spconv" / "evaluation" / cfg.INVERSION_MODEL.OUTPUT_DIR
    output_dir.mkdir(parents=True, exist_ok=True)

    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+'_reconstruction_num.npy'), np.array(reconstruction_num_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+'_cd.npy'), np.array(cd_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+'_d_cd.npy'), np.array(d_cd_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+'_f1.npy'), np.array(f1_list))
    np.save(output_dir/(cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+'_hd.npy'), np.array(hd_list))
    print("cd, d_cd, f1, and hd of ("+cfg.INVERSION_MODEL.LAYER+'_'+cfg.INVERSION_MODEL.LOSS.NAME+") are saved!")

    if cfg.LOCAL_RANK == 0:
        progress_bar.close()

    if cfg.LOCAL_RANK != 0:
        return {}


    logger.info("Result is save to %s" % result_dir)
    logger.info("****************Evaluation done.*****************")
    return


if __name__ == "__main__":
    pass