import functools
import math
from typing import Callable, Optional, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from hypernet_core import FIN_FOUT, OptBlock, TargetNet
from hypernet_core import get_embedder
from utils.examples.radiance_fields.ngp import trunc_exp
from utils.model_tools import hook_fn_decorator


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        output_dim: int = None,  # The number of output tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        hidden_init: Callable = nn.init.xavier_uniform_,
        hidden_activation: Callable = nn.ReLU(),
        output_enabled: bool = True,
        output_init: Optional[Callable] = nn.init.xavier_uniform_,
        output_activation: Optional[Callable] = nn.Identity(),
        bias_enabled: bool = True,
        bias_init: Callable = nn.init.zeros_,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.net_depth = net_depth
        self.net_width = net_width
        self.skip_layer = skip_layer
        self.hidden_init = hidden_init
        self.hidden_activation = hidden_activation
        self.output_enabled = output_enabled
        self.output_init = output_init
        self.output_activation = output_activation
        self.bias_enabled = bias_enabled
        self.bias_init = bias_init

        self.hidden_layers = nn.ModuleList()
        in_features = self.input_dim
        for i in range(self.net_depth):
            self.hidden_layers.append(
                nn.Linear(in_features, self.net_width, bias=bias_enabled)
            )
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                in_features = self.net_width + self.input_dim
            else:
                in_features = self.net_width
        if self.output_enabled:
            self.output_layer = nn.Linear(
                in_features, self.output_dim, bias=bias_enabled
            )
        else:
            self.output_dim = in_features

        self.initialize()

    def initialize(self):
        def init_func_hidden(m):
            if isinstance(m, nn.Linear):
                if self.hidden_init is not None:
                    self.hidden_init(m.weight)
                if self.bias_enabled and self.bias_init is not None:
                    self.bias_init(m.bias)

        self.hidden_layers.apply(init_func_hidden)
        if self.output_enabled:

            def init_func_output(m):
                if isinstance(m, nn.Linear):
                    if self.output_init is not None:
                        self.output_init(m.weight)
                    if self.bias_enabled and self.bias_init is not None:
                        self.bias_init(m.bias)

            self.output_layer.apply(init_func_output)

    def forward(self, x):
        inputs = x
        for i in range(self.net_depth):
            x = self.hidden_layers[i](x)
            x = self.hidden_activation(x)
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                x = torch.cat([x, inputs], dim=-1)
        if self.output_enabled:
            x = self.output_layer(x)
            x = self.output_activation(x)
        return x


class DenseLayer(MLP):
    def __init__(self, input_dim, output_dim, **kwargs):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            net_depth=0,  # no hidden layers
            **kwargs,
        )


class NerfMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        condition_dim: int,  # The number of condition tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
    ):
        super().__init__()
        self.base = MLP(
            input_dim=input_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            output_enabled=False,
        )
        hidden_features = self.base.output_dim
        self.sigma_layer = DenseLayer(hidden_features, 1)

        if condition_dim > 0:
            self.bottleneck_layer = DenseLayer(hidden_features, net_width)
            self.rgb_layer = MLP(
                input_dim=net_width + condition_dim,
                output_dim=3,
                net_depth=net_depth_condition,
                net_width=net_width_condition,
                skip_layer=None,
            )
        else:
            self.rgb_layer = DenseLayer(hidden_features, 3)

    def query_density(self, x):
        x = self.base(x)
        raw_sigma = self.sigma_layer(x)
        return raw_sigma

    def forward(self, x, condition=None):
        x = self.base(x)
        raw_sigma = self.sigma_layer(x)
        if condition is not None:
            if condition.shape[:-1] != x.shape[:-1]:
                num_rays, n_dim = condition.shape
                condition = condition.view(
                    [num_rays] + [1] * (x.dim() - condition.dim()) + [n_dim]
                ).expand(list(x.shape[:-1]) + [n_dim])
            bottleneck = self.bottleneck_layer(x)
            x = torch.cat([bottleneck, condition], dim=-1)
        raw_rgb = self.rgb_layer(x)
        return raw_rgb, raw_sigma


class SinusoidalEncoder(nn.Module):
    """Sinusoidal Positional Encoder used in Nerf."""

    def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
        super().__init__()
        self.x_dim = x_dim
        self.min_deg = min_deg
        self.max_deg = max_deg
        self.use_identity = use_identity
        self.register_buffer(
            "scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
        )

    @property
    def latent_dim(self) -> int:
        return (
            int(self.use_identity) + (self.max_deg - self.min_deg) * 2
        ) * self.x_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [..., x_dim]
        Returns:
            latent: [..., latent_dim]
        """
        if self.max_deg == self.min_deg:
            return x
        xb = torch.reshape(
            (x[Ellipsis, None, :] * self.scales[:, None]),
            list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
        )
        latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
        if self.use_identity:
            latent = torch.cat([x] + [latent], dim=-1)
        return latent


class VanillaNeRFRadianceField(nn.Module):
    def __init__(
        self,
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
    ) -> None:
        super().__init__()
        self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
        self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.mlp = NerfMLP(
            input_dim=self.posi_encoder.latent_dim,
            condition_dim=self.view_encoder.latent_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            net_depth_condition=net_depth_condition,
            net_width_condition=net_width_condition,
        )

    def query_opacity(self, x, step_size):
        density = self.query_density(x)
        # if the density is small enough those two are the same.
        # opacity = 1.0 - torch.exp(-density * step_size)
        opacity = density * step_size
        return opacity

    def query_density(self, x):
        x = self.posi_encoder(x)
        sigma = self.mlp.query_density(x)
        return F.relu(sigma)

    def forward(self, x, condition=None):
        x = self.posi_encoder(x)
        if condition is not None:
            condition = self.view_encoder(condition)
        rgb, sigma = self.mlp(x, condition=condition)
        return torch.sigmoid(rgb), F.relu(sigma)


class TNeRFRadianceField(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.posi_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.time_encoder = SinusoidalEncoder(1, 0, 4, True)
        self.warp = MLP(
            input_dim=self.posi_encoder.latent_dim
            + self.time_encoder.latent_dim,
            output_dim=3,
            net_depth=4,
            net_width=64,
            skip_layer=2,
            output_init=functools.partial(torch.nn.init.uniform_, b=1e-4),
        )
        self.nerf = VanillaNeRFRadianceField()

    def query_opacity(self, x, timestamps, step_size):
        idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
        t = timestamps[idxs]
        density = self.query_density(x, t)
        # if the density is small enough those two are the same.
        # opacity = 1.0 - torch.exp(-density * step_size)
        opacity = density * step_size
        return opacity

    def query_density(self, x, t):
        x = x + self.warp(
            torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
        )
        return self.nerf.query_density(x)

    def forward(self, x, t, condition=None):
        x = x + self.warp(
            torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
        )
        return self.nerf(x, condition=condition)


class NDRTNeRFRadianceField(nn.Module):

    """Invertble NN from https://arxiv.org/pdf/2206.15258.pdf"""

    def __init__(self) -> None:
        super().__init__()
        self.time_encoder = SinusoidalEncoder(1, 0, 4, True)
        self.warp_layers_1 = nn.ModuleList()
        self.time_layers_1 = nn.ModuleList()
        self.warp_layers_2 = nn.ModuleList()
        self.time_layers_2 = nn.ModuleList()
        self.posi_encoder_1 = SinusoidalEncoder(2, 0, 4, True)
        self.posi_encoder_2 = SinusoidalEncoder(1, 0, 4, True)
        for _ in range(3):
            self.warp_layers_1.append(
                MLP(
                    input_dim=self.posi_encoder_1.latent_dim + 64,
                    output_dim=1,
                    net_depth=2,
                    net_width=128,
                    skip_layer=None,
                    output_init=functools.partial(
                        torch.nn.init.uniform_, b=1e-4
                    ),
                )
            )
            self.warp_layers_2.append(
                MLP(
                    input_dim=self.posi_encoder_2.latent_dim + 64,
                    output_dim=1 + 2,
                    net_depth=1,
                    net_width=128,
                    skip_layer=None,
                    output_init=functools.partial(
                        torch.nn.init.uniform_, b=1e-4
                    ),
                )
            )
            self.time_layers_1.append(
                DenseLayer(
                    input_dim=self.time_encoder.latent_dim,
                    output_dim=64,
                )
            )
            self.time_layers_2.append(
                DenseLayer(
                    input_dim=self.time_encoder.latent_dim,
                    output_dim=64,
                )
            )

        self.nerf = VanillaNeRFRadianceField()

    def _warp(self, x, t_enc, i_layer):
        uv, w = x[:, :2], x[:, 2:]
        dw = self.warp_layers_1[i_layer](
            torch.cat(
                [self.posi_encoder_1(uv), self.time_layers_1[i_layer](t_enc)],
                dim=-1,
            )
        )
        w = w + dw
        rt = self.warp_layers_2[i_layer](
            torch.cat(
                [self.posi_encoder_2(w), self.time_layers_2[i_layer](t_enc)],
                dim=-1,
            )
        )
        r = self._euler2rot_2dinv(rt[:, :1])
        t = rt[:, 1:]
        uv = torch.bmm(r, (uv - t)[..., None]).squeeze(-1)
        return torch.cat([uv, w], dim=-1)

    def warp(self, x, t):
        t_enc = self.time_encoder(t)
        x = self._warp(x, t_enc, 0)
        x = x[..., [1, 2, 0]]
        x = self._warp(x, t_enc, 1)
        x = x[..., [2, 0, 1]]
        x = self._warp(x, t_enc, 2)
        return x

    def query_opacity(self, x, timestamps, step_size):
        idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
        t = timestamps[idxs]
        density = self.query_density(x, t)
        # if the density is small enough those two are the same.
        # opacity = 1.0 - torch.exp(-density * step_size)
        opacity = density * step_size
        return opacity

    def query_density(self, x, t):
        x = self.warp(x, t)
        return self.nerf.query_density(x)

    def forward(self, x, t, condition=None):
        x = self.warp(x, t)
        return self.nerf(x, condition=condition)

    def _euler2rot_2dinv(self, euler_angle):
        # (B, 1) -> (B, 2, 2)
        theta = euler_angle.reshape(-1, 1, 1)
        rot = torch.cat(
            (
                torch.cat((theta.cos(), -theta.sin()), 1),
                torch.cat((theta.sin(), theta.cos()), 1),
            ),
            2,
        )
        return rot


def density_activation(x):
    # x = x.clamp(max=15, min=-15)
    # mask = x <= 1
    # out = torch.where(mask, torch.exp(x - 1), x)
    # return out
    x = x / 20 - 1
    x = x.clamp(max=11, min=-11)
    return trunc_exp(x)
    # return nn.functional.relu(x)


class MLPDensityField(TargetNet):

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        multires=24,
        include_input=False,
        **kwargs
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.density_activation = density_activation
        self.embed_fn, self.input_ch = get_embedder(multires, include_input=False)

        self.mlp_base = nn.Sequential(
            nn.Linear(self.input_ch, 64, bias=False),
            nn.ReLU(),
            nn.Linear(64, 64, bias=False),
            # nn.ReLU(),
            # nn.Linear(64, 64),
            # nn.ReLU(),
            # nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1, bias=False),
        )
        # self._init_weight_()

    def _init_weight_(self):
        for x in self.mlp_base:
            if isinstance(x, nn.Linear):
                nn.init.normal_(x.weight, mean=0.0, std=0.05)
                if x.bias is not None:
                    nn.init.zeros_(x.bias)

    def forward(self, positions: torch.Tensor):
        aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
        positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        density_before_activation = (
            self.mlp_base(self.embed_fn(positions.view(-1, self.num_dim)))
            .view(list(positions.shape[:-1]) + [1])
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        # if density.requires_grad:
        #     density.register_hook(hook_fn_decorator("density", print_val=True))
        #     density_before_activation.register_hook(
        #         hook_fn_decorator("density_before_activation", print_val=True)
        #     )
        # print(f'checking density {density.min(), density.max()}')
        return density

    def get_submodules(self):
        for x in self.mlp_base:
            if isinstance(x, nn.Linear):
                yield x

    def construct_opt_blocks(self, ftask_dim, weight_dim, deriv_hidden_dim, driv_num_layers,
                             *args, in_dim=64, out_dim=64, **kwargs):
        forward_in = nn.ModuleList([FIN_FOUT(ftask_dim, in_dim, hidden_dim=deriv_hidden_dim, num_layers=driv_num_layers)])
        dloss_dout = nn.ModuleList([FIN_FOUT(ftask_dim, in_dim, hidden_dim=deriv_hidden_dim, num_layers=driv_num_layers)])
        opt_blocks = nn.ModuleList([
            OptBlock(module, ftask_dim, out_dim, weight_dim, deriv_hidden_dim, driv_num_layers, *args, **kwargs)
            for module in self.get_submodules()
        ])
        for i in range(len(opt_blocks)-1):
            opt_blocks[i].link(opt_blocks[i+1])

        forward_in[0].link(opt_blocks[0])
        opt_blocks[-1].link(dloss_dout[0])

        return opt_blocks, forward_in, dloss_dout
