# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
TensorRF implementation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Literal, Tuple, Type, cast, Sequence, Optional

import numpy as np
import torch
from torch.nn import Parameter


from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.field_components.encodings import NeRFEncoding, GaplaneEncoding, Identity, TensorVMEncoding ###
from nerfstudio.field_components.field_heads import FieldHeadNames

from nerfstudio.fields.gaplanes_field_v3 import GAplanesV3Field ###
# from nerfstudio.fields.gaplanes_shared_field import GAplanesSharedField
from nerfstudio.fields.gaplanes_field_v2 import GAPlanesDensityField

from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared, tv_loss, distortion_loss, interlevel_loss
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler, ProposalNetworkSampler

from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer
from nerfstudio.model_components.scene_colliders import AABBBoxCollider
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps, colors, misc

import ipdb
########## with proposal network but not multires
@dataclass
class GAplanesV3ModelConfig(ModelConfig):
    """GAplanes model config"""

    _target: Type = field(default_factory=lambda: GAplanesV3Model)
    """target class to instantiate"""
    init_resolution: Sequence[int] = (128, 128, 64)
    """initial render resolution"""
    final_resolution: Sequence[int] = (300, 300, 64)
    """final render resolution"""

    upsampling_iters: Tuple[int, ...] = (2000, 3000, 4000, 5500, 7000)
    """specifies a list of iteration step numbers to perform upsampling"""
    loss_coefficients: Dict[str, float] = to_immutable_dict(
        {
            "rgb_loss": 1.0,
            "plane_tv": 1e-3,
            "plane_tv_proposal_net": 1e-4,
            "l1_reg": 5e-4,
            "interlevel": 1.0,
            "distortion": 0.001,
        }
    )
    """Loss specific weights."""
    # num_samples: int = 50
    # """Number of samples in field evaluation"""
    # num_uniform_samples: int = 200
    # """Number of samples in density evaluation"""

    num_components: Sequence[int] = (48, 48, 16) # int = 48
    """Number of components in color encoding"""
    reduce: Literal["concat", "product"] = "product"
    """How to form the features"""

    regularization: Literal["none", "l1", "tv"] = "l1"
    """Regularization method"""
    use_gradient_scaling: bool = False
    """Use gradient scaler where the gradients are lower for points closer to the camera."""
    background_color: Literal["random", "last_sample", "black", "white"] = "white"
    """Whether to randomize the background color."""

    # proposal sampling arguments
    num_proposal_iterations: int = 2
    """Number of proposal network iterations."""
    use_same_proposal_network: bool = False
    """Use the same proposal network. Otherwise use different ones."""
    proposal_net_args_list: List[Dict] = field(
        default_factory=lambda: [
            {"num_output_coords": 8, "resolution": [64, 64, 64]},
            {"num_output_coords": 8, "resolution": [128, 128, 128]},
        ]
    )
    """Arguments for the proposal density fields."""
    num_proposal_samples: Optional[Tuple[int]] = (256, 128)
    """Number of samples per ray for each proposal network."""

    num_samples: Optional[int] = 48
    """Number of samples per ray used for rendering."""
    single_jitter: bool = False
    """Whether use single jitter or not for the proposal networks."""
    proposal_warmup: int = 5000
    """Scales n from 1 to proposal_update_every over this many steps."""
    proposal_update_every: int = 5
    """Sample every n steps after the warmup."""
    use_proposal_weight_anneal: bool = True
    """Whether to use proposal weight annealing."""
    proposal_weights_anneal_slope: float = 10.0
    """Slope of the annealing function for the proposal weights."""
    proposal_weights_anneal_max_num_iters: int = 1000
    """Max num iterations for the annealing function."""


class GAplanesV3Model(Model):
    """GAplanes Model

    Args:
        config: GAplanes configuration to instantiate model
    """

    config: GAplanesV3ModelConfig

    def __init__(
        self,
        config: GAplanesV3ModelConfig,
        **kwargs,
    ) -> None:
        self.init_resolution = config.init_resolution
        self.upsampling_iters = config.upsampling_iters
        # self.num_den_components = config.num_den_components
        self.reduce = config.reduce

        self.num_components = config.num_components
        self.upsampling_steps = (
            np.round(
                np.exp(
                    np.linspace(
                        np.log(config.init_resolution),
                        np.log(config.final_resolution),
                        len(config.upsampling_iters) + 1,
                    )
                )
            )
            .astype("int")
            .tolist()[1:]
        )
        super().__init__(config=config, **kwargs)

    def get_training_callbacks(
        self, training_callback_attributes: TrainingCallbackAttributes
    ) -> List[TrainingCallback]:
        # the callback that we want to run every X iterations after the training iteration
        # def reinitialize_optimizer(self, training_callback_attributes: TrainingCallbackAttributes, step: int):
        #     assert training_callback_attributes.optimizers is not None
        #     assert training_callback_attributes.pipeline is not None
        #     index = self.upsampling_iters.index(step)
        #     resolution = self.upsampling_steps[index]

        #     self.field.density_rgb_encoding.upsample_grid(resolution)
            
        #     # reinitialize the encodings optimizer
        #     optimizers_config = training_callback_attributes.optimizers.config
        #     enc = training_callback_attributes.pipeline.get_param_groups()["encodings"]
        #     lr_init = optimizers_config["encodings"]["optimizer"].lr

        #     training_callback_attributes.optimizers.optimizers["encodings"] = optimizers_config["encodings"][
        #         "optimizer"
        #     ].setup(params=enc)
        #     if optimizers_config["encodings"]["scheduler"]:
        #         training_callback_attributes.optimizers.schedulers["encodings"] = (
        #             optimizers_config["encodings"]["scheduler"]
        #             .setup()
        #             .get_scheduler(
        #                 optimizer=training_callback_attributes.optimizers.optimizers["encodings"], lr_init=lr_init
        #             )
        #         )

        # callbacks = [
        #     TrainingCallback(
        #         where_to_run=[TrainingCallbackLocation.AFTER_TRAIN_ITERATION],
        #         iters=self.upsampling_iters,
        #         func=reinitialize_optimizer,
        #         args=[self, training_callback_attributes],
        #     )
        # ]
        callbacks = []
        ###### callbacks for proposal sampling
        if self.config.use_proposal_weight_anneal:
            # anneal the weights of the proposal network before doing PDF sampling
            N = self.config.proposal_weights_anneal_max_num_iters

            def set_anneal(step):
                # https://arxiv.org/pdf/2111.12077.pdf eq. 18
                train_frac = np.clip(step / N, 0, 1)
                bias = lambda x, b: (b * x) / ((b - 1) * x + 1)
                anneal = bias(train_frac, self.config.proposal_weights_anneal_slope)
                self.proposal_sampler.set_anneal(anneal)

            callbacks.append(
                TrainingCallback(
                    where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION],
                    update_every_num_iters=1,
                    func=set_anneal,
                )
            )
            callbacks.append(
                TrainingCallback(
                    where_to_run=[TrainingCallbackLocation.AFTER_TRAIN_ITERATION],
                    update_every_num_iters=1,
                    func=self.proposal_sampler.step_cb,
                )
            )

        return callbacks

    def update_to_step(self, step: int) -> None:
        if step < self.upsampling_iters[0]:
            return

        new_iters = list(self.upsampling_iters) + [step + 1]
        new_iters.sort()

        index = new_iters.index(step + 1)
        new_grid_resolution = self.upsampling_steps[index - 1]


        self.field.density_rgb_encoding.upsample_grid(new_grid_resolution)
        
    def populate_modules(self):
        """Set the fields and modules"""
        super().populate_modules()



        color_encoding = GaplaneEncoding(
            resolution=self.init_resolution,
            num_components=self.num_components,
            reduce=self.reduce,
        )


        # direction_encoding = NeRFEncoding(in_dim=3, num_frequencies=2, min_freq_exp=0, max_freq_exp=2) ## tensorf
        direction_encoding = NeRFEncoding(
            in_dim=3, num_frequencies=4, min_freq_exp=0.0, max_freq_exp=4.0, include_input=True
        ) ### from vanilla nerf


        self.field = GAplanesV3Field(
            self.scene_box.aabb,
            direction_encoding=direction_encoding,
            density_rgb_encoding=color_encoding,
        )
        self.density_fns = []
        num_prop_nets = self.config.num_proposal_iterations
        # Build the proposal network(s)
        self.proposal_networks = torch.nn.ModuleList()
        if self.config.use_same_proposal_network:
            assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed."
            prop_net_args = self.config.proposal_net_args_list[0]
            network = GAPlanesDensityField(
                aabb=self.scene_box.aabb,
                # spatial_distortion=scene_contraction,
                reduce=self.config.reduce,
                **prop_net_args,
            )
            # network = KPlanesDensityField(
            #     self.scene_box.aabb,
            #     spatial_distortion=scene_contraction,
                # linear_decoder=False,
            #     **prop_net_args,
            # )
            self.proposal_networks.append(network)
            self.density_fns.extend([network.density_fn for _ in range(num_prop_nets)])
        else:
            for i in range(num_prop_nets):
                prop_net_args = self.config.proposal_net_args_list[min(i, len(self.config.proposal_net_args_list) - 1)]
                network = GAPlanesDensityField(
                    aabb=self.scene_box.aabb,
                    # spatial_distortion=scene_contraction,
                    reduce=self.config.reduce,
                    **prop_net_args,
                )
                # network = KPlanesDensityField(
                #     self.scene_box.aabb,
                #     spatial_distortion=scene_contraction,
                #     linear_decoder=False,
                #     **prop_net_args,
                # )

                self.proposal_networks.append(network)
            self.density_fns.extend([network.density_fn for network in self.proposal_networks])

        # Samplers
        def update_schedule(step):
            return np.clip(
                np.interp(step, [0, self.config.proposal_warmup], [0, self.config.proposal_update_every]),
                1,
                self.config.proposal_update_every,
            )
        initial_sampler = UniformSampler(single_jitter=self.config.single_jitter)

        self.proposal_sampler = ProposalNetworkSampler(
            num_nerf_samples_per_ray=self.config.num_samples,
            num_proposal_samples_per_ray=self.config.num_proposal_samples,
            num_proposal_network_iterations=self.config.num_proposal_iterations,
            single_jitter=self.config.single_jitter,
            update_sched=update_schedule,
            initial_sampler=initial_sampler,
        )
        

        # samplers
        # self.sampler_uniform = UniformSampler(num_samples=self.config.num_uniform_samples, single_jitter=True)
        # self.sampler_pdf = PDFSampler(num_samples=self.config.num_samples, single_jitter=True, include_original=False)

        # renderers
        self.renderer_rgb = RGBRenderer(background_color=self.config.background_color)
        self.renderer_accumulation = AccumulationRenderer()
        self.renderer_depth = DepthRenderer()

        # losses
        self.rgb_loss = MSELoss()

        # metrics
        from torchmetrics.functional import structural_similarity_index_measure
        from torchmetrics.image import PeakSignalNoiseRatio
        from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

        self.psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.ssim = structural_similarity_index_measure
        self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True)

        # colliders
        if self.config.enable_collider:
            self.collider = AABBBoxCollider(scene_box=self.scene_box)



    def get_param_groups(self) -> Dict[str, List[Parameter]]:
        param_groups = {}

        param_groups["fields"] = (
            list(self.field.density_mlp.parameters())
            + list(self.field.color_mlp.parameters())
        )

        param_groups["encodings"] = list(self.field.density_rgb_encoding.parameters())
        ## add proposal network
        param_groups["proposal_networks"] = list(self.proposal_networks.parameters())



        return param_groups

    def get_outputs(self, ray_bundle: RayBundle):
        density_fns = self.density_fns
        ray_samples, weights_list, ray_samples_list = self.proposal_sampler(ray_bundle, density_fns=density_fns)
        
        # # uniform sampling
        # ray_samples_uniform = self.sampler_uniform(ray_bundle)
        # dens = self.field.get_density(ray_samples) ##### dens

        # weights = ray_samples.get_weights(dens)
        # weights_list.append(weights)
        # ray_samples_list.append(ray_samples)

        # accumulation = self.renderer_accumulation(weights)
        # acc_mask = torch.where(accumulation < 0.0001, False, True).reshape(-1)

        # field_outputs = self.field.forward(
        #     ray_samples, mask=acc_mask, bg_color=colors.WHITE.to(weights.device)
        # )

        # if self.config.use_gradient_scaling:
        #     field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

        # weights_fine = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])

        # accumulation = self.renderer_accumulation(weights_fine)
        # depth = self.renderer_depth(weights_fine, ray_samples)

        # rgb = self.renderer_rgb(
        #     rgb=field_outputs[FieldHeadNames.RGB],
        #     weights=weights_fine,
        # )

        field_outputs = self.field(ray_samples)

        weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
        weights_list.append(weights)
        ray_samples_list.append(ray_samples)

        rgb = self.renderer_rgb(rgb=field_outputs[FieldHeadNames.RGB], weights=weights)
        depth = self.renderer_depth(weights=weights, ray_samples=ray_samples)
        accumulation = self.renderer_accumulation(weights=weights)

        outputs = {
            "rgb": rgb,
            "accumulation": accumulation,
            "depth": depth,
        }

        # rgb = torch.where(accumulation < 0, colors.WHITE.to(rgb.device), rgb)
        # accumulation = torch.clamp(accumulation, min=0)

        # outputs = {"rgb": rgb, "accumulation": accumulation, "depth": depth}

        # These use a lot of GPU memory, so we avoid storing them for eval.
        if self.training:
            outputs["weights_list"] = weights_list
            outputs["ray_samples_list"] = ray_samples_list

        for i in range(self.config.num_proposal_iterations):
            outputs[f"prop_depth_{i}"] = self.renderer_depth(
                weights=weights_list[i], ray_samples=ray_samples_list[i]
            )
            
        return outputs

    def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
        # Scaling metrics by coefficients to create the losses.
        device = outputs["rgb"].device
        image = batch["image"].to(device)
        pred_image, image = self.renderer_rgb.blend_background_for_loss_computation(
            pred_image=outputs["rgb"],
            pred_accumulation=outputs["accumulation"],
            gt_image=image,
        )

        rgb_loss = self.rgb_loss(image, pred_image)

        loss_dict = {"rgb_loss": rgb_loss}

        if self.config.regularization == "l1":
            l1_parameters = []

            for parameter in self.field.density_rgb_encoding.parameters():
                l1_parameters.append(parameter.view(-1))
            
            loss_dict["l1_reg"] = torch.abs(torch.cat(l1_parameters)).mean()
        elif self.config.regularization == "tv":

            plane_coef = self.field.density_rgb_encoding.plane_coef
            loss_dict["plane_tv"] = tv_loss(plane_coef)
            ## can add tv loss for the proposal net as well
            # plane_coef = self.field.density_rgb_encoding.plane_coef
            # prop_grids = [p.feature_encoding.plane_coef for p in self.proposal_networks]
            # loss_dict["plane_tv"] = tv_loss(plane_coef)
            total = 0
            for p in self.proposal_networks:
                prop_grid = p.feature_encoding.plane_coef
                total += tv_loss(prop_grid)
            loss_dict["plane_tv_proposal_net"] = total / self.config.num_proposal_iterations
            
        elif self.config.regularization == "none":
            pass
        else:
            raise ValueError(f"Regularization {self.config.regularization} not supported")


        if self.training:
            loss_dict["interlevel"] = interlevel_loss(outputs["weights_list"], outputs["ray_samples_list"])
            loss_dict["distortion"] = distortion_loss(outputs["weights_list"], outputs["ray_samples_list"])

        loss_dict = misc.scale_dict(loss_dict, self.config.loss_coefficients)
        return loss_dict

    def get_image_metrics_and_images(
        self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]
    ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
        image = batch["image"].to(outputs["rgb"].device)
        image = self.renderer_rgb.blend_background(image)
        rgb = outputs["rgb"]
        acc = colormaps.apply_colormap(outputs["accumulation"])
        assert self.config.collider_params is not None
        depth = colormaps.apply_depth_colormap(
            outputs["depth"],
            accumulation=outputs["accumulation"],
            near_plane=self.config.collider_params["near_plane"],
            far_plane=self.config.collider_params["far_plane"],
        )

        combined_rgb = torch.cat([image, rgb], dim=1)

        # Switch images from [H, W, C] to [1, C, H, W] for metrics computations
        image = torch.moveaxis(image, -1, 0)[None, ...]
        rgb = torch.moveaxis(rgb, -1, 0)[None, ...]

        psnr = self.psnr(image, rgb)
        ssim = cast(torch.Tensor, self.ssim(image, rgb))
        lpips = self.lpips(image, rgb)

        metrics_dict = {
            "psnr": float(psnr.item()),
            "ssim": float(ssim.item()),
            "lpips": float(lpips.item()),
        }

        images_dict = {"img": combined_rgb, "accumulation": acc, "depth": depth}

        for i in range(self.config.num_proposal_iterations):
            key = f"prop_depth_{i}"
            prop_depth_i = colormaps.apply_depth_colormap(
                outputs[key],
                accumulation=outputs["accumulation"],
            )
            images_dict[key] = prop_depth_i

        return metrics_dict, images_dict
