"""Reconstructors based on matching pursuit."""
import math

import torch
from tqdm import tqdm

from npeff_torch.compressed_sensing import compressed_sensing_common
from npeff_torch.peis import random_projectors


###############################################################################


def _argmax_magnitude_vector(x: torch.Tensor) -> torch.Tensor:
    """Returns the index of the entry with the largest absolute value."""
    assert len(x.shape) == 1, 'Must be a vector.'
    argmax = torch.argmax(x)
    argmin = torch.argmin(x)
    return torch.where(x[argmax] > -x[argmin], argmax, argmin)


###############################################################################


class MatchingPursuitReconstructor(compressed_sensing_common.ReconstructorAbc):

    def __init__(
        self, *,
        random_projector: 'random_projectors.RandomProjector',
        d_original: int,
        n_steps: int,
        use_tqdm: bool = False,
    ):
        super().__init__(
            random_projector=random_projector,
            d_original=d_original,
        )
        self.n_steps = n_steps
        self.use_tqdm = use_tqdm

    # @torch.no_grad()
    # def reconstruct_vector(self, x: torch.Tensor) -> torch.Tensor:
    #     assert len(x.shape) == 1, 'Must be a vector.'
    #     # Needs a dummy batch dimension.
    #     reconstructed = self.reconstruct_vectors(x[None, :])
    #     return torch.squeeze(reconstructed, dim=0)

    @torch.no_grad()
    def reconstruct_vector(self, x: torch.Tensor) -> torch.Tensor:
        assert len(x.shape) == 1, 'Must be a vector.'

        # Make a copy since we will be modifying this in place.
        x = x.clone()

        recon = torch.zeros((self.d_original,), dtype=x.dtype, device=x.device)

        dot_products_buffer = torch.empty((self.d_original,), dtype=x.dtype, device=x.device)
        buffer2 = torch.empty((self.random_projector.params.d_projection), dtype=x.dtype, device=x.device)

        # TODO: Better naming and handling of tensor aliasing.

        # TODO: Think about why this is needed and a more robust way to changing this.
        # coeff_rescaling_factor = math.sqrt(self.d_original / self.random_projector.params.d_projection)
        coeff_rescaling_factor = self.d_original / self.random_projector.params.d_projection
        # coeff_rescaling_factor = math.sqrt(self.d_original)
        print("TODO: Figure out why the scaling factor was needed above.")
        #   => Need to have normalized columns in the dictionary


        for step in tqdm(range(self.n_steps)) if self.use_tqdm else range(self.n_steps):
            self.random_projector.transposed_project(x[None, :], out=dot_products_buffer[None, :])

            selected_index = _argmax_magnitude_vector(dot_products_buffer)

            coeff = dot_products_buffer[selected_index].clone()
            coeff *= coeff_rescaling_factor
            # print(coeff, selected_index)
            # print(torch.amax(dot_products_buffer), torch.amin(dot_products_buffer))


            recon[selected_index] += coeff

            buffer3 = dot_products_buffer
            buffer3.zero_()
            buffer3[selected_index] = coeff

            self.random_projector.project(buffer3[None, :], out=buffer2[None, :])
            x -= buffer2

        return recon

    # @torch.no_grad()
    # def reconstruct_vectors(self, x: torch.Tensor) -> torch.Tensor:
    #     assert len(x.shape) == 2, 'Must be a batch of vectors.'
    #     n_vecs = x.shape[0]

    #     # Make a copy since we will be modifying this in place.
    #     x = x.clone()

    #     # TODO: Technically, I don't need two buffers with a small computational overhead
    #     # if I recompute the dot product using only the selected component.
    #     dot_products_buffer = torch.empty((n_vecs, self.d_original), dtype=x.dtype, device=x.device)
    #     scores_buffer = torch.empty_like(dot_products_buffer)

    #     for step in range(self.n_steps):
    #         self.random_projector.transposed_project(x, out=dot_products_buffer)
    #         torch.abs(dot_products_buffer, out=scores_buffer)

    #         selected_indices = torch.argmax(scores_buffer, dim=-1)
    #         # torch.gather(dot_products_buffer, dim=)


