"""Generalized linear operators."""

import torch

import time
import warnings


class Operator(torch.nn.Module):
    """General class for a linear operator."""

    def __init__(self, operator, adjoint_operator, dimension=1, channels=1):
        super().__init__()
        self.fwd_operator = operator
        self.adj_operator = adjoint_operator
        self.norm_val = None

        self.dimension = dimension
        self.channels = channels

    def forward(self, inputs):
        return self.fwd_operator(inputs)

    def T(self, inputs):
        return self.adj_operator(inputs)

    def norm(self):
        if self.norm_val is None:
            shape = [self.fwd_operator.weight.shape[0], 64, 64]
            self.norm_val = self.normest(self, shape)
        return self.norm_val

    @staticmethod
    def normest(operator, shape=[1, 64, 64], tol=1e-7, max_iterations=5000, verbose=False):
        """
            normest implementation, following the matlab baseline at
            MATLAB/R2018b/toolbox/matlab/matfun/normest.m
        """
        try:
            device = operator.fwd_operator.device
            dtype = operator.fwd_operator.dtype
        except AttributeError:
            device = operator.fwd_operator.weight.device
            dtype = operator.fwd_operator.weight.dtype
        with torch.no_grad():
            start_time = time.time()
            x = torch.randn(1, *shape, device=device, dtype=dtype)
            e = x.norm()
            if e == 0:
                return e
            x = x / e
            e0 = 0

            for cnt in range(max_iterations):
                e0 = e.clone()
                Ax = operator(x)
                x = operator.T(Ax)
                xnorm = x.norm()
                e = xnorm / Ax.norm()
                x = x / xnorm
                if torch.abs(e - e0) < tol * e:
                    if verbose:
                        m, s = m, s = divmod(time.time() - start_time, 60)
                        print(f'Tolerance {tol} reached after {cnt+1} iterations in {m:.0f} minutes and {s:.2f} seconds')
                    return e
            # Return anyway if not converged
            final_tol = torch.abs(e - e0) / e
            warnings.warn(f'Normest did not convergence to tolerance {tol} within {cnt+1} iterations.'
                          f' Final tol was {final_tol}')
            return e


class MatrixOperator(Operator):
    """Operator with an explicit matrix form."""

    def __init__(self, matrix, dimension=1, channels=1):
        super().__init__(matrix, matrix.T, dimension, channels)
        self.register_buffer('op', matrix)
        self.norm_val = torch.linalg.norm(matrix.T @ matrix, ord=2)

    def forward(self, inputs):
        return inputs @ self.op.T

    def T(self, inputs):
        return inputs @ self.op
