"""Fixed-point continuation method.

From https://www.cmor-faculty.rice.edu/~zhang/reports/tr0707.pdf


Solves the problem:
    min_x |x|_1 + mu/2 * |Ax - b|_2^2
"""
from typing import List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm

from npeff_torch.peis import random_projectors

###############################################################################


class Fpc:

    def __init__(
        self,
        projected: torch.Tensor,
        random_projector: random_projectors.RandomProjector,
        *,
        d_original: int,
        mu: float,
        tau: Optional[float] = None,
        gamma: float = 0.99,
        beta: float = 4.0,
    ):
        # See page 26 of https://www.cmor-faculty.rice.edu/~zhang/reports/tr0707.pdf
        # for some information on what some of the arguments mean.
        self._projected = projected
        self._random_projector = random_projector
        self._d_original = d_original

        self._mu = mu
        self._gamma = gamma
        self._beta = beta

        self._n_vecs = projected.shape[0]
        self._d_projection = projected.shape[-1]

        if tau is None:
            self._tau = self._compute_default_tau()
        else:
            self._tau = tau

    #######################################################

    def _compute_default_tau(self) -> float:
        # Equation (73) of https://www.cmor-faculty.rice.edu/~zhang/reports/tr0707.pdf
        delta = float(self._d_projection) / float(self._d_original)
        return min(1 + 1.665 * (1 - delta), 1.999)

    #######################################################

    def _compute_default_initial_value(self) -> torch.Tensor:
        # x_0 = tau * A^T b
        value = torch.zeros([self._n_vecs, self._d_original], dtype=self._projected.dtype, device=self._projected.device)
        self._random_projector.transposed_project(self._projected, out=value)
        value.mul_(self._tau)
        return value

    def _compute_mu_i(self, mu_0: torch.Tensor, i: int) -> torch.Tensor:
        # return torch.minimum(mu_0 * self._beta ** i, self._mu)
        return torch.minimum(mu_0 * self._beta ** i, 0 * mu_0 + self._mu)

    def _compute_outer_iterations_counts(self, mu_0: torch.Tensor) -> List[int]:
        # {mu_0, ret}.shape = [n_vecs]
        mu_0 = mu_0.detach().cpu().numpy()
        counts = np.ceil(np.log(self._beta) * (np.log(self._mu) - np.log(mu_0))).astype(np.int32)
        return [int(c) for c in counts]

    def _inner_step(self, value: torch.Tensor, mu_i: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # TODO: Try to make this memory efficient on the GPU, maybe make custom CUDA kernel for (most of) this.
        # value.shape = [n_vecs, d_original]
        # mu_i.shape = [n_vecs]
        nu = self._tau / mu_i

        diff = self._random_projector.project(value) - self._projected

        # g(x) = A^T(Ax - b)
        g = torch.zeros_like(value)
        self._random_projector.transposed_project(diff, out=g)

        # h(x) = x - tau * g(x)
        h = value - self._tau * g

        # s(y) = sgn(y) * relu(abs(y) - nu)
        s = torch.sign(h) * torch.relu(torch.abs(h) - nu[:, None])

        # Compute metrics to determine when to terminate the inner loop.
        xtol_metric = torch.linalg.vector_norm(s - value, dim=-1)
        value_norms = torch.linalg.vector_norm(value, dim=-1)
        xtol_metric /= torch.maximum(value_norms, torch.ones_like(value_norms))

        gtol_metric = torch.linalg.vector_norm(g, ord=float('inf'), dim=-1)
        gtol_metric = mu_i * gtol_metric - 1

        return s, xtol_metric, gtol_metric

    # def _inner_step(self, value: torch.Tensor, mu_i: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    #     # TODO: Try to make this memory efficient on the GPU, maybe make custom CUDA kernel for (most of) this.
    #     # value.shape = [n_vecs, d_original]
    #     # mu_i.shape = [n_vecs]
    #     nu = self._tau / mu_i

    #     diff = self._random_projector.project(value) - self._projected

    #     print('diff:', torch.linalg.vector_norm(diff, dim=-1))

    #     # g(x) = A^T(Ax - b)
    #     g = torch.zeros_like(value)
    #     print('|g|:', torch.linalg.vector_norm(g, dim=-1))
    #     self._random_projector.transposed_project(diff, out=g)
    #     print('|g|:', torch.linalg.vector_norm(g, dim=-1))

    #     print(diff.abs().argmax())

    #     inds = torch.argsort(g.squeeze().abs(), descending=True)
    #     print(inds[:16])
    #     print(g.squeeze()[inds[:16]])
    #     # print(inds[:320])
    #     # print(inds[1:321] - inds[:320])

    #     print((g.abs() > 1e9).sum())
    #     print((g.abs() > 1e8).sum())
    #     print((g.abs() > 1e7).sum())
    #     print((g.abs() > 1e6).sum())
    #     print((g.abs() > 1e5).sum())
    #     print((g.abs() > 1e4).sum())
    #     # print((g[:, 614400:].abs() > 1e4).sum())

    #     raise ValueError('EARLY STOP')

    #     h = value - self._tau * g
    #     print('|h|:', torch.linalg.vector_norm(h, dim=-1))

    #     # s(y) = sgn(y) * relu(abs(y) - nu)
    #     s = torch.sign(h) * torch.relu(torch.abs(h) - nu[:, None])
    #     print('|s|:', torch.linalg.vector_norm(s, dim=-1))

    #     # Compute metrics to determine when to terminate the inner loop.
    #     xtol_metric = torch.linalg.vector_norm(s - value, dim=-1)
    #     value_norms = torch.linalg.vector_norm(value, dim=-1)
    #     xtol_metric /= torch.maximum(value_norms, torch.ones_like(value_norms))

    #     gtol_metric = torch.linalg.vector_norm(g, ord=float('inf'), dim=-1)
    #     gtol_metric = mu_i * gtol_metric - 1

    #     return s, xtol_metric, gtol_metric
    
    #######################################################

    # def _inner_step(self, value: torch.Tensor, mu_i: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    #     # TODO: Try to make this memory efficient on the GPU, maybe make custom CUDA kernel for (most of) this.
    #     # NOTE: This is me trying to get as many operations to be in-place as possible to prevent OOMs with the 1.7B model.
    #     # value.shape = [n_vecs, d_original]
    #     # mu_i.shape = [n_vecs]
    #     nu = self._tau / mu_i

    #     diff = self._random_projector.project(value) - self._projected

    #     # g(x) = A^T(Ax - b)
    #     g = torch.zeros_like(value)
    #     self._random_projector.transposed_project(diff, out=g)

    #     gtol_metric = torch.linalg.vector_norm(g, ord=float('inf'), dim=-1)
    #     gtol_metric = mu_i * gtol_metric - 1

    #     # h(x) = x - tau * g(x)
    #     # h = value - self._tau * g
    #     g *= -self._tau
    #     g += value
    #     h = g
    #     del g

    #     # s(y) = sgn(y) * relu(abs(y) - nu)
    #     # s = torch.sign(h) * torch.relu(torch.abs(h) - nu[:, None])
    #     s = torch.abs(h)
    #     s -= nu[:, None]
    #     s.relu_()
    #     s *= torch.sign(h)
    #     del h

    #     # Compute metrics to determine when to terminate the inner loop.
    #     xtol_metric = torch.linalg.vector_norm(s - value, dim=-1)
    #     value_norms = torch.linalg.vector_norm(value, dim=-1)
    #     xtol_metric /= torch.maximum(value_norms, torch.ones_like(value_norms))

    #     return s, xtol_metric, gtol_metric

    #######################################################

    @torch.no_grad()
    def run(self, *, max_inner_iters: int, xtol: float = 1e-4, gtol: float = 0.2) -> torch.Tensor:
        value = self._compute_default_initial_value()
        device = value.device

        value0_infty = torch.max(torch.abs(value), dim=-1).values
        mu_0 = self._tau / (self._gamma * value0_infty)

        outer_iteration_counts = self._compute_outer_iterations_counts(mu_0)
        print(f'outer_iteration_counts: {outer_iteration_counts}')
        print(f'max(outer_iteration_counts): {max(outer_iteration_counts)}')

        # TODO: Do not perform computation for vectors that have already converged. IDK
        # if there is an easy/efficient way to do in pytorch, but should be relatively easy
        # to do when implemented in a cuda kernel.
        i = 0
        outer_completed_mask = torch.zeros([self._n_vecs], dtype=torch.bool, device=device)
        while not torch.all(outer_completed_mask):
            mu_i = self._compute_mu_i(mu_0, i)

            inner_completed_mask = outer_completed_mask.detach().clone()
            for _ in tqdm(range(max_inner_iters)):
                next_value, xtol_metric, gtol_metric = self._inner_step(value, mu_i)
                # Do not change solutions that have already converged.
                next_value[inner_completed_mask] = value[inner_completed_mask]

                inner_completed_mask |= (xtol_metric < xtol) & (gtol_metric < gtol)
                if torch.all(inner_completed_mask):
                    break

                value = next_value

            outer_completed_mask = (mu_i == self._mu)
            i += 1
        
        return value
