import torch
import numpy as np


class LinearOperator(torch.nn.Module):
    def __init__(self):
        super(LinearOperator, self).__init__()

    def forward(self, x):
        pass

    def adjoint(self, x):
        pass

    def gramian(self, x):
        return self.adjoint(self.forward(x))


class SelfAdjointLinearOperator(LinearOperator):
    def adjoint(self, x):
        return self.forward(x)


class Identity(SelfAdjointLinearOperator):
    def forward(self, x):
        return x


class OperatorPlusNoise(torch.nn.Module):
    def __init__(self, operator, noise_sigma):
        super(OperatorPlusNoise, self).__init__()
        self.internal_operator = operator
        self.noise_sigma = noise_sigma

    def forward(self, x):
        A_x = self.internal_operator(x)
        # noise_level = (torch.rand(1) * self.noise_sigma).to('cuda')
        # return A_x + noise_level * torch.randn_like(A_x)
        return A_x + self.noise_sigma * torch.randn_like(A_x)


# class OperatorPlusNoise_val(torch.nn.Module):
#     def __init__(self, operator, noise_sigma):
#         super(OperatorPlusNoise_val, self).__init__()
#         self.internal_operator = operator
#         self.noise_sigma = noise_sigma
#
#     def forward(self, x, seed):
#         A_x = self.internal_operator(x)
#         torch.manual_seed(seed=seed)
#         return A_x + self.noise_sigma * torch.randn_like(A_x)


def normalize(X, bs):
    maxVal, _ = torch.max(X.reshape(bs, -1), dim=1)
    minVal, _ = torch.min(X.reshape(bs, -1), dim=1)
    return (X - minVal[:, None, None]) / (maxVal - minVal)[:, None, None]
