import torch.nn as nn

from abc import ABC, abstractmethod
from typing import List

from convexrobust.utils import torch_utils


class ConvexModule(nn.Module, ABC):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @abstractmethod
    def forward(self, z, x):
        pass

    @abstractmethod
    def project(self):
        pass

    def init_project(self):
        self.project()


def project_weight_positive(weight):
    weight.data.clamp_(0.0)


def init_weight_positive(weight, linear=False, strategy='scaled'):
    if strategy == 'simple':
        weight.data.clamp_(0)
    elif strategy == 'scaled':
        if linear:
            weight.data.uniform_(0.0, 0.003)
        else:
            weight.data.uniform_(0.0, 0.005)


class ConvexMLP(ConvexModule):
    def __init__(self, in_n: int, out_n: int, feature_ns: List[int], nonlin=nn.ReLU,
                 in_n_orig=None, skip_connections=True, batchnorms=True):
        super().__init__()
        if in_n_orig is None:
            in_n_orig = in_n
        feature_ns = [in_n] + feature_ns + [out_n]
        self.layer_n = len(feature_ns) - 1

        self.skip_connections = skip_connections
        self.batchnorms = batchnorms
        self.nonlin = nonlin().to(torch_utils.device())

        W_z, W_x, nonlins, bns = [], [], [], []
        for i, (prev_feature_n, curr_feature_n) in enumerate(zip(feature_ns, feature_ns[1:])):
            W_z.append(nn.Linear(prev_feature_n, curr_feature_n, bias=not skip_connections))
            if self.skip_connections:
                W_x.append(nn.Linear(in_n_orig, curr_feature_n))

            if i < self.layer_n - 1:
                nonlins.append(nonlin().to(torch_utils.device()))
                if batchnorms:
                    bns.append(nn.BatchNorm1d(curr_feature_n))

        self.W_z = nn.ModuleList(W_z).to(torch_utils.device())
        self.W_x = nn.ModuleList(W_x).to(torch_utils.device())
        self.nonlins = nn.ModuleList(nonlins).to(torch_utils.device())
        self.bns = nn.ModuleList(bns).to(torch_utils.device())

    def forward(self, z, x):
        for i in range(self.layer_n):
            if self.skip_connections:
                z = self.W_z[i](z) + self.W_x[i](x)
            else:
                z = self.W_z[i](z)

            if i < self.layer_n - 1:
                z = self.nonlins[i](z)
                if self.batchnorms:
                    z = self.bns[i](z)

        return z

    def project(self):
        for W_z in self.W_z:
            project_weight_positive(W_z.weight)

    def init_project(self, strategy='scaled'):
        for W_z in self.W_z:
            init_weight_positive(W_z.weight, linear=True, strategy=strategy)


class ConvexConvNet(ConvexModule):
    def __init__(self, image_size=224, channel_n=3, feature_n=32, depth=5,
                 conv_1_stride=1, conv_1_kernel_size=15, conv_1_dilation=1,
                 deep_kernel_size=5, pool_size=1, skip_connections=False, nonlin=nn.ReLU):
        super().__init__()

        assert (conv_1_kernel_size % 2) == 1

        conv_1_padding = (conv_1_kernel_size // 2) * conv_1_dilation

        self.bn_1 = nn.BatchNorm2d(channel_n)
        self.conv_1 = nn.Conv2d(
            channel_n, feature_n, kernel_size=conv_1_kernel_size,
            stride=conv_1_stride, dilation=conv_1_dilation, padding=conv_1_padding
        )
        self.nonlin_1 = nonlin()

        self.blocks = nn.ModuleList(
            [ConvexBlock(feature_n, deep_kernel_size, nonlin) for _ in range(depth)]
        )

        self.skip_connections = skip_connections
        if skip_connections:
            self.skips = nn.ModuleList(
                [nn.Conv2d(channel_n, feature_n, kernel_size=deep_kernel_size,
                           padding=deep_kernel_size // 2) for _ in range(depth)]
            )

        self.max_pool = nn.MaxPool2d(pool_size, pool_size)
        self.bn_last = nn.BatchNorm2d(feature_n)

        final_image_size = image_size // (pool_size * conv_1_stride)
        self.readout = nn.Linear(feature_n * (final_image_size ** 2), 1, bias=True)

    def forward(self, x):
        batch_n = x.shape[0]

        x = self.bn_1(x)
        z = self.conv_1(x)
        z = self.nonlin_1(z)

        for i, block in enumerate(self.blocks):
            z = block(z)
            if self.skip_connections:
                z = z + 0.1 * self.skips[i](x)

        z = self.max_pool(z)
        z = self.bn_last(z)
        z = self.readout(z.reshape(batch_n, -1))

        return z.squeeze(1)

    def project(self):
        for block in self.blocks:
            project_weight_positive(block.conv.weight)
            block.bn.weight.data.clamp_(0.05)

        self.bn_last.weight.data.clamp_(0.05)
        project_weight_positive(self.readout.weight)

    def init_project(self, strategy='scaled'):
        for block in self.blocks:
            init_weight_positive(block.conv.weight, strategy=strategy)

        init_weight_positive(self.readout.weight, linear=True, strategy=strategy)


class ConvexBlock(nn.Module):
    def __init__(self, channels, kernel_size, nonlin):
        super().__init__()
        self.bn = nn.BatchNorm2d(channels)
        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size,
                              padding=kernel_size // 2)
        self.nonlin = nonlin()

    def forward(self, x):
        z = self.bn(x)
        z = self.conv(z) + z
        z = self.nonlin(z)

        return z


class StandardMLP(nn.Module):
    def __init__(self, in_n: int, out_n: int, feature_ns: List[int],
                 linear=nn.Linear, nonlin=nn.ReLU):
        super().__init__()
        feature_ns = [in_n] + feature_ns
        self.nonlin = nn.ReLU().to(torch_utils.device())

        layers = []
        for (prev_feature_n, curr_feature_n) in zip(feature_ns, feature_ns[1:]):
            layers.append(linear(prev_feature_n, curr_feature_n))
            layers.append(nonlin())
        layers.append(linear(feature_ns[-1], out_n))

        self.sequential = nn.Sequential(*layers)

    def forward(self, x):
        return self.sequential(x)
