"""
Density proposal field
"""
from typing import Optional, Callable
import logging as log

import torch
import torch.nn as nn

from plenoxels.models.utils import Network
from plenoxels.models.kplane_field import interpolate_ms_features, normalize_aabb, init_grid_param, init_gaussian
from plenoxels.raymarching.spatial_distortions import SpatialDistortion


class KPlaneDensityField(nn.Module):
    def __init__(self,
                 aabb,
                 resolution,
                 num_input_coords,
                 num_output_coords,
                 density_activation: Callable,
                 spatial_distortion: Optional[SpatialDistortion] = None,
                 linear_decoder: bool = True,
                 use_grid_kernel=False,
                 grid_kernel_size=11,
                 num_scales=4,
                 **kwargs):
        super().__init__()

        self.aabb = nn.Parameter(aabb, requires_grad=False)
        self.spatial_distortion = spatial_distortion
        self.hexplane = num_input_coords == 4
        self.feature_dim = num_output_coords
        self.density_activation = density_activation
        self.linear_decoder = linear_decoder
        activation = "ReLU"
        if self.linear_decoder:
            activation = "None"

        self.grids = init_grid_param(
            grid_nd=2, in_dim=num_input_coords, out_dim=num_output_coords,
            reso=resolution, a=0.1, b=0.15)
        if use_grid_kernel:
            self.kernels = nn.ModuleList()
            k = nn.ParameterList()
            for gp in self.grids:
                k.append(init_gaussian(gp.shape[1]*num_scales,
                                       (grid_kernel_size, grid_kernel_size)))
            self.kernels.append(k)
        else:
            self.kernels = None

        self.sigma_net = Network(
            n_input_dims=self.feature_dim,
            n_output_dims=1,
            network_config={
                "otype": "FullyFusedMLP",
                "activation": activation,
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )
        log.info(f"Initialized KPlaneDensityField. hexplane={self.hexplane} - "
                 f"resolution={resolution}")
        # log.info(f"KPlaneDensityField grids: \n{self.grids}")

    def get_density(self, pts: torch.Tensor, covs: torch.Tensor,
                    timestamps: Optional[torch.Tensor] = None):
        if self.spatial_distortion is not None:
            breakpoint() # not yet implemeneted
            pts = self.spatial_distortion(pts)
            pts = pts / 2  # from [-2, 2] to [-1, 1]
        else:
            pts = normalize_aabb(pts, self.aabb)

            # assume 'x, 'y, 'z = ax, ay, az
            # covs *= 2 / (self.aabb[1] - self.aabb[0]).mean()
        n_rays, n_samples = pts.shape[:2]
        if timestamps is not None and self.hexplane:
            timestamps = timestamps[:, None].expand(-1, n_samples)[..., None]  # [n_rays, n_samples, 1]
            pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]

        pts = pts.reshape(-1, pts.shape[-1])
        covs = covs.flatten(0, 1)

        features = interpolate_ms_features(
            pts, covs, ms_grids=[self.grids], grid_dimensions=2,
            concat_features=False, num_levels=None, kernels=self.kernels)
        density = self.density_activation(
            self.sigma_net(features).to(pts)
            #features.to(pts)
        ).view(n_rays, n_samples, 1)
        return density

    def forward(self, pts: torch.Tensor):
        return self.get_density(pts)

    def get_params(self):
        field_params = {k: v for k, v in self.grids.named_parameters(prefix="grids")}
        nn_params = {k: v for k, v in self.sigma_net.named_parameters(prefix="sigma_net")}
        other_params = {k: v for k, v in self.named_parameters() if (
            k not in nn_params.keys() and k not in field_params.keys()
        )}
        return {
            "nn": list(nn_params.values()),
            "field": list(field_params.values()),
            "other": list(other_params.values()),
        }
