import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F

from plenoxels.models.utils import Encoding, Network
from plenoxels.ops.interpolation import grid_sample_wrapper
from plenoxels.raymarching.spatial_distortions import SpatialDistortion


def get_normalized_directions(directions):
    """SH encoding must be in the range [0, 1]

    Args:
        directions: batch of directions
    """
    return (directions + 1.0) / 2.0


def normalize_aabb(pts, aabb):
    return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0


def init_grid_param(
        grid_nd: int,
        in_dim: int,
        out_dim: int,
        reso: Sequence[int],
        a: float = 0.1,
        b: float = 0.5):
    assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
    has_time_planes = in_dim == 4
    assert grid_nd <= in_dim
    coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
    grid_coefs = nn.ParameterList()
    for ci, coo_comb in enumerate(coo_combs):
        new_grid_coef = nn.Parameter(torch.empty(
            [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
        ))
        if has_time_planes and 3 in coo_comb:  # Initialize time planes to 1
            nn.init.ones_(new_grid_coef)
        else:
            nn.init.uniform_(new_grid_coef, a=a, b=b)
        grid_coefs.append(new_grid_coef)

    return grid_coefs


# DEFAULT: (35.9871, 0.9822, 0.9950)
# 34.8908 (x1) 33.7284 (x2)
@torch.no_grad()
def get_ellipses(covs):
    # for numerical stability
    # scale = covs[..., :1] # .mean(dim=-1, keepdims=True).clamp(max=1)
    # if scale > 1:
    #     scale = 1

    a2 = (covs[..., 0] / covs[..., :2].sqrt().prod(dim=-1)).square()
    b2 = (covs[..., 1] / covs[..., :2].sqrt().prod(dim=-1)).square()
    assert (a2 == 0).sum() == 0
    assert (b2 == 0).sum() == 0
    beta = covs[..., 2]
    gamma = covs[..., 3]

    sin_beta = torch.sin(beta)
    sin_gamma = torch.sin(gamma)
    cos_beta = torch.cos(beta)
    cos_gamma = torch.cos(gamma)

    sin_beta_square = torch.square(sin_beta)
    cos_beta_square = 1 - sin_beta_square
    sin_gamma_square = torch.square(sin_gamma)
    cos_gamma_square = 1 - sin_gamma_square

    sin_cos_beta = sin_beta * cos_beta
    sin_cos_gamma = sin_gamma * cos_gamma

    # Ratation matrix (assuming yaw = 0)
    # [cos B,  sin B sin G, sin B cos G]
    # [0,      cos G,       - sin G    ]
    # [-sin B, cos B sin G, cos B cos G]

    # Inverse...
    # [cos B, sin B sin G,   -sin B cos G]
    # [0,     cos G,         sin G       ]
    # [sin B, - cos B sin G, cos B cos G ]

    # (X2 + Y2) / A2 + Z2 / B2 = 1

    # (cos B x + sin B z)2 / A2
    # + (sin B sin G x + cos G y - cos B sin G z)2 / A2
    # + (- sin B cos G x + sin G y + cos B cos G z)2 / B2
    # - 1 = 0

    # (cos2 B / A2 + sin2 B sin2 G / A2 + sin2 B cos2 G / B2) x2
    # + (cos2 G / A2 + sin2 G / B2) y2
    # + (sin2 B / A2 + cos2 B sin2 G / A2 + cos2 B cos2 G / B2) z2
    # + 2 (sin B sin G cos G / A2 - sin B sin G cos G / B2) xy
    # + 2 (cos B sin B / A2 - sin B cos B sin2 G / A2 - sin B cos B cos2 G / B2) xz
    # + 2 (- cos B sin G cos G / A2 + cos B sin G cos G / B2) yz
    # - 1 = 0

    c_x2 = (cos_beta_square + sin_beta_square * sin_gamma_square) / a2 \
         + sin_beta_square * cos_gamma_square / b2
    c_y2 = cos_gamma_square / a2 + sin_gamma_square / b2
    c_z2 = (sin_beta_square + cos_beta_square * sin_gamma_square) / a2 \
         + cos_beta_square * cos_gamma_square / b2
    c_xy = 2 * (sin_beta * sin_cos_gamma * (1 / a2 - 1 / b2))
    c_xz = 2 * (sin_cos_beta * (1 - sin_gamma_square) / a2
                - sin_cos_beta * cos_gamma_square / b2)
    c_yz = 2 * (cos_beta * sin_cos_gamma * (1 / b2 - 1 / a2))

    a_xy, b_xy, theta_xy = get_ellipse_info(c_x2, c_xy, c_y2, -1)
    a_xz, b_xz, theta_xz = get_ellipse_info(c_x2, c_xz, c_z2, -1)
    a_yz, b_yz, theta_yz = get_ellipse_info(c_y2, c_yz, c_z2, -1)

    n_scale = covs[..., :2].sqrt().prod(dim=-1) # covs[..., 1] # scale # ** 0.5

    return ((a_xy * n_scale, b_xy * n_scale, theta_xy),
            (a_xz * n_scale, b_xz * n_scale, theta_xz),
            (a_yz * n_scale, b_yz * n_scale, theta_yz))


def get_ellipse_info(A, B, C, F, eps=1e-5):
    # 2. ELLIPSE
    # AX2 + BXY + CY2 + F = 0
    # a, b = - sqrt(2(B2 - 4AC) * F * ((A + C) \pm sqrt((A - C)2 + B2)))
    #      / (B2 - 4AC)
    # theta = if B != 0, arccot (C - A - sqrt((A - C)2 + B2)) / B
    #         elif A < C, 0
    #         elif A > 0, pi / 2
    assert A.isnan().sum() == 0
    assert B.isnan().sum() == 0
    assert C.isnan().sum() == 0

    B2 = B.square()
    A_C_square = (A - C).square()
    sqrt_A_C_square_B2 = (A_C_square + B2).clamp(min=0).sqrt()
    A_C = A + C
    B2_4AC = B2 - 4 * A * C

    temp = (2 * F / B2_4AC).clamp(min=eps).sqrt()
    a = temp * (A_C + sqrt_A_C_square_B2).clamp(min=eps).sqrt()
    # a = temp * ((A + C) + (A_C_square + B2).clamp(min=0).sqrt()).clamp(min=0).sqrt()
    b = temp * (A_C - sqrt_A_C_square_B2).clamp(min=eps).sqrt()
    # b = temp * ((A + C) - (A_C_square + B2).clamp(min=0).sqrt()).clamp(min=0).sqrt()
    # theta = torch.atan2(B, (C - A - (A_C_square + B2).clamp(min=0).sqrt()))
    theta = torch.atan2((-C + A + sqrt_A_C_square_B2), B) # (A_C_square + B2).clamp(min=0).sqrt()))

    assert a.isnan().sum() == 0
    assert (a == 0).sum() == 0
    assert b.isnan().sum() == 0
    assert (b == 0).sum() == 0
    assert theta.isnan().sum() == 0

    return a, b, theta


def init_gaussian(out_chan, k_size: tuple, sigma=1):
    meshgrids = torch.meshgrid([
        torch.arange(s, dtype=torch.float32) for s in k_size
    ], indexing='ij')

    kernel = 1
    for size, mgrid in zip(k_size, meshgrids):
        mean = (size - 1) / 2
        kernel *= 1 / (sigma * ((2 * torch.pi) ** 0.5)) \
                * torch.exp(-0.5 * ((mgrid - mean) / sigma) ** 2)
    kernel = kernel / torch.sum(kernel)
    return kernel.repeat(out_chan, 1, 1, 1)


def interpolate_ms_features(pts: torch.Tensor,
                            covs: torch.Tensor,
                            ms_grids: Collection[Iterable[nn.Module]],
                            grid_dimensions: int,
                            concat_features: bool,
                            num_levels: Optional[int],
                            kernels: Optional[Collection[Iterable[nn.Module]]],
                            ) -> torch.Tensor:
    coo_combs = list(itertools.combinations(
        range(pts.shape[-1]), grid_dimensions)
    )
    if num_levels is None:
        num_levels = len(ms_grids)
    multi_scale_interp = [] if concat_features else 0.
    grid: nn.ParameterList

    # res = ((covs[..., -1] / 0.0004).round().log2() / (4 - 1) - 0.5) * 2
    # res = res.unsqueeze(-1)
    res = covs[..., -1:]

    '''
    ellipses = get_ellipses(covs)
    for i, e in enumerate(ellipses):
        for j, c in enumerate(e):
            assert c.isnan().sum() == 0
    '''

    for scale_id, grid in enumerate(ms_grids[:num_levels]):
        interp_space = 1.
        for ci, coo_comb in enumerate(coo_combs):
            # interpolate in plane
            # shape of grid[ci]: 1, out_dim, *reso
            feature_dim = grid[ci].shape[1]

            '''
            cos = torch.cos(ellipses[ci][2])
            cos_square = cos * cos
            sin = torch.sin(ellipses[ci][2])
            scale = torch.mean(ellipses[ci][1])
            a_square = (ellipses[ci][0] / scale) ** 2
            b_square = (ellipses[ci][1] / scale) ** 2

            # circumference
            # A = cos_square / a_square + (1-cos_square) / b_square
            # B _square = ((-1/a_square + 1/b_square) * cos * sin) ** 2
            # C = (1-cos_square) / a_square + cos_square / b_square
            A = cos_square * b_square + (1-cos_square) * a_square
            B_square = ((-b_square + a_square) * cos * sin) ** 2
            C = (1-cos_square) * b_square + cos_square * a_square

            offset = A*C - B_square
            offset = torch.stack([C / offset, A / offset], -1) * scale
            assert (offset == 0).sum() == 0

            offset = offset.reshape(*pts.shape[:-1], 2)
            '''
            new_grid = grid[ci]
            new_pts = pts[..., coo_comb]

            if kernels is not None:
                n_scales = kernels[scale_id][ci].shape[0] // new_grid.shape[1]
                new_grid = F.conv2d(
                    new_grid.repeat(1, n_scales, 1, 1),
                    F.normalize(kernels[scale_id][ci], dim=(-2, -1)),
                    padding='same', groups=kernels[scale_id][ci].shape[0]) \
                    .reshape(1, n_scales, -1, *new_grid.shape[-2:]) \
                    .transpose(1, 2)
                new_pts = torch.cat([new_pts, res.reshape(*pts.shape[:-1], 1)],
                                    -1)

            interp_out_plane = (
                grid_sample_wrapper(new_grid, new_pts).view(-1, feature_dim))

            assert interp_out_plane.isnan().sum() == 0
            assert grid[ci].isnan().sum() == 0

            # compute product over planes
            interp_space = interp_space * interp_out_plane

        assert interp_out_plane.isnan().sum() == 0, torch.where(interp_out_plane.isnan())
        assert interp_out_plane.isinf().sum() == 0, torch.where(interp_out_plane.isnan())

        # combine over scales
        if concat_features:
            multi_scale_interp.append(interp_space)
        else:
            multi_scale_interp = multi_scale_interp + interp_space

    if concat_features:
        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
    assert multi_scale_interp.isnan().sum() == 0
    return multi_scale_interp


class KPlaneField(nn.Module):
    def __init__(
        self,
        aabb,
        grid_config: Union[str, List[Dict]],
        concat_features_across_scales: bool,
        multiscale_res: Optional[Sequence[int]],
        use_appearance_embedding: bool,
        appearance_embedding_dim: int,
        spatial_distortion: Optional[SpatialDistortion],
        density_activation: Callable,
        linear_decoder: bool,
        linear_decoder_layers: Optional[int],
        num_images: Optional[int],
        use_grid_kernel=False,
        grid_kernel_size=11,
        num_scales=4,
        **kwargs,
    ) -> None:
        super().__init__()

        self.aabb = nn.Parameter(aabb, requires_grad=False)
        self.spatial_distortion = spatial_distortion
        self.grid_config = grid_config

        self.multiscale_res_multipliers: List[int] = multiscale_res or [1]
        self.concat_features = concat_features_across_scales
        self.density_activation = density_activation
        self.linear_decoder = linear_decoder

        # 1. Init planes
        self.grids = nn.ModuleList()
        self.feature_dim = 0
        for res in self.multiscale_res_multipliers:
            # initialize coordinate grid
            config = self.grid_config[0].copy()
            # Resolution fix: multi-res only on spatial planes
            config["resolution"] = [
                r * res for r in config["resolution"][:3]
            ] + config["resolution"][3:]
            gp = init_grid_param(
                grid_nd=config["grid_dimensions"],
                in_dim=config["input_coordinate_dim"],
                out_dim=config["output_coordinate_dim"],
                reso=config["resolution"],
            )
            # shape[1] is out-dim - Concatenate over feature len for each scale
            if self.concat_features:
                self.feature_dim += gp[-1].shape[1]
            else:
                self.feature_dim = gp[-1].shape[1]
            self.grids.append(gp)
        # log.info(f"Initialized model grids: {self.grids}")

        if use_grid_kernel:
            self.kernels = nn.ModuleList()
            for gp in self.grids:
                k = nn.ParameterList()
                for i in range(len(gp)):
                    k.append(init_gaussian(gp[i].shape[1]*num_scales,
                                           (grid_kernel_size, grid_kernel_size)))
                self.kernels.append(k)
        else:
            self.kernels = None

        # 2. Init appearance code-related parameters
        self.use_average_appearance_embedding = True  # for test-time
        self.use_appearance_embedding = use_appearance_embedding
        self.num_images = num_images
        self.appearance_embedding = None
        if use_appearance_embedding:
            assert self.num_images is not None
            self.appearance_embedding_dim = appearance_embedding_dim
            # this will initialize as normal_(0.0, 1.0)
            self.appearance_embedding = nn.Embedding(self.num_images, self.appearance_embedding_dim)
        else:
            self.appearance_embedding_dim = 0

        # 3. Init decoder params
        self.direction_encoder = Encoding(
            n_input_dims=3,
            encoding_config={
                "otype": "SphericalHarmonics",
                "degree": 4,
            },
        )

        # 3. Init decoder network
        if self.linear_decoder:
            assert linear_decoder_layers is not None
            # The NN learns a basis that is used instead of spherical harmonics
            # Input is an encoded view direction, output is weights for
            # combining the color features into RGB
            # This architecture is based on instant-NGP
            self.color_basis = Network(
                n_input_dims=3 + self.appearance_embedding_dim,#self.direction_encoder.n_output_dims,
                n_output_dims=3 * self.feature_dim,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 128,
                    "n_hidden_layers": linear_decoder_layers,
                },
            )
            # sigma_net just does a linear transformation on the features to get density
            self.sigma_net = Network(
                n_input_dims=self.feature_dim,
                n_output_dims=1,
                network_config={
                    "otype": "CutlassMLP",
                    "activation": "None",
                    "output_activation": "None",
                    "n_neurons": 128,
                    "n_hidden_layers": 0,
                },
            )
        else:
            self.geo_feat_dim = 15
            self.sigma_net = Network(
                n_input_dims=self.feature_dim,
                n_output_dims=self.geo_feat_dim + 1,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": 64,
                    "n_hidden_layers": 1,
                },
            )
            self.in_dim_color = (
                    self.direction_encoder.n_output_dims
                    + self.geo_feat_dim
                    + self.appearance_embedding_dim
            )
            self.color_net = Network(
                n_input_dims=self.in_dim_color,
                n_output_dims=3,
                network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "Sigmoid",
                    "n_neurons": 64,
                    "n_hidden_layers": 2,
                },
            )

    def get_density(self, pts: torch.Tensor, covs: torch.Tensor,
                    timestamps: Optional[torch.Tensor] = None):
        """Computes and returns the densities."""
        if self.spatial_distortion is not None:
            breakpoint()
            pts = self.spatial_distortion(pts)
            pts = pts / 2  # from [-2, 2] to [-1, 1]
        else:
            pts = normalize_aabb(pts, self.aabb)
            scale = 2 / (self.aabb[1] - self.aabb[0]).mean()
            covs[..., :2] *= scale
        n_rays, n_samples = pts.shape[:2]
        if timestamps is not None:
            timestamps = timestamps[:, None].expand(-1, n_samples)[..., None]  # [n_rays, n_samples, 1]
            pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]

        pts = pts.reshape(-1, pts.shape[-1])
        features = interpolate_ms_features(
            pts, covs,
            ms_grids=self.grids,  # noqa
            grid_dimensions=self.grid_config[0]["grid_dimensions"],
            concat_features=self.concat_features, num_levels=None,
            kernels=self.kernels)
        assert features.isnan().sum() + features.isinf().sum() == 0
        if len(features) < 1:
            features = torch.zeros((0, 1)).to(features.device)
        if self.linear_decoder:
            density_before_activation = self.sigma_net(features)  # [batch, 1]
        else:
            features = self.sigma_net(features)
            features, density_before_activation = torch.split(
                features, [self.geo_feat_dim, 1], dim=-1)

        assert features.isnan().sum() + features.isinf().sum() == 0
        assert density_before_activation.isnan().sum() + density_before_activation.isinf().sum() == 0

        density = self.density_activation(
            density_before_activation.to(pts)
        ).view(n_rays, n_samples, 1)
        return density, features

    def forward(self,
                pts: torch.Tensor,
                covs: torch.Tensor,
                directions: torch.Tensor,
                timestamps: Optional[torch.Tensor] = None):
        camera_indices = None
        if self.use_appearance_embedding:
            if timestamps is None:
                raise AttributeError("timestamps (appearance-ids) are not provided.")
            camera_indices = timestamps
            timestamps = None
        density, features = self.get_density(pts, covs, timestamps)
        n_rays, n_samples = pts.shape[:2]

        directions = directions.view(-1, 1, 3).expand(pts.shape).reshape(-1, 3)
        if not self.linear_decoder:
            directions = get_normalized_directions(directions)
            encoded_directions = self.direction_encoder(directions)

        assert features.isnan().sum() + features.isinf().sum() == 0

        if self.linear_decoder:
            color_features = [features]
        else:
            color_features = [encoded_directions, features.view(-1, self.geo_feat_dim)]

        if self.use_appearance_embedding:
            if camera_indices.dtype == torch.float32:
                # Interpolate between two embeddings. Currently they are hardcoded below.
                #emb1_idx, emb2_idx = 100, 121  # trevi
                emb1_idx, emb2_idx = 11, 142  # sacre
                emb_fn = self.appearance_embedding
                emb1 = emb_fn(torch.full_like(camera_indices, emb1_idx, dtype=torch.long))
                emb1 = emb1.view(emb1.shape[0], emb1.shape[2])
                emb2 = emb_fn(torch.full_like(camera_indices, emb2_idx, dtype=torch.long))
                emb2 = emb2.view(emb2.shape[0], emb2.shape[2])
                embedded_appearance = torch.lerp(emb1, emb2, camera_indices)
            elif self.training:
                embedded_appearance = self.appearance_embedding(camera_indices)
            else:
                if hasattr(self, "test_appearance_embedding"):
                    embedded_appearance = self.test_appearance_embedding(camera_indices)
                elif self.use_average_appearance_embedding:
                    embedded_appearance = torch.ones(
                        (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
                    ) * self.appearance_embedding.mean(dim=0)
                else:
                    embedded_appearance = torch.zeros(
                        (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
                    )

            # expand embedded_appearance from n_rays, dim to n_rays*n_samples, dim
            ea_dim = embedded_appearance.shape[-1]
            embedded_appearance = embedded_appearance.view(-1, 1, ea_dim).expand(n_rays, n_samples, -1).reshape(-1, ea_dim)
            if not self.linear_decoder:
                color_features.append(embedded_appearance)

        color_features = torch.cat(color_features, dim=-1)
        assert color_features.isnan().sum() == 0

        if self.linear_decoder:
            if self.use_appearance_embedding:
                basis_values = self.color_basis(torch.cat([directions, embedded_appearance], dim=-1))
            else:
                basis_values = self.color_basis(directions)  # [batch, color_feature_len * 3]
            basis_values = basis_values.view(color_features.shape[0], 3, -1)  # [batch, 3, color_feature_len]
            rgb = torch.sum(color_features[:, None, :] * basis_values, dim=-1)  # [batch, 3]
            rgb = rgb.to(directions)
            rgb = torch.sigmoid(rgb).view(n_rays, n_samples, 3)
        else:
            rgb = self.color_net(color_features).to(directions).view(n_rays, n_samples, 3)

        rgb = torch.nan_to_num(rgb)
        assert rgb.isnan().sum() == 0
        assert density.isnan().sum() == 0

        return {"rgb": rgb, "density": density}

    def get_params(self):
        field_params = {k: v for k, v in self.grids.named_parameters(prefix="grids")}
        nn_params = [
            self.sigma_net.named_parameters(prefix="sigma_net"),
            self.direction_encoder.named_parameters(prefix="direction_encoder"),
        ]
        if self.linear_decoder:
            nn_params.append(self.color_basis.named_parameters(prefix="color_basis"))
        else:
            nn_params.append(self.color_net.named_parameters(prefix="color_net"))
        nn_params = {k: v for plist in nn_params for k, v in plist}
        other_params = {k: v for k, v in self.named_parameters() if (
            k not in nn_params.keys() and k not in field_params.keys()
        )}
        return {
            "nn": list(nn_params.values()),
            "field": list(field_params.values()),
            "other": list(other_params.values()),
        }


if __name__ == '__main__':
    covs = torch.tensor(
        [[1, 2, 0, 0],
         [1, 2, torch.pi, 0]])
    print(get_ellipses(covs))

