

import logging
import time
from types import SimpleNamespace
from typing import Any, Callable, Dict, Tuple

import torch
import torch.nn as nn

from geocalib.camera import BaseCamera, camera_models
from geocalib.gravity import Gravity
from geocalib.misc import J_focal2fov
from geocalib.perspective_fields import J_perspective_field, get_perspective_field
from geocalib.utils import focal2fov, rad2deg

logger = logging.getLogger(__name__)


def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
    ref = data.get("up_field", data["latitude_field"])
    ref = ref.detach()

    h, w = ref.shape[-2:]
    batch_h, batch_w = (
        ref.new_ones((ref.shape[0],)) * h,
        ref.new_ones((ref.shape[0],)) * w,
    )

    init_r = ref.new_zeros((ref.shape[0],))
    init_p = ref.new_zeros((ref.shape[0],))

    focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
    init_vfov = focal2fov(focal, h)

    params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
    params |= {"scales": data["scales"]} if "scales" in data else {}
    params |= {"dist": data["prior_dist"]} if "prior_dist" in data else {}
    camera = camera_model.from_dict(params)
    camera = camera.float().to(ref.device)

    gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)

    if "prior_gravity" in data:
        gravity = data["prior_gravity"].float().to(ref.device)
        gravity = Gravity(gravity) if isinstance(gravity, torch.Tensor) else gravity

    return camera, gravity


def scaled_loss(
    x: torch.Tensor, fn: Callable, a: float
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    a2 = a**2
    loss, loss_d1, loss_d2 = fn(x / a2)
    return loss * a2, loss_d1, loss_d2 / a2


def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """The classical robust Huber loss, with first and second derivatives."""
    mask = x <= 1
    sx = torch.sqrt(x + 1e-8)  # avoid nan in backward pass
    isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
    loss = torch.where(mask, x, 2 * sx - 1)
    loss_d1 = torch.where(mask, torch.ones_like(x), isx)
    loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
    return loss, loss_d1, loss_d2


def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
    """Early stopping criterion based on cost convergence."""
    return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)


def update_lambda(
    lamb: torch.Tensor,
    prev_cost: torch.Tensor,
    new_cost: torch.Tensor,
    lambda_min: float = 1e-6,
    lambda_max: float = 1e2,
) -> torch.Tensor:
    """Update damping factor for Levenberg-Marquardt optimization."""
    new_lamb = lamb.new_zeros(lamb.shape)
    new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
    lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
    return lamb


def optimizer_step(
    G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
    diag = H.diagonal(dim1=-2, dim2=-1)
    diag = diag * lambda_.unsqueeze(-1)  # (B, 3)

    H = H + diag.clamp(min=eps).diag_embed()

    H_, G_ = H.cpu(), G.cpu()
    try:
        U = torch.linalg.cholesky(H_)
    except RuntimeError:
        logger.warning("Cholesky decomposition failed. Stopping.")
        delta = H.new_zeros((H.shape[0], H.shape[-1]))  # (B, 3)
    else:
        delta = torch.cholesky_solve(G_[..., None], U)[..., 0]

    return delta.to(H.device)



class LMOptimizer(nn.Module):


    default_conf = {

        "camera_model": "pinhole", 
        "shared_intrinsics": False,  
        "num_steps": 30,
        "lambda_": 0.1,
        "fix_lambda": False,
        "early_stop": True,
        "atol": 1e-8,
        "rtol": 1e-8,
        "use_spherical_manifold": True, 
        "use_log_focal": True,  

        "up_loss_fn_scale": 1e-2,
        "lat_loss_fn_scale": 1e-2,
        "verbose": False,
    }

    def __init__(self, conf: Dict[str, Any]):

        super().__init__()
        self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
        self.num_steps = conf.num_steps

        self.set_camera_model(conf.camera_model)
        self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)

    def set_camera_model(self, camera_model: str) -> None:

        assert (
            camera_model in camera_models.keys()
        ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
        self.camera_model = camera_models[camera_model]
        self.camera_has_distortion = hasattr(self.camera_model, "dist")

        logger.debug(
            f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
        )

    def setup_optimization_and_priors(
        self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
    ) -> None:

        if data is None:
            data = {}
        self.shared_intrinsics = shared_intrinsics

        self.estimate_gravity = True
        if "prior_gravity" in data:
            self.estimate_gravity = False
            logger.debug("Using provided gravity as prior.")

        self.estimate_focal = True
        if "prior_focal" in data:
            self.estimate_focal = False
            logger.debug("Using provided focal as prior.")

        self.estimate_dist = self.camera_model.name() in [
            "radial",
            "simple_radial",
            "simple_divisional",
        ]
        if "prior_dist" in data:
            self.estimate_dist = False
            logger.debug("Using provided distortion as prior.")

        self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
        self.focal_delta_dims = (
            (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
        )

        self.dist_delta_dims = None
        if self.estimate_dist:
            self.dist_delta_dims = tuple(
                range(
                    self.focal_delta_dims[-1] + 1,
                    self.focal_delta_dims[-1] + 1 + self.camera_model.num_dist_params(),
                )
            )

        self.n_intrinsic_params = self.estimate_focal + (
            self.camera_model.num_dist_params() if self.camera_has_distortion else 0
        )

        logger.debug(f"Camera Model:         {self.camera_model}")
        logger.debug(f"Optimizing gravity:   {self.estimate_gravity} ({self.gravity_delta_dims})")
        logger.debug(f"Optimizing focal:     {self.estimate_focal} ({self.focal_delta_dims})")
        logger.debug(f"Optimizing distortion:{self.estimate_dist} ({self.dist_delta_dims})")

        logger.debug(f"Shared intrinsics:  {self.shared_intrinsics}")

    def calculate_residuals(
        self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:

        perspective_up, perspective_lat = get_perspective_field(camera, gravity)
        perspective_lat = torch.sin(perspective_lat)

        residuals = {}
        if "up_field" in data:
            up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
            residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)

        if "latitude_field" in data:
            target_lat = torch.sin(data["latitude_field"])
            lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
            residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)

        return residuals

    def calculate_costs(
        self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:

        costs, weights = {}, {}

        if "up_residual" in residuals:
            up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
            up_cost, up_weight, _ = scaled_loss(up_cost, huber_loss, self.conf.up_loss_fn_scale)

            if "up_confidence" in data:
                up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
                up_weight = up_weight * up_conf
                up_cost = up_cost * up_conf

            costs["up_cost"] = up_cost
            weights["up_weights"] = up_weight

        if "latitude_residual" in residuals:
            lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
            lat_cost, lat_weight, _ = scaled_loss(lat_cost, huber_loss, self.conf.lat_loss_fn_scale)

            if "latitude_confidence" in data:
                lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
                lat_weight = lat_weight * lat_conf
                lat_cost = lat_cost * lat_conf

            costs["latitude_cost"] = lat_cost
            weights["latitude_weights"] = lat_weight

        return costs, weights

    def calculate_gradient_and_hessian(
        self,
        J: torch.Tensor,
        residuals: torch.Tensor,
        weights: torch.Tensor,
        shared_intrinsics: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        dims = ()
        if self.estimate_gravity:
            dims = (0, 1)
        if self.estimate_focal:
            dims += (2,)
        if self.camera_has_distortion:
            dims += tuple(range(3, 3 + self.camera_model.num_dist_params()))
        assert dims, "No parameters to optimize"

        J = J[..., dims]

        Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
        Grad = weights[..., None] * Grad
        Grad = Grad.sum(-2)  # (B, N_params)

        if shared_intrinsics:
      
            Grad_g = Grad[..., :2].reshape(1, -1)
            Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
            Grad_dist = Grad[..., 3:].sum(-2).reshape(1, -1)
            Grad = torch.cat([Grad_g, Grad_f, Grad_dist], dim=-1)

        Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
        Hess = weights[..., None, None] * Hess
        Hess = Hess.sum(-3)

        if shared_intrinsics:

            B = Hess.shape[0]
            H_g = torch.block_diag(*list(Hess[..., :2, :2]))
            J_intrinsics_g = Hess[..., :2, 2:].reshape(B * 2, -1)
            J_g_intrinsics = Hess[..., 2:, :2].permute(0, 2, 1).reshape(B * 2, -1).T

            H_intrinsics = Hess[..., 2:, 2:].sum(-3)

            dims = H_g.shape[-1] + self.n_intrinsic_params
            Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
            Hess[: -self.n_intrinsic_params, : -self.n_intrinsic_params] = H_g
            Hess[-self.n_intrinsic_params :, : -self.n_intrinsic_params] = J_g_intrinsics
            Hess[: -self.n_intrinsic_params, -self.n_intrinsic_params :] = J_intrinsics_g
            Hess[-self.n_intrinsic_params :, -self.n_intrinsic_params :] = H_intrinsics
            Hess = Hess.unsqueeze(0)

        return Grad, Hess

    def setup_system(
        self,
        camera: BaseCamera,
        gravity: Gravity,
        residuals: Dict[str, torch.Tensor],
        weights: Dict[str, torch.Tensor],
        as_rpf: bool = False,
        shared_intrinsics: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        J_up, J_lat = J_perspective_field(
            camera,
            gravity,
            spherical=self.conf.use_spherical_manifold and not as_rpf,
            log_focal=self.conf.use_log_focal and not as_rpf,
        )

        J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1])  # (B, N, 2, 3)
        J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1])  # (B, N, 1, 3)

        n_params = (
            2 * self.estimate_gravity
            + self.estimate_focal
            + (self.camera_model.num_dist_params() if self.camera_has_distortion else 0)
        )
        Grad = J_up.new_zeros(J_up.shape[0], n_params)
        Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)

        if shared_intrinsics:
            N_params = (
                Grad.shape[0] * (n_params - self.n_intrinsic_params) + self.n_intrinsic_params
            )
            Grad = Grad.new_zeros(1, N_params)
            Hess = Hess.new_zeros(1, N_params, N_params)

        if "up_residual" in residuals:
            Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
                J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
            )

            if self.conf.verbose:
                logger.info(f"Up J:\n{Up_Grad.mean(0)}")

            Grad = Grad + Up_Grad
            Hess = Hess + Up_Hess

        if "latitude_residual" in residuals:
            Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
                J_lat,
                residuals["latitude_residual"],
                weights["latitude_weights"],
                shared_intrinsics,
            )

            if self.conf.verbose:
                logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")

            Grad = Grad + Lat_Grad
            Hess = Hess + Lat_Hess

        return Grad, Hess

    def estimate_uncertainty(
        self,
        camera_opt: BaseCamera,
        gravity_opt: Gravity,
        errors: Dict[str, torch.Tensor],
        weights: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        _, Hess = self.setup_system(
            camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
        )
        Cov = torch.inverse(Hess)

        roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
        pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
        gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
        if self.estimate_gravity:
            roll_uncertainty = Cov[..., 0, 0]
            pitch_uncertainty = Cov[..., 1, 1]

            try:
                delta_uncertainty = Cov[..., :2, :2]
                eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
                gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
            except RuntimeError:
                logger.warning("Could not calculate gravity uncertainty")
                gravity_uncertainty = Cov.new_zeros(Cov.shape[0])

        focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
        fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
        if self.estimate_focal:
            focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
            fov_uncertainty = (
                J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
            )

        return {
            "covariance": Cov,
            "roll_uncertainty": torch.sqrt(roll_uncertainty),
            "pitch_uncertainty": torch.sqrt(pitch_uncertainty),
            "gravity_uncertainty": torch.sqrt(gravity_uncertainty),
            "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
            "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
        }

    def update_estimate(
        self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
    ) -> Tuple[BaseCamera, Gravity]:

        delta_gravity = (
            delta[..., self.gravity_delta_dims]
            if self.estimate_gravity
            else delta.new_zeros(delta.shape[:-1] + (2,))
        )
        new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)

        delta_f = (
            delta[..., self.focal_delta_dims]
            if self.estimate_focal
            else delta.new_zeros(delta.shape[:-1] + (1,))
        )
        new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)

        if self.camera_has_distortion and self.estimate_dist:
            delta_dist = delta[..., self.dist_delta_dims]
            new_camera = new_camera.update_dist(delta_dist)

        return new_camera, new_gravity

    def optimize(
        self,
        data: Dict[str, torch.Tensor],
        camera_opt: BaseCamera,
        gravity_opt: Gravity,
    ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:

        key = list(data.keys())[0]
        B = data[key].shape[0]

        lamb = data[key].new_ones(B) * self.conf.lambda_
        if self.shared_intrinsics:
            lamb = data[key].new_ones(1) * self.conf.lambda_

        infos = {"stop_at": self.num_steps}
        for i in range(self.num_steps):
            if self.conf.verbose:
                logger.info(f"Step {i+1}/{self.num_steps}")

            errors = self.calculate_residuals(camera_opt, gravity_opt, data)
            costs, weights = self.calculate_costs(errors, data)

            if i == 0:
                prev_cost = sum(c.mean(-1) for c in costs.values())
                for k, c in costs.items():
                    infos[f"initial_{k}"] = c.mean(-1)

                infos["initial_cost"] = prev_cost

            Grad, Hess = self.setup_system(
                camera_opt,
                gravity_opt,
                errors,
                weights,
                shared_intrinsics=self.shared_intrinsics,
            )
            delta = optimizer_step(Grad, Hess, lamb)  # (B, N_params)

            if self.shared_intrinsics:
                delta_g = delta[..., : -self.n_intrinsic_params].reshape(B, 2)
                delta_f = delta[..., -self.n_intrinsic_params].expand(B, 1)
                delta_dist = delta[..., -self.n_intrinsic_params + 1 :].expand(B, -1)
                delta = torch.cat([delta_g, delta_f, delta_dist], dim=-1)

            # calculate new cost
            camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
            new_cost, _ = self.calculate_costs(
                self.calculate_residuals(camera_opt, gravity_opt, data), data
            )
            new_cost = sum(c.mean(-1) for c in new_cost.values())

            if not self.conf.fix_lambda and not self.shared_intrinsics:
                lamb = update_lambda(lamb, prev_cost, new_cost)

            if self.conf.verbose:
                logger.info(f"Cost:\nPrev: {prev_cost}\nNew:  {new_cost}")
                logger.info(f"Camera:\n{camera_opt._data}")

            if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
                infos["stop_at"] = min(i + 1, infos["stop_at"])

                if self.conf.early_stop:
                    if self.conf.verbose:
                        logger.info(f"Early stopping at step {i+1}")
                    break

            prev_cost = new_cost

            if i == self.num_steps - 1 and self.conf.early_stop:
                logger.warning("Reached maximum number of steps without convergence.")

        final_errors = self.calculate_residuals(camera_opt, gravity_opt, data)  # (B, N, 3)
        final_cost, weights = self.calculate_costs(final_errors, data)  # (B, N)

        if not self.training:
            infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)

        infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
        for k, c in final_cost.items():
            infos[f"final_{k}"] = c.mean(-1)

        infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())

        return camera_opt, gravity_opt, infos

    def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Run the LM optimization."""
        camera_init, gravity_init = get_trivial_estimation(data, self.camera_model)

        self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)

        start = time.time()
        camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)

        if self.conf.verbose:
            logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")

            logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
            logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")

            logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
            logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")

        return {"camera": camera_opt, "gravity": gravity_opt, **infos}
