"""Operations for regression thingies."""


import torch
import math

from collections import namedtuple
import copy

OP_LIST = [
    'grad',
    'skip',
    'shuffle',
    'noise',
    'learnable_grad',
    'net',
    'grad_rep',
    'learnable_grad_rep',
    'grad_proj',
    'learnable_grad_proj_a',
    'learnable_grad_proj_b',
    'resnet_a',
    'resnet_b',
    'roll',
    'zero'
]

PRIMITIVES = [
    'grad',  # 'grad',
    'skip',
    'roll',
    'noise',
]

LEARNABLE_PRIMITIVES = [
    'learnable_grad',   # 'learnable_grad',
    'net',
    'roll',
    'noise'
]

GOOD_PRIMITIVES = [
    'grad',  # 'grad',
    'skip',
]

GOOD_LEARNABLE_PRIMITIVES = [
    'learnable_grad',   # 'learnable_grad',
    'net',
]

GOOD_IMPROVED_LEARNABLE_PRIMITIVES = [
    'learnable_grad_proj_b',   # 'learnable_grad',
    'net',
]

LEARNABLE_IMPROVED_PRIMITIVES = [
    'learnable_grad_proj_b',   # 'learnable_grad',
    'net',
    'roll',
    'noise'
]


BAD_PRIMITIVES = [
    'roll',
    'noise',
]


OPS = {
    'skip': lambda A, sublayers, channels: Identity(),

    'grad': lambda A, sublayers, channels: GradDescentStep(A, sublayers, channels),
    'learnable_grad': lambda A, sublayers, channels: RegularizedGradDescentStep(A, sublayers, channels),
    'grad_proj': lambda A, sublayers, channels: GradProjectionDescentStep(A, sublayers, channels),
    'learnable_grad_proj_a': lambda A, sublayers, channels: RegularizedGradProjectionDescentStepA(A, sublayers, channels),
    'learnable_grad_proj_b': lambda A, sublayers, channels: RegularizedGradProjectionDescentStepB(A, sublayers, channels),
    'grad_rep': lambda A, sublayers, channels: GradReplicateDescentStep(A, sublayers, channels),
    'learnable_grad_rep': lambda A, sublayers, channels: RegularizedGradReplicateDescentStep(A, sublayers, channels),

    'net': lambda A, sublayers, channels: DeepStepStraight(A, sublayers, channels),
    'resnet_a': lambda A, sublayers, channels: DeepStep(A, sublayers, channels),
    'resnet_b': lambda A, sublayers, channels: DeepStepResidual(A, sublayers, channels),

    'shuffle': lambda A, sublayers, channels: Identity_shuffle(),
    'roll': lambda A, sublayers, channels: Identity_roll(),
    'noise': lambda A, sublayers, channels: AddNoise(),

    'zero': lambda A, sublayers, channels: Zero(),
}

Genotype = namedtuple('Genotype', 'operations')


def _get_learnable_map(channels, sublayers=2, inner_channels=64, dimension=1):
    if dimension == 1:
        inner_channels = channels * 4
        if sublayers == 1:
            modules = [torch.nn.Conv1d(channels, channels, 3, padding=1)]
        else:
            modules = [torch.nn.Conv1d(channels, inner_channels, 3, padding=1)]
            for _ in range(sublayers - 2):
                modules += [torch.nn.ReLU(),
                            torch.nn.Conv1d(inner_channels, inner_channels, 3, padding=1)]
            modules += [torch.nn.ReLU(),
                        torch.nn.Conv1d(inner_channels, channels, 3, padding=1)]
    else:
        if sublayers == 1:
            modules = [torch.nn.Conv2d(channels, channels, 3, padding=1, bias=True)]
        else:
            # first layer
            modules = [torch.nn.Conv2d(channels, inner_channels, 3, padding=1, bias=False),
                       torch.nn.BatchNorm2d(inner_channels, momentum=0.9, eps=1e-04, affine=True),
                       torch.nn.ReLU(inplace=True)]
            # Intermediate layers
            for _ in range(sublayers - 2):
                modules += [torch.nn.Conv2d(inner_channels, inner_channels, 3, padding=1, bias=False),
                            torch.nn.BatchNorm2d(inner_channels, momentum=0.9, eps=1e-04, affine=True),
                            torch.nn.ReLU(inplace=True)]
            # Final layer
            modules += [torch.nn.Conv2d(inner_channels, channels, 3, padding=1, bias=True)]
    net = torch.nn.Sequential(*modules)
    # Initialize as dncnn
    for module in net.modules():
        if isinstance(module, torch.nn.Conv2d):
            torch.nn.init.orthogonal_(module.weight.data, gain=0.2)
        elif isinstance(module, torch.nn.BatchNorm2d):
            torch.nn.init.uniform_(module.weight.data, 0.1, 1.0)
            torch.nn.init.constant_(module.bias.data, 0.0)

    return net


class Identity(torch.nn.Module):
    """This layer is the identity in x."""

    def __init__(self, *args, **kwargs):
        """Returns x, discard parameters."""
        super().__init__()

    def forward(self, x, y, x0):
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class Identity_shuffle(torch.nn.Module):
    """This layer is the randomly shuffled identity in x."""

    def __init__(self, *args, **kwargs):
        """Returns x, discard parameters."""
        super().__init__()

    def forward(self, x, y, x0):
        idx = torch.randperm(x.nelement())
        x = x.view(-1)[idx].view(x.size())
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class Identity_roll(torch.nn.Module):
    """This layer is the randomly rolled identity in x."""

    def __init__(self, *args, **kwargs):
        """Returns x, discard parameters."""
        super().__init__()

    def forward(self, x, y, x0):
        offsets = [torch.randint(0, dim, (1,)) for dim in x.shape[1:]]
        x = torch.roll(x, offsets, list(range(1, x.dim())))
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class Zero(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        """Returns nothing, discard parameters."""
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x.mul(0.0)

    def reset_parameters(self):
        """Initialize module."""
        pass


class AddNoise(torch.nn.Module):
    """This layer is the randomly suffled identity in x."""

    def __init__(self, *args, **kwargs):
        """Returns x, discard parameters."""
        super().__init__()

    def forward(self, x, y, x0):
        x = x + torch.randn_like(x)
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class GradDescentStep(torch.nn.Module):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.01):
        """."""
        super().__init__()
        self.A = A  # Handle to the operator which implements .__call__() and .T()
        self.tau = 1 / self.A.norm()
        self.net = Zero()

    def forward(self, x, y, x0):
        x_reduced = x[:, :self.A.channels, ...].clone()
        x = x.clone()
        x[:, :self.A.channels, ...] = x_reduced - self.tau * (self.A.T(self.A(x_reduced + x0) - y) + self.net(x_reduced))
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class RegularizedGradDescentStep(GradDescentStep):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term) - tau * net(x)
    Here net(x) represents the learned nabla_x R_theta(x)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.1):
        """Initialize subnet with how many layers and channels?"""
        super().__init__(A, sublayers, channels, data, tau)
        self.net = _get_learnable_map(A.channels, sublayers, dimension=A.dimension)


class GradProjectionDescentStep(torch.nn.Module):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.01):
        """."""
        super().__init__()
        self.A = A  # Handle to the operator which implements .__call__() and .T()
        self.tau = 1 / self.A.norm()
        self.net = Zero()

        _Conv = torch.nn.Conv2d if A.dimension == 2 else torch.nn.Conv1d
        self.projection_in = _Conv(channels, A.channels, kernel_size=1, bias=False)  # bias??
        self.projection_out = _Conv(A.channels, channels, kernel_size=1, bias=False)  # bias??

    def forward(self, x, y, x0):
        x_in = self.projection_in(x)
        x_reduced = x_in - self.tau * (self.A.T(self.A(x_in + x0) - y) + self.net(x_in))
        return self.projection_out(x_reduced)

    def reset_parameters(self):
        """Initialize module."""
        pass


class RegularizedGradProjectionDescentStepA(GradProjectionDescentStep):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term) - tau * net(x)
    Here net(x) represents the learned nabla_x R_theta(x)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.1):
        """Initialize subnet with how many layers and channels?"""
        super().__init__(A, sublayers, channels, data, tau)
        self.net = _get_learnable_map(A.channels, sublayers, dimension=A.dimension)


class RegularizedGradProjectionDescentStepB(GradProjectionDescentStep):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term) - tau * net(x)
    Here net(x) represents the learned nabla_x R_theta(x)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.1):
        """Initialize subnet with how many layers and channels?"""
        super().__init__(A, sublayers, channels, data, tau)
        self.net = _get_learnable_map(channels, sublayers, dimension=A.dimension)

    def forward(self, x, y, x0):
        x_in = self.projection_in(x)
        x_reduced = x_in - self.tau * self.A.T(self.A(x_in + x0) - y)
        return self.projection_out(x_reduced) - self.tau * self.net(x)


class GradReplicateDescentStep(torch.nn.Module):
    """Take a step of gradient descent, replicate this on all channels.

    This layer does x - tau * nabla_x(data_term)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.01):
        """."""
        super().__init__()
        self.A_small = A

        self.A = copy.deepcopy(A)  # Handle to the operator which implements .__call__() and .T()
        if A.dimension == 2:
            if not isinstance(self.A.fwd_operator, torch.nn.Identity):
                fwd_op = self.A.fwd_operator
                if not isinstance(self.A.adj_operator, torch.nn.Sequential):  # todo: refactor into something sane if this op works
                    adj_op = self.A.adj_operator
                else:
                    adj_op = self.A.adj_operator[1]
                fwd_op.groups = channels
                adj_op.groups = channels
                fwd_op.weight = fwd_op.weight.repeat(math.ceil(channels / A.channels), 1, *[1] * A.dimension)[:channels]
                adj_op.weight = adj_op.weight.repeat(math.ceil(channels / A.channels), 1, *[1] * A.dimension)[:channels]
        else:
            if self.A.channels == 1:
                # matmul broadcasts this case
                pass
            else:
                raise ValueError('Todo: Implement this case.')

        self.channels = channels

        self.tau = 1 / self.A.norm()
        self.net = Zero()

    def forward(self, x, y, x0):
        if self.A.channels > 1:
            res = (y - self.A_small(x0)).repeat(1, math.ceil(self.channels / self.A.channels), *[1] * self.A.dimension)[:, :self.channels]
        else:
            res = y - self.A_small(x0)
        x = x - self.tau * (self.A.T(self.A(x) - res) + self.net(x))
        return x

    def reset_parameters(self):
        """Initialize module."""
        pass


class RegularizedGradReplicateDescentStep(GradReplicateDescentStep):
    """Take a step of gradient descent.

    This layer does x - tau * nabla_x(data_term) - tau * net(x)
    Here net(x) represents the learned nabla_x R_theta(x)
    """

    def __init__(self, A, sublayers=1, channels=1, data='euclidean', tau=0.1):
        """Initialize subnet with how many layers and channels?"""
        super().__init__(A, sublayers, channels, data, tau)
        self.net = _get_learnable_map(channels, sublayers, dimension=A.dimension)


class DeepStep(torch.nn.Module):
    """Just apply a learned operation."""

    def __init__(self, A, sublayers=1, channels=1):
        """Initialize subnet with how many layers and channels?"""
        super().__init__()
        self.A = A  # the operator
        self.net = _get_learnable_map(channels, sublayers, dimension=A.dimension)
        # Initialize:
        self.reset_parameters()

    def forward(self, x, y, x0):
        """This layer could recover a dncnn type architecture x0+n_0(n_1(x0)), if channels permitted."""
        return x - self.net(x)

    def reset_parameters(self):
        """Layer specific initialization instructions."""
        pass


class DeepStepResidual(DeepStep):
    """Just apply a learned operation."""

    def forward(self, x, y, x0):
        return x0 - self.net(x - x0)


class DeepStepStraight(DeepStep):
    """Just apply a learned operation."""

    def forward(self, x, y, x0):
        return self.net(x)
