"""Bregman iteration compressed sensing reconstruction.

From https://epubs.siam.org/doi/pdf/10.1137/070703983
"""
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch

from npeff_torch.compressed_sensing import compressed_sensing_common
from npeff_torch.compressed_sensing import fpcs
from npeff_torch.peis import random_projectors

###############################################################################


class _BregmanReconstructor:
    """
    Version 2 on page 8 of https://epubs.siam.org/doi/pdf/10.1137/070703983
    """

    def __init__(
        self,
        projected: torch.Tensor,
        random_projector: random_projectors.RandomProjector,
        *,
        d_original: int,
        # NOTE: This mu has a somewhat different meaning than in the FPC code. Here,
        # it weights the L1 term while it weights the reconstruction term in the FPC code.
        mu: float,
        fpc_kwargs: Optional[Dict[str, Any]] = None,
    ):
        if fpc_kwargs is None:
            fpc_kwargs = {}

        self._projected = projected
        self._random_projector = random_projector
        self._d_original = d_original

        self._mu = mu
        self._fpc_mu = 1.0 / mu
        self._fpc_kwargs = fpc_kwargs.copy()

        self._n_vecs = projected.shape[0]
        self._d_projection = projected.shape[-1]

    def _step(
        self,
        f: torch.Tensor,
        u: torch.Tensor,
        *,
        fpc_max_inner_iters: int,
        fpc_xtol: float,
        fpc_gtol: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # NOTE: I think the f below refers to _projected
        # f_{k+1} = f + (f_k - A u_k)
        f_next = torch.zeros_like(f)
        self._random_projector.project(u, out=f_next)
        f_next *= -1.0
        f_next += f
        f_next += self._projected

        # u_{k+1} = argmin_u mu * ||u||_1 + 1/2 * ||A u - f_{k+1}||_2^2
        fpc = fpcs.Fpc(
            projected=f_next,
            random_projector=self._random_projector,
            d_original=self._d_original,
            mu=self._fpc_mu,
            **self._fpc_kwargs,
        )
        u_next = fpc.run(
            max_inner_iters=fpc_max_inner_iters,
            xtol=fpc_xtol,
            gtol=fpc_gtol,
        )

        return f_next, u_next

    def _compute_stop_tol_metric(self, u: torch.Tensor, norm_projected: torch.Tensor) -> torch.Tensor:
        # ||A u - f||_2 / ||f||_2
        return torch.linalg.vector_norm(self._random_projector.project(u) - self._projected) / norm_projected

    @torch.no_grad()
    def run(
        self,
        *,
        bregman_max_iters: int,
        bregman_stop_tol: float = 1e-5,
        #
        fpc_max_inner_iters: int,
        fpc_xtol: float = 1e-4,
        fpc_gtol: float = 0.2,
    ) -> torch.Tensor:
        device = self._projected.device

        norm_projected = torch.linalg.vector_norm(self._projected, dim=-1)
        
        f = torch.zeros([self._n_vecs, self._d_projection], dtype=self._projected.dtype, device=device)
        u = torch.zeros([self._n_vecs, self._d_original], dtype=self._projected.dtype, device=device)

        for _ in range(bregman_max_iters):
            f, u = self._step(f, u, fpc_max_inner_iters=fpc_max_inner_iters, fpc_xtol=fpc_xtol, fpc_gtol=fpc_gtol)
            step_tol_metric = self._compute_stop_tol_metric(u, norm_projected)

            # TODO: Better handling of multiple vectors?
            if torch.all(step_tol_metric < bregman_stop_tol):
                break

        return u


###############################################################################


class BregmanReconstructor(compressed_sensing_common.ReconstructorAbc):

    def __init__(
        self, *,
        random_projector: 'random_projectors.RandomProjector',

        d_original: int,

        # NOTE: This bregman mu has a somewhat different meaning than in the FPC code. Here,
        # it weights the L1 term while it weights the reconstruction term in the FPC code.
        bregman_mu: float,

        bregman_max_iters: int,
        bregman_stop_tol: float = 1e-5,

        fpc_tau: Optional[float] = None,
        fpc_gamma: float = 0.99,
        fpc_beta: float = 4.0,

        fpc_max_inner_iters: int,
        fpc_xtol: float = 1e-4,
        fpc_gtol: float = 0.2,

        # TODO: Add other actual arguments here.
    ):
        super().__init__(
            random_projector=random_projector,
            d_original=d_original,
        )

        self.bregman_mu = bregman_mu

        self.bregman_max_iters = bregman_max_iters
        self.bregman_stop_tol = bregman_stop_tol

        self.fpc_tau = fpc_tau
        self.fpc_gamma = fpc_gamma
        self.fpc_beta = fpc_beta

        self.fpc_max_inner_iters = fpc_max_inner_iters
        self.fpc_xtol = fpc_xtol
        self.fpc_gtol = fpc_gtol

    def reconstruct_vector(self, x: torch.Tensor) -> torch.Tensor:
        assert len(x.shape) == 1, 'Must be a vector.'

        bregman_rc = _BregmanReconstructor(
            # Needs a dummy batch dimension.
            projected=x[None, :],
            random_projector=self.random_projector,
            d_original=self.d_original,
            mu=self.bregman_mu,
            fpc_kwargs={
                'tau': self.fpc_tau,
                'gamma': self.fpc_gamma,
                'beta': self.fpc_beta,
            },
        )

        reconstructed = bregman_rc.run(
            bregman_max_iters=self.bregman_max_iters,
            bregman_stop_tol=self.bregman_stop_tol,
            fpc_max_inner_iters=self.fpc_max_inner_iters,
            fpc_xtol=self.fpc_xtol,
            fpc_gtol=self.fpc_gtol,
        )
        
        return torch.squeeze(reconstructed, dim=0)
