# -*- coding:utf-8 -*-
# author: Xinge

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numba as nb
import multiprocessing
import torch_scatter
from torch_scatter import scatter


class CylinderVFE(nn.Module):

    def __init__(self, model_cfg, runtime_cfg, **kwargs):
    #def __init__(self, grid_size, fea_dim=3,
    #             out_pt_fea_dim=64, max_pt_per_encode=64, fea_compre=None):
        super().__init__() 
        grid_size = model_cfg.get("GRID_SIZE", [480, 360,  32])
        fea_dim = model_cfg.get("FEA_DIM", 3)
        out_pt_fea_dim = model_cfg.get("OUT_PT_FEA_DIM", 64)
        max_pt_per_encode = model_cfg.get("MAX_PT_PER_ENCODE", 64)
        fea_compre = model_cfg.get("FEA_COMPRE", None) 

        self.PPmodel = nn.Sequential(
            nn.BatchNorm1d(fea_dim),

            nn.Linear(fea_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256, out_pt_fea_dim)
        )

        self.max_pt = max_pt_per_encode
        self.fea_compre = fea_compre
        self.grid_size = grid_size
        kernel_size = 3
        self.local_pool_op = torch.nn.MaxPool2d(kernel_size, stride=1,
                                                padding=(kernel_size - 1) // 2,
                                                dilation=1)
        self.pool_dim = out_pt_fea_dim

        # point feature compression
        if self.fea_compre is not None:
            self.fea_compression = nn.Sequential(
                nn.Linear(self.pool_dim, self.fea_compre),
                nn.ReLU())
            self.pt_fea_dim = self.fea_compre
        else:
            self.pt_fea_dim = self.pool_dim
    
    def get_output_feature_dim(self):
        return self.pt_fea_dim

    # def forward(self, pt_fea, xy_ind):
    def forward(self, batch_dict):
        cat_pt_fea = batch_dict['cyl_pt_fea'].float() # [N_pt, 12]
        cat_pt_ind = batch_dict['cyl_grid_ind'] # [N_pt, 3]
        seg_labels = batch_dict['segmentation_label']
        ins_labels = batch_dict['instance_label']
        if 'instance_label_back' in batch_dict.keys():
            ins_labels = batch_dict['instance_label_back']
        
        point_wise_median_dict = {'segmentation_label': seg_labels, 'instance_label': ins_labels}
        # unique xy grid index
        unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0)
        unq = unq.type(torch.int64)

        # process feature
        processed_cat_pt_fea = self.PPmodel(cat_pt_fea)
        pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]
        
        voxel_index = unq_inv 
        num_voxels = unq.shape[0]
        degree = scatter(torch.ones_like(voxel_index), voxel_index, dim=0, dim_size=num_voxels, reduce='sum')
        offset = degree.cumsum(dim=0) - degree
        median_offset = offset + torch.div(degree, 2, rounding_mode='floor')
        for key, val in point_wise_median_dict.items():
            # val = val[~out_of_boundary_mask]
            max_val, min_val = val.max(), val.min()
            tval = (val - min_val) + voxel_index * max_val
            sorted_vals, indices = torch.sort(tval)
            voxel_median_val = sorted_vals[median_offset] - torch.arange(num_voxels).to(val) * max_val
            batch_dict['cyl_'+key] = voxel_median_val


        if self.fea_compre:
            processed_pooled_data = self.fea_compression(pooled_data)
        else:
            processed_pooled_data = pooled_data
        batch_dict['cyl_coords'] = unq 
        batch_dict['cyl_features_3d'] = processed_pooled_data 

        return batch_dict 

