# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
TensorRF implementation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Literal, Tuple, Type, cast, Sequence

import numpy as np
import torch
from torch.nn import Parameter

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.field_components.encodings import NeRFEncoding, GaplaneEncoding, Identity, TensorVMEncoding ###
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.fields.gaplanes_field import GAplanesField ###
from nerfstudio.fields.gaplanes_shared_field import GAplanesSharedField ###
from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared, tv_loss
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer
from nerfstudio.model_components.scene_colliders import AABBBoxCollider
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps, colors, misc

import ipdb
@dataclass
class GAplanesModelConfig(ModelConfig):
    """GAplanes model config"""

    _target: Type = field(default_factory=lambda: GAplanesModel)
    """target class to instantiate"""
    init_resolution: Sequence[int] = (128, 128, 64)
    """initial render resolution"""
    final_resolution: Sequence[int] = (300, 300, 64)
    """final render resolution"""
    upsampling_iters: Tuple[int, ...] = (2000, 3000, 4000, 5500, 7000)
    """specifies a list of iteration step numbers to perform upsampling"""
    loss_coefficients: Dict[str, float] = to_immutable_dict(
        {
            "rgb_loss": 1.0,
            "tv_reg_density": 1e-3,
            "tv_reg_color": 1e-4,
            "l1_reg": 5e-4,
        }
    )
    """Loss specific weights."""
    num_samples: int = 50
    """Number of samples in field evaluation"""
    num_uniform_samples: int = 200
    """Number of samples in density evaluation"""
    num_den_components: Sequence[int] = (16, 16, 8) # int = 16
    """Number of components in density encoding"""
    num_color_components: Sequence[int] = (48, 48, 16) # int = 48
    """Number of components in color encoding"""
    reduce: Literal["concat", "product"] = "product"
    """How to form the features"""
    shared_features: bool = True
    """sharing features across density and color tasks"""
    regularization: Literal["none", "l1", "tv"] = "l1"
    """Regularization method"""
    camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="SO3xR3"))
    """Config of the camera optimizer to use"""
    use_gradient_scaling: bool = False
    """Use gradient scaler where the gradients are lower for points closer to the camera."""
    background_color: Literal["random", "last_sample", "black", "white"] = "white"
    """Whether to randomize the background color."""


class GAplanesModel(Model):
    """GAplanes Model

    Args:
        config: GAplanes configuration to instantiate model
    """

    config: GAplanesModelConfig

    def __init__(
        self,
        config: GAplanesModelConfig,
        **kwargs,
    ) -> None:
        self.init_resolution = config.init_resolution
        self.upsampling_iters = config.upsampling_iters
        self.num_den_components = config.num_den_components
        self.reduce = config.reduce
        self.shared_features = config.shared_features
        self.num_color_components = config.num_color_components
        self.upsampling_steps = (
            np.round(
                np.exp(
                    np.linspace(
                        np.log(config.init_resolution),
                        np.log(config.final_resolution),
                        len(config.upsampling_iters) + 1,
                    )
                )
            )
            .astype("int")
            .tolist()[1:]
        )
        super().__init__(config=config, **kwargs)

    def get_training_callbacks(
        self, training_callback_attributes: TrainingCallbackAttributes
    ) -> List[TrainingCallback]:
        # the callback that we want to run every X iterations after the training iteration
        def reinitialize_optimizer(self, training_callback_attributes: TrainingCallbackAttributes, step: int):
            assert training_callback_attributes.optimizers is not None
            assert training_callback_attributes.pipeline is not None
            index = self.upsampling_iters.index(step)
            resolution = self.upsampling_steps[index]

            if self.shared_features:
                self.field.density_rgb_encoding.upsample_grid(resolution)
            else:
                # upsample the position and direction grids
                self.field.density_encoding.upsample_grid(resolution) #### del this [0]
                self.field.color_encoding.upsample_grid(resolution)

            # reinitialize the encodings optimizer
            optimizers_config = training_callback_attributes.optimizers.config
            enc = training_callback_attributes.pipeline.get_param_groups()["encodings"]
            lr_init = optimizers_config["encodings"]["optimizer"].lr

            training_callback_attributes.optimizers.optimizers["encodings"] = optimizers_config["encodings"][
                "optimizer"
            ].setup(params=enc)
            if optimizers_config["encodings"]["scheduler"]:
                training_callback_attributes.optimizers.schedulers["encodings"] = (
                    optimizers_config["encodings"]["scheduler"]
                    .setup()
                    .get_scheduler(
                        optimizer=training_callback_attributes.optimizers.optimizers["encodings"], lr_init=lr_init
                    )
                )

        # callbacks = [
        #     TrainingCallback(
        #         where_to_run=[TrainingCallbackLocation.AFTER_TRAIN_ITERATION],
        #         iters=self.upsampling_iters,
        #         func=reinitialize_optimizer,
        #         args=[self, training_callback_attributes],
        #     )
        # ]
        callbacks = []
        return callbacks

    def update_to_step(self, step: int) -> None:
        if step < self.upsampling_iters[0]:
            return

        new_iters = list(self.upsampling_iters) + [step + 1]
        new_iters.sort()

        index = new_iters.index(step + 1)
        new_grid_resolution = self.upsampling_steps[index - 1]

        if self.shared_features:
            self.field.density_rgb_encoding.upsample_grid(new_grid_resolution)
        else:
            self.field.density_encoding.upsample_grid(new_grid_resolution)  # delete [0] later!
            self.field.color_encoding.upsample_grid(new_grid_resolution)  # type: ignore

    def populate_modules(self):
        """Set the fields and modules"""
        super().populate_modules()



        color_encoding = GaplaneEncoding(
            resolution=self.init_resolution,
            num_components=self.num_color_components,
            reduce=self.reduce,
        )


        # feature_encoding = NeRFEncoding(in_dim=self.appearance_dim, num_frequencies=2, min_freq_exp=0, max_freq_exp=2)
        feature_encoding = Identity(in_dim=3) ## no feature encoding (in dim wrong)
        # direction_encoding = NeRFEncoding(in_dim=3, num_frequencies=2, min_freq_exp=0, max_freq_exp=2) ## tensorf
        direction_encoding = NeRFEncoding(
            in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=4.0, include_input=True
        ) ### from vanilla nerf

        if self.shared_features:
            self.field = GAplanesSharedField(
                self.scene_box.aabb,
                direction_encoding=direction_encoding,
                density_rgb_encoding=color_encoding,
                head_mlp_num_layers=2,
                head_mlp_layer_width=128,
            )
        else:
            density_encoding = GaplaneEncoding(
                resolution=self.init_resolution,
                num_components=self.num_den_components,
                reduce=self.reduce,
            )
            self.field = GAplanesField(
                self.scene_box.aabb,
                feature_encoding=feature_encoding,
                direction_encoding=direction_encoding,
                density_encoding=density_encoding,
                color_encoding=color_encoding,
                head_mlp_num_layers=2,
                head_mlp_layer_width=128,
                use_sh=False,
            )

        # samplers
        self.sampler_uniform = UniformSampler(num_samples=self.config.num_uniform_samples, single_jitter=True)
        self.sampler_pdf = PDFSampler(num_samples=self.config.num_samples, single_jitter=True, include_original=False)

        # renderers
        self.renderer_rgb = RGBRenderer(background_color=self.config.background_color)
        self.renderer_accumulation = AccumulationRenderer()
        self.renderer_depth = DepthRenderer()

        # losses
        self.rgb_loss = MSELoss()

        # metrics
        from torchmetrics.functional import structural_similarity_index_measure
        from torchmetrics.image import PeakSignalNoiseRatio
        from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim = structural_similarity_index_measure
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)

        # colliders
        if self.config.enable_collider:
            # print("initializing collider")
            self.collider = AABBBoxCollider(scene_box=self.scene_box)

        # regularizations
        # if self.config.tensorf_encoding == "cp" and self.config.regularization == "tv":
        #     raise RuntimeError("TV reg not supported for CP decomposition")

        # (optional) camera optimizer
        self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
            num_cameras=self.num_train_data, device="cpu"
        )

    def get_param_groups(self) -> Dict[str, List[Parameter]]:
        param_groups = {}

        param_groups["fields"] = (
            list(self.field.density_mlp.parameters())
            + list(self.field.color_mlp.parameters())
            + list(self.field.field_output_rgb.parameters())
        )
        if self.shared_features:
            param_groups["encodings"] = list(self.field.density_rgb_encoding.parameters())
        else:
            param_groups["encodings"] = list(self.field.color_encoding.parameters()) + list(
                self.field.density_encoding.parameters()
            )
        self.camera_optimizer.get_param_groups(param_groups=param_groups)

        return param_groups

    def get_outputs(self, ray_bundle: RayBundle):
        # uniform sampling
        if self.training:
            self.camera_optimizer.apply_to_raybundle(ray_bundle)
        ray_samples_uniform = self.sampler_uniform(ray_bundle)
        dens = self.field.get_density(ray_samples_uniform) ##### dens
        # ipdb.set_trace()
        weights = ray_samples_uniform.get_weights(dens)
        coarse_accumulation = self.renderer_accumulation(weights)
        acc_mask = torch.where(coarse_accumulation < 0.0001, False, True).reshape(-1)

        # pdf sampling
        ray_samples_pdf = self.sampler_pdf(ray_bundle, ray_samples_uniform, weights)

        # fine field:
        field_outputs_fine = self.field.forward(
            ray_samples_pdf, mask=acc_mask, bg_color=colors.WHITE.to(weights.device)
        )
        if self.config.use_gradient_scaling:
            field_outputs_fine = scale_gradients_by_distance_squared(field_outputs_fine, ray_samples_pdf)

        weights_fine = ray_samples_pdf.get_weights(field_outputs_fine[FieldHeadNames.DENSITY])

        accumulation = self.renderer_accumulation(weights_fine)
        depth = self.renderer_depth(weights_fine, ray_samples_pdf)

        rgb = self.renderer_rgb(
            rgb=field_outputs_fine[FieldHeadNames.RGB],
            weights=weights_fine,
        )

        rgb = torch.where(accumulation < 0, colors.WHITE.to(rgb.device), rgb)
        accumulation = torch.clamp(accumulation, min=0)

        outputs = {"rgb": rgb, "accumulation": accumulation, "depth": depth}
        return outputs

    def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
        # Scaling metrics by coefficients to create the losses.
        device = outputs["rgb"].device
        image = batch["image"].to(device)
        pred_image, image = self.renderer_rgb.blend_background_for_loss_computation(
            pred_image=outputs["rgb"],
            pred_accumulation=outputs["accumulation"],
            gt_image=image,
        )

        rgb_loss = self.rgb_loss(image, pred_image)

        loss_dict = {"rgb_loss": rgb_loss}

        if self.config.regularization == "l1":
            l1_parameters = []
            if self.shared_features:
                for parameter in self.field.density_rgb_encoding.parameters():
                    l1_parameters.append(parameter.view(-1))
            else:
                for parameter in self.field.density_encoding.parameters():
                    l1_parameters.append(parameter.view(-1))

            loss_dict["l1_reg"] = torch.abs(torch.cat(l1_parameters)).mean()
        elif self.config.regularization == "tv":
            if self.shared_features:
                plane_coef = self.field.density_rgb_encoding.plane_coef
                loss_dict["tv_reg_density"] = tv_loss(plane_coef)
            else:
                density_plane_coef = self.field.density_encoding.plane_coef
                color_plane_coef = self.field.color_encoding.plane_coef
                assert isinstance(color_plane_coef, torch.Tensor) and isinstance(
                    density_plane_coef, torch.Tensor
                ), "TV reg only supported for encoding types with plane_coef attribute"
                loss_dict["tv_reg_density"] = tv_loss(density_plane_coef)
                loss_dict["tv_reg_color"] = tv_loss(color_plane_coef)
        elif self.config.regularization == "none":
            pass
        else:
            raise ValueError(f"Regularization {self.config.regularization} not supported")

        self.camera_optimizer.get_loss_dict(loss_dict)

        loss_dict = misc.scale_dict(loss_dict, self.config.loss_coefficients)
        return loss_dict

    def get_image_metrics_and_images(
        self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]
    ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
        image = batch["image"].to(outputs["rgb"].device)
        image = self.renderer_rgb.blend_background(image)
        rgb = outputs["rgb"]
        acc = colormaps.apply_colormap(outputs["accumulation"])
        assert self.config.collider_params is not None
        depth = colormaps.apply_depth_colormap(
            outputs["depth"],
            accumulation=outputs["accumulation"],
            near_plane=self.config.collider_params["near_plane"],
            far_plane=self.config.collider_params["far_plane"],
        )

        combined_rgb = torch.cat([image, rgb], dim=1)

        # Switch images from [H, W, C] to [1, C, H, W] for metrics computations
        image = torch.moveaxis(image, -1, 0)[None, ...]
        rgb = torch.moveaxis(rgb, -1, 0)[None, ...]

        psnr = self.psnr(image, rgb)
        ssim = cast(torch.Tensor, self.ssim(image, rgb))
        lpips = self.lpips(image, rgb)

        metrics_dict = {
            "psnr": float(psnr.item()),
            "ssim": float(ssim.item()),
            "lpips": float(lpips.item()),
        }
        self.camera_optimizer.get_metrics_dict(metrics_dict)

        images_dict = {"img": combined_rgb, "accumulation": acc, "depth": depth}
        return metrics_dict, images_dict
