from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import threestudio
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *


@threestudio.register("volume-grid")
class VolumeGrid(BaseImplicitGeometry):
    @dataclass
    class Config(BaseImplicitGeometry.Config):
        grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100))
        n_feature_dims: int = 3
        density_activation: Optional[str] = "softplus"
        density_bias: Union[float, str] = "blob"
        density_blob_scale: float = 5.0
        density_blob_std: float = 0.5
        normal_type: Optional[
            str
        ] = "finite_difference"  # in ['pred', 'finite_difference', 'finite_difference_laplacian']

        # automatically determine the threshold
        isosurface_threshold: Union[float, str] = "auto"

    cfg: Config

    def configure(self) -> None:
        super().configure()
        self.grid_size = self.cfg.grid_size

        self.grid = nn.Parameter(
            torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size)
        )
        if self.cfg.density_bias == "blob":
            self.register_buffer("density_scale", torch.tensor(0.0))
        else:
            self.density_scale = nn.Parameter(torch.tensor(0.0))

        if self.cfg.normal_type == "pred":
            self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size))

    def get_density_bias(self, points: Float[Tensor, "*N Di"]):
        if self.cfg.density_bias == "blob":
            # density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None]
            density_bias: Float[Tensor, "*N 1"] = (
                self.cfg.density_blob_scale
                * (
                    1
                    - torch.sqrt((points.detach() ** 2).sum(dim=-1))
                    / self.cfg.density_blob_std
                )[..., None]
            )
            return density_bias
        elif isinstance(self.cfg.density_bias, float):
            return self.cfg.density_bias
        else:
            raise AttributeError(f"Unknown density bias {self.cfg.density_bias}")

    def get_trilinear_feature(
        self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"]
    ) -> Float[Tensor, "*N Df"]:
        points_shape = points.shape[:-1]
        df = grid.shape[1]
        di = points.shape[-1]
        out = F.grid_sample(
            grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear"
        )
        out = out.reshape(df, -1).T.reshape(*points_shape, df)
        return out

    def forward(
        self, points: Float[Tensor, "*N Di"], output_normal: bool = False
    ) -> Dict[str, Float[Tensor, "..."]]:
        points_unscaled = points  # points in the original scale
        points = contract_to_unisphere(
            points, self.bbox, self.unbounded
        )  # points normalized to (0, 1)
        points = points * 2 - 1  # convert to [-1, 1] for grid sample

        out = self.get_trilinear_feature(points, self.grid)
        density, features = out[..., 0:1], out[..., 1:]
        density = density * torch.exp(self.density_scale)  # exp scaling in DreamFusion

        # breakpoint()
        density = get_activation(self.cfg.density_activation)(
            density + self.get_density_bias(points_unscaled)
        )

        output = {
            "density": density,
            "features": features,
        }

        if output_normal:
            if (
                self.cfg.normal_type == "finite_difference"
                or self.cfg.normal_type == "finite_difference_laplacian"
            ):
                eps = 1.0e-3
                if self.cfg.normal_type == "finite_difference_laplacian":
                    offsets: Float[Tensor, "6 3"] = torch.as_tensor(
                        [
                            [eps, 0.0, 0.0],
                            [-eps, 0.0, 0.0],
                            [0.0, eps, 0.0],
                            [0.0, -eps, 0.0],
                            [0.0, 0.0, eps],
                            [0.0, 0.0, -eps],
                        ]
                    ).to(points_unscaled)
                    points_offset: Float[Tensor, "... 6 3"] = (
                        points_unscaled[..., None, :] + offsets
                    ).clamp(-self.cfg.radius, self.cfg.radius)
                    density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
                        points_offset
                    )
                    normal = (
                        -0.5
                        * (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
                        / eps
                    )
                else:
                    offsets: Float[Tensor, "3 3"] = torch.as_tensor(
                        [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
                    ).to(points_unscaled)
                    points_offset: Float[Tensor, "... 3 3"] = (
                        points_unscaled[..., None, :] + offsets
                    ).clamp(-self.cfg.radius, self.cfg.radius)
                    density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
                        points_offset
                    )
                    normal = -(density_offset[..., 0::1, 0] - density) / eps
                normal = F.normalize(normal, dim=-1)
            elif self.cfg.normal_type == "pred":
                normal = self.get_trilinear_feature(points, self.normal_grid)
                normal = F.normalize(normal, dim=-1)
            else:
                raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
            output.update({"normal": normal, "shading_normal": normal})
        return output

    def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
        points_unscaled = points
        points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
        points = points * 2 - 1  # convert to [-1, 1] for grid sample

        out = self.get_trilinear_feature(points, self.grid)
        density = out[..., 0:1]
        density = density * torch.exp(self.density_scale)

        density = get_activation(self.cfg.density_activation)(
            density + self.get_density_bias(points_unscaled)
        )
        return density

    def forward_field(
        self, points: Float[Tensor, "*N Di"]
    ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
        if self.cfg.isosurface_deformable_grid:
            threestudio.warn(
                f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
            )
        density = self.forward_density(points)
        return density, None

    def forward_level(
        self, field: Float[Tensor, "*N 1"], threshold: float
    ) -> Float[Tensor, "*N 1"]:
        return -(field - threshold)

    def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
        out: Dict[str, Any] = {}
        if self.cfg.n_feature_dims == 0:
            return out
        points_unscaled = points
        points = contract_to_unisphere(points, self.bbox, self.unbounded)
        points = points * 2 - 1  # convert to [-1, 1] for grid sample
        features = self.get_trilinear_feature(points, self.grid)[..., 1:]
        out.update(
            {
                "features": features,
            }
        )
        return out
