from dataclasses import dataclass, field

import torch
import torch.nn as nn
import torch.nn.functional as F

import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *


@threestudio.register("textured-background")
class TexturedBackground(BaseBackground):
    @dataclass
    class Config(BaseBackground.Config):
        n_output_dims: int = 3
        height: int = 64
        width: int = 64
        color_activation: str = "sigmoid"

    cfg: Config

    def configure(self) -> None:
        self.texture = nn.Parameter(
            torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width))
        )

    def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]:
        x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2]
        xy = (x**2 + y**2) ** 0.5
        u = torch.atan2(xy, z) / torch.pi
        v = torch.atan2(y, x) / (torch.pi * 2) + 0.5
        uv = torch.stack([u, v], -1)
        return uv

    def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]:
        dirs_shape = dirs.shape[:-1]
        uv = self.spherical_xyz_to_uv(dirs)
        uv = 2 * uv - 1  # rescale to [-1, 1] for grid_sample
        uv = uv.reshape(1, -1, 1, 2)
        color = (
            F.grid_sample(
                self.texture,
                uv,
                mode="bilinear",
                padding_mode="reflection",
                align_corners=False,
            )
            .reshape(self.cfg.n_output_dims, -1)
            .T.reshape(*dirs_shape, self.cfg.n_output_dims)
        )
        color = get_activation(self.cfg.color_activation)(color)
        return color
