import torch
import torch.nn as nn
import torch.nn.functional as F

import deepinv as dinv
from deepinv.physics import Physics, LinearPhysics, Downsampling
from deepinv.utils import TensorList
from deepinv.utils.tensorlist import TensorList

# from huggingface_hub import hf_hub_download

import pdb

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def finetune(model, data, physics, supervised=False, validation=None, max_iter=50, noise_loss='SURE', transform='shift', lr=1e-4,
                 batch_size=1, device='cuda'):
    r"""
    Finetune a model on a dataset.
    
    :param model: RAM Model to finetune.
    :param data: Dataset to finetune on, it can be simply a tensor of measurements or
        a torch.utils.data.Dataset giving measurements (only self-supervised) or ground-truth references and measurements (supervised possible).
    :param physics: Physics model to use for the finetuning.
    :param supervised: If True, the model will be finetuned in a supervised way, otherwise it will be self-supervised (no ground-truth).
    :param validation: Validation dataset to use for early stopping, if None the validation set will be split automatically from the training set.
    :param max_iter: Maximum number of epochs to run the finetuning.
    :param noise_loss: Type of noise loss to use, can be 'noiseless', 'SURE' or 'splitting'.
    :param transform: Type of transformation for the Equivariant Imaging loss, can be 'shift' or 'rotate' (90 degrees).
    :param lr: Learning rate for the optimizer.
    :param batch_size: Batch size for the dataloader.
    :param device: Device to use for the finetuning.
    :return: The finetuned model.
    """

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    data = to_dataset(data)

    if validation is None:
        if len(data) > 1:
            train_split = int(len(data) * .8)
            data, validation = torch.utils.data.random_split(data, [train_split, len(data) - train_split])
        else:
            print('Warning: No validation set provided, use more than a single image to improve validation.')
            validation = data
    else:
        validation = to_dataset(validation)

    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(validation, batch_size=batch_size, shuffle=False)


    denoising = isinstance(physics, dinv.physics.Denoising)

    if supervised:
        losses = [dinv.loss.SupLoss()]
    else:
        losses = []
        if noise_loss == 'noiseless' or not isinstance(physics.noise_model, dinv.physics.NoiseModel):
            mc_loss = dinv.loss.MCLoss()
        elif noise_loss == 'SURE':
            if isinstance(physics.noise_model, dinv.physics.GaussianNoise):
                mc_loss = dinv.loss.SureGaussianLoss(physics.noise_model.sigma)
            elif isinstance(physics.noise_model, dinv.physics.PoissonNoise):
                mc_loss = dinv.loss.SurePoissonLoss(physics.noise_model.gain)
            elif isinstance(physics.noise_model, dinv.physics.PoissonGaussianNoise):
                mc_loss = dinv.loss.SurePGLoss(gain=physics.noise_model.gain, sigma=physics.noise_model.sigma)
        else:
            mc_loss = dinv.loss.SplittingLoss(split_ratio=.9)

        losses.append(mc_loss)

        if not denoising and transform is not None:
            if transform == 'shift':
                t = dinv.transform.Shift(shift_max=.1)
            elif transform == 'rotate':
                t = dinv.transform.Rotate(multiples=90)
            else:
                raise ValueError(f"Unknown transform: {transform}")
            losses.append(dinv.loss.EILoss(t, weight=.1))

    print('Finetuning with losses: ' + str([l.__class__.__name__ for l in losses]))

    batches_per_epoch = len(data) // batch_size
    eval_interval = max(3 // batches_per_epoch, 1) # do at least 3 gradient steps between evals
    trainer = dinv.Trainer(model=model, physics=physics, eval_interval=eval_interval, ckp_interval=max_iter-1,
                           metrics=losses[0], early_stop=True, device=device,
                           losses=losses, epochs=max_iter, optimizer=optimizer, train_dataloader=dataloader, eval_dataloader=val_dataloader)

    # finetune
    trainer.train()
    # return best model
    return trainer.load_best_model()


class UnsupDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return torch.nan, self.data[idx]


def to_dataset(x):
    r"""
    Converts a tensor or list to a dataset.

    :param torch.Tensor x: Input tensor.
    :return: Dataset object.
    """
    if isinstance(x, Tensor):
        x = UnsupDataset(x)
    elif isinstance(x, list):
        x = torch.utils.data.TensorDataset(*x)
    elif isinstance(x, torch.utils.data.Dataset):
        pass
    else:
        raise ValueError("Input must be a dataset, tensor or a list of tensors.")
    return x

class RAM(nn.Module):
    r"""
    RAM model

    This model is a convolutional neural network (CNN) designed for image reconstruction tasks.

    :param in_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
    :param device: Device to which the model should be moved. If None, the model will be created on the default device.
    :param pretrained: If True, the model will be initialized with pretrained weights.
    """
    def __init__(
        self,
        in_channels=[1, 2, 3],
        device=None,
        pretrained=True,
    ):
        super(RAM, self).__init__()

        nc = [64, 128, 256, 512] # number of channels in the network
        self.in_channels = in_channels
        self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device))

        self.separate_head = isinstance(in_channels, list)

        if isinstance(in_channels, list):
            in_channels_first = []
            for i in range(len(in_channels)):
                in_channels_first.append(in_channels[i] + 2)

        # check if in_channels is a list
        self.m_head = InHead(in_channels_first, nc[0])

        self.m_down1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)
        self.m_down2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
        self.m_down3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
        self.m_body = BaseEncBlock(nc[3], nc[3], img_channels=in_channels, decode_upscale=8)
        self.m_up3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
        self.m_up2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
        self.m_up1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)

        self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2")
        self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2")
        self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2")
        self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2")
        self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2")
        self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2")

        self.m_tail = OutTail(nc[0], in_channels)

        # load pretrained weights from hugging face
        # if pretrained:
        #     self.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))

        if device is not None:
            self.to(device)

    def constant2map(self, value, x):
        r"""
        Converts a constant value to a map of the same size as the input tensor x.

        :params float value: constant value
        :params torch.Tensor x: input tensor
        """
        if isinstance(value, torch.Tensor):
            if value.ndim > 0:
                value_map = value.view(x.size(0), 1, 1, 1)
                value_map = value_map.expand(-1, 1, x.size(2), x.size(3))
            else:
                value_map = torch.ones(
                    (x.size(0), 1, x.size(2), x.size(3)), device=x.device
                ) * value[None, None, None, None].to(x.device)
        else:
            value_map = (
                torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device)
                * value
            )
        return value_map

    def base_conditioning(self, x, sigma, gamma):
        noise_level_map = self.constant2map(sigma, x)
        gamma_map = self.constant2map(gamma, x)
        return torch.cat((x, noise_level_map, gamma_map), 1)

    def realign_input(self, x, physics, y):
        r"""
        Realign the input x based on the measurements y and the physics model.
        Applies the proximity operator of the L2 norm with respect to the physics model.

        :params torch.Tensor x: Input tensor
        :params deepinv.physics.Physics physics: Physics model
        :params torch.Tensor y: Measurements
        """
        if hasattr(physics, "factor"):
            f = physics.factor
        elif hasattr(physics, "base") and hasattr(physics.base, "factor"):
            f = physics.base.factor
        elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"):
            f = physics.base.base.factor
        else:
            f = 1.0

        sigma = 1e-6  # default value
        if hasattr(physics.noise_model, 'sigma'):
            sigma = physics.noise_model.sigma
        if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'):
            sigma = physics.base.noise_model.sigma
        if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'):
            sigma = physics.base.base.noise_model.sigma

        if isinstance(y, TensorList):
            num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1))
        else:
            num = (y.reshape(y.shape[0], -1).abs().mean(1))

        snr = num / (sigma + 1e-4)  # SNR equivariant
        gamma = 1 / (1e-4 + 1 / (snr * f **2 ))  # TODO: check square-root / mean / check if we need to add a factor in front ?
        gamma = gamma[(...,) + (None,) * (x.dim() - 1)]
        model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign)

        return model_input

    def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
        r"""
        Forward pass of the UNet model.

        :params torch.Tensor x0: init image
        :params float sigma: Gaussian noise level
        :params float gamma: Poisson noise gain
        :params deepinv.physics.Physics physics: physics measurement operator
        :params torch.Tensor y: measurements
        """
        img_channels = x0.shape[1]
        physics = MultiScaleLinearPhysics(physics, x0.shape[-3:], device=x0.device)

        if self.separate_head and img_channels not in self.in_channels:
            raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.")

        if y is not None:
            x0 = self.realign_input(x0, physics, y)

        x0 = self.base_conditioning(x0, sigma, gamma)

        x1 = self.m_head(x0)

        x1_ = self.m_down1(x1, physics=physics, y=y, img_channels=img_channels, scale=0)
        x2 = self.pool1(x1_)

        x3_ = self.m_down2(x2, physics=physics, y=y, img_channels=img_channels, scale=1)
        x3 = self.pool2(x3_)

        x4_ = self.m_down3(x3, physics=physics, y=y, img_channels=img_channels, scale=2)
        x4 = self.pool3(x4_)

        x = self.m_body(x4, physics=physics, y=y, img_channels=img_channels, scale=3)

        x = self.up3(x + x4)
        x = self.m_up3(x, physics=physics, y=y, img_channels=img_channels, scale=2)

        x = self.up2(x + x3)
        x = self.m_up2(x, physics=physics, y=y, img_channels=img_channels, scale=1)

        x = self.up1(x + x2)
        x = self.m_up1(x, physics=physics, y=y, img_channels=img_channels, scale=0)

        x = self.m_tail(x + x1, img_channels)

        return x


    def forward(self, y=None, physics=None, sigma=None, gamma=0):
        r"""
        Reconstructs a signal estimate from measurements y
        :param torch.tensor y: measurements
        :param deepinv.physics.Physics physics: forward operator
        """
        if physics is None:
            physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)

        # x_temp = physics.A_adjoint(y)
        # pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
        # physics = Pad(physics, pad)

        # x_in = physics.A_adjoint(y)
        x_in = y

        sigma = sigma #physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3
        # gamma = physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3

        out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y)

        # out = physics.remove_pad(out)

        return out



### --------------- MODEL ---------------
class BaseEncBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bias=False, nb=4, img_channels=None, decode_upscale=None):
        super(BaseEncBlock, self).__init__()
        self.enc = nn.ModuleList(
            [
                ResBlock(
                    in_channels,
                    out_channels,
                    bias=bias,
                    img_channels=img_channels,
                    decode_upscale=decode_upscale,
                )
                for _ in range(nb)
            ]
        )

    def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
        for i in range(len(self.enc)):
            x = self.enc[i](x, physics=physics, y=y, img_channels=img_channels, scale=scale)
        return x


def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None):
    r"""
    Efficient Krylov subspace embedding computation with parallel processing.

    :params torch.Tensor y: Input tensor.
    :params p: An object with A and A_adjoint methods (linear operator).
    :params float factor: Scaling factor.
    :params torch.Tensor v: Precomputed values to subtract from Krylov sequence. Defaults to None.
    :params int N: Number of Krylov iterations. Defaults to 4.
    :params torch.Tensor x_init: Initial guess. Defaults to None.
    """

    if x_init is None:
        x = p.A_adjoint(y)
    else:
        x = x_init.clone()  # Extract the first img_channels

    norm = factor ** 2  # Precompute normalization factor
    AtA = lambda u: p.A_adjoint(p.A(u)) * norm  # Define the linear operator

    v = v if v is not None else torch.zeros_like(x)

    out = x.clone()
    # Compute Krylov basis
    x_k = x.clone()
    for i in range(N-1):
        x_k = AtA(x_k) - v
        out = torch.cat([out, x_k], dim=1)

    return out


class MeasCondBlock(nn.Module):
    r"""
    Measurement conditioning block for the RAM model.

    :param out_channels: Number of output channels.
    :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
    :param decode_upscale: Upscaling factor for the decoding convolution.
    :param N: Number of Krylov iterations.
    :param depth_encoding: Depth of the encoding convolution.
    :param c_mult: Multiplier for the number of channels.
    """
    def __init__(self, out_channels=64, img_channels=None, decode_upscale=None, N=4, depth_encoding=1, c_mult=1):
        super(MeasCondBlock, self).__init__()

        self.separate_head = isinstance(img_channels, list)

        assert img_channels is not None, "decode_dimensions should be provided"
        assert decode_upscale is not None, "decode_upscale should be provided"

        self.N = N
        self.c_mult = c_mult
        self.relu_encoding = nn.ReLU(inplace=False)
        self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
        self.encoding_conv = Heads(img_channels, out_channels,  depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=False, skip_in=True)

        self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
        self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
        self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
        self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
        self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)

    def forward(self, x, y, physics, img_channels=None, scale=1):
        physics.set_scale(scale)
        dec = self.decoding_conv(x, img_channels)
        factor = 2**(scale)
        meas_y = krylov_embeddings(y, physics, factor, N=self.N)
        meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...])
        for c in range(1, self.c_mult):
            meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)])
            meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
        meas = torch.cat([meas_y, meas_dec], dim=1)
        cond = self.encoding_conv(meas)
        emb = self.relu_encoding(cond)
        return emb


class ResBlock(nn.Module):
    r"""
    Convolutional residual block.

    :param in_channels: Number of input channels.
    :param out_channels: Number of output channels.
    :param kernel_size: Size of the convolution kernel.
    :param stride: Stride of the convolution.
    :param padding: Padding for the convolution.
    :param bias: Whether to use bias in the convolution.
    :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
    :param decode_upscale: Upscaling factor for the decoding convolution.
    :param head: Whether this is a head block.
    :param tail: Whether this is a tail block.
    :param N: Number of Krylov iterations.
    :param c_mult: Multiplier for the number of channels.
    :param depth_encoding: Depth of the encoding convolution.
    """
    def __init__(
        self,
        in_channels=64,
        out_channels=64,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=True,
        img_channels=None,
        decode_upscale=None,
        head=False,
        tail=False,
        N=2,
        c_mult=2,
        depth_encoding=2,
    ):
        super(ResBlock, self).__init__()

        if not head and not tail:
            assert in_channels == out_channels, "Only support in_channels==out_channels."
        self.separate_head = isinstance(img_channels, list)
        self.is_head = head
        self.is_tail = tail

        if self.is_head:
            self.head = InHead(img_channels, out_channels, input_layer=True)

        if not self.is_head and not self.is_tail:
            self.conv1 = conv(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias,
                "C",
            )
            self.nl = nn.ReLU(inplace=True)
            self.conv2 = conv(
                out_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias,
                "C",
            )

        self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
        self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, c_mult=c_mult,
                                          img_channels=img_channels, decode_upscale=decode_upscale,
                                          N=N, depth_encoding=depth_encoding)

    def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
        u = self.conv1(x)
        u = self.nl(u)
        u_2 = self.conv2(u)
        emb_grad = self.PhysicsBlock(u, y, physics, img_channels=img_channels, scale=scale)
        u_1 = self.gain * emb_grad
        return x + u_2 + u_1


class InHead(torch.nn.Module):
    def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False):
        super(InHead, self).__init__()
        self.in_channels_list = in_channels_list
        self.input_layer = input_layer
        for i, in_channels in enumerate(in_channels_list):
            conv = AffineConv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                bias=bias,
                mode=mode,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode="zeros",
            )
            setattr(self, f"conv{i}", conv)

    def forward(self, x):
        in_channels = x.size(1) - 1 if self.input_layer else x.size(1)

        # find index
        i = self.in_channels_list.index(in_channels)
        x = getattr(self, f"conv{i}")(x)

        return x

class OutTail(torch.nn.Module):
    def __init__(self, in_channels, out_channels_list, mode="", bias=False):
        super(OutTail, self).__init__()
        self.in_channels = in_channels
        self.out_channels_list = out_channels_list
        for i, out_channels in enumerate(out_channels_list):
            conv = AffineConv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                bias=bias,
                mode=mode,
                kernel_size=3,
                stride=1,
                padding=1,
                padding_mode="zeros",
            )
            setattr(self, f"conv{i}", conv)

    def forward(self, x, out_channels):
        i = self.out_channels_list.index(out_channels)
        x = getattr(self, f"conv{i}")(x)

        return x

class Heads(torch.nn.Module):
    def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False):
        super(Heads, self).__init__()
        self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list]
        self.scale = scale
        self.mode = mode
        for i, in_channels in enumerate(self.in_channels_list):
            setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))

        if self.mode == "":
            self.nl = torch.nn.ReLU(inplace=False)
            if self.scale != 1:
                for i, in_channels in enumerate(in_channels_list):
                    setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))

    def forward(self, x):
        in_channels = x.size(1)
        i = self.in_channels_list.index(in_channels)

        if self.scale != 1:
            if self.mode == "bilinear":
                x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False)
            else:
                x = getattr(self, f"down{i}")(x)
                x = self.nl(x)

        # find index
        x = getattr(self, f"head{i}")(x)

        return x

class Tails(torch.nn.Module):
    def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False):
        super(Tails, self).__init__()
        self.out_channels_list = out_channels_list
        self.scale = scale
        for i, out_channels in enumerate(out_channels_list):
            setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))

        self.mode = mode
        if self.mode == "":
            self.nl = torch.nn.ReLU(inplace=False)
            if self.scale != 1:
                for i, out_channels in enumerate(out_channels_list):
                    setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale)))

    def forward(self, x, out_channels):
        i = self.out_channels_list.index(out_channels)
        x = getattr(self, f"tail{i}")(x)
        # find index
        if self.scale != 1:
            if self.mode == "bilinear":
                x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
            else:
                x = getattr(self, f"up{i}")(x)

        return x

class HeadBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False):
        super(HeadBlock, self).__init__()

        padding = kernel_size // 2

        c = out_channels if depth < 2 else in_channels

        self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias)
        self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False)
        self.depth = depth
        self.nl_1 = torch.nn.ReLU(inplace=False)
        self.nl_2 = torch.nn.ReLU(inplace=False)
        self.relu_in = relu_in
        self.skip_in = skip_in

        for i in range(depth-1):
            if i < depth - 2:
                c_in, c = in_channels, in_channels
            else:
                c_in, c = in_channels, out_channels

            setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias))
            setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias))
            setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False))


    def forward(self, x):

        if self.skip_in and self.relu_in:
            x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x)
        elif self.skip_in and not self.relu_in:
            x = self.convin(x) + self.zero_conv_skip(x)
        else:
            x = self.convin(x)

        for i in range(self.depth-1):
            aux = getattr(self, f"conv1{i}")(x)
            aux = self.nl_2(aux)
            aux_0 = getattr(self, f"conv2{i}")(aux)
            aux_1 = getattr(self, f"skipconv{i}")(x)
            x = aux_0 + aux_1

        return x


# --------------------------------------------------------------------------------------
class AffineConv2d(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        mode="affine",
        bias=False,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        padding_mode="circular",
        blind=True,
    ):
        if mode == "affine":  # f(a*x + 1) = a*f(x) + 1
            bias = False
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            bias=bias,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            padding_mode=padding_mode,
        )
        self.blind = blind
        self.mode = mode

    def affine(self, w):
        """returns new kernels that encode affine combinations"""
        return (
            w.view(self.out_channels, -1).roll(1, 1).view(w.size())
            - w
            + 1 / w[0, ...].numel()
        )

    def forward(self, x):
        if self.mode != "affine":
            return super().forward(x)
        else:
            kernel = (
                self.affine(self.weight)
                if self.blind
                else torch.cat(
                    (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]),
                    dim=1,
                )
            )
            padding = tuple(
                elt for elt in reversed(self.padding) for _ in range(2)
            )  # used to translate padding arg used by Conv module to the ones used by F.pad
            padding_mode = (
                self.padding_mode if self.padding_mode != "zeros" else "constant"
            )  # used to translate padding_mode arg used by Conv module to the ones used by F.pad
            return F.conv2d(
                F.pad(x, padding, mode=padding_mode),
                kernel,
                stride=self.stride,
                dilation=self.dilation,
                groups=self.groups,
            )

"""
Functional blocks below

Parts of code borrowed from
https://github.com/cszn/DPIR/tree/master/models
https://github.com/xinntao/BasicSR
"""
from collections import OrderedDict
import torch
import torch.nn as nn


"""
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
"""


def sequential(*args):
    """Advanced nn.Sequential.
    Args:
        nn.Sequential, nn.Module
    Returns:
        nn.Sequential
    """
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError("sequential does not support OrderedDict input.")
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


def conv(
    in_channels=64,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
    bias=True,
    mode="CBR",
):
    L = []
    for t in mode:
        if t == "C":
            L.append(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=bias,
                )
            )
        elif t == "T":
            L.append(
                nn.ConvTranspose2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=bias,
                )
            )
        elif t == "R":
            L.append(nn.ReLU(inplace=True))
        else:
            raise NotImplementedError("Undefined type: ".format(t))
    return sequential(*L)


# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(
    in_channels=64,
    out_channels=3,
    padding=0,
    bias=True,
    mode="2R",
):
    assert len(mode) < 4 and mode[0] in [
        "2",
        "3",
        "4",
        "8",
    ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
    kernel_size = int(mode[0])
    stride = int(mode[0])
    mode = mode.replace(mode[0], "T")
    up1 = conv(
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias,
        mode,
    )
    return up1


def downsample_strideconv(
    in_channels=64,
    out_channels=64,
    padding=0,
    bias=True,
    mode="2R",
):
    assert len(mode) < 4 and mode[0] in [
        "2",
        "3",
        "4",
        "8",
    ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
    kernel_size = int(mode[0])
    stride = int(mode[0])
    mode = mode.replace(mode[0], "C")
    down1 = conv(
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias,
        mode,
    )
    return down1



class Upsampling(Downsampling):
    def A(self, x, **kwargs):
        return super().A_adjoint(x, **kwargs)

    def A_adjoint(self, y, **kwargs):
        return super().A(y, **kwargs)

    def prox_l2(self, z, y, gamma, **kwargs):
        return super().prox_l2(z, y, gamma, **kwargs)


class MultiScalePhysics(Physics):
    def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs):
        super().__init__(noise_model=physics.noise_model, **kwargs)
        self.base = physics
        self.scales = scales
        self.img_shape = img_shape
        self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales]
        self.scale = 0

    def set_scale(self, scale):
        if scale is not None:
            self.scale = scale

    def A(self, x, scale=None, **kwargs):
        self.set_scale(scale)
        if self.scale == 0:
            return self.base.A(x, **kwargs)
        else:
            return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs)

    def downsample(self, x, scale=None):
        self.set_scale(scale)
        if self.scale == 0:
            return x
        else:
            return self.Upsamplings[self.scale - 1].A_adjoint(x)

    def upsample(self, x, scale=None):
        self.set_scale(scale)
        if self.scale == 0:
            return x
        else:
            return self.Upsamplings[self.scale - 1].A(x)

    def update_parameters(self, **kwargs):
        self.base.update_parameters(**kwargs)


class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics):
    def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs):
        super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs)

    def A_adjoint(self, y, scale=None, **kwargs):
        self.set_scale(scale)
        y = self.base.A_adjoint(y, **kwargs)
        if self.scale == 0:
            return y
        else:
            return self.Upsamplings[self.scale - 1].A_adjoint(y)


class Pad(LinearPhysics):
    def __init__(self, physics, pad):
        super().__init__(noise_model=physics.noise_model)
        self.base = physics
        self.pad = pad

    def A(self, x):
        return self.base.A(x[..., self.pad[0]:, self.pad[1]:])

    def A_adjoint(self, y):
        y = self.base.A_adjoint(y)
        y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0))
        return y

    def remove_pad(self, x):
        return x[..., self.pad[0]:, self.pad[1]:]

    def update_parameters(self, **kwargs):
        self.base.update_parameters(**kwargs)