import torch
from torch.optim.lr_scheduler import ExponentialLR

import math

from convexrobust.data import datamodules
from convexrobust.model.modules import ConvexConvNet, ConvexMLP
from convexrobust.model.convex_certifiable import ConvexCertifiable
from convexrobust.model.base_certifiable import Norm


class ConvexCifar(ConvexCertifiable):
    def __init__(self, convnet_params={}, augment_input=True, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('cifar10_catsdogs', mean_only=True)

        default_params = {
            'feature_n': 16, 'depth': 4,
            'conv_1_stride': 1, 'conv_1_kernel_size': 11, 'conv_1_dilation': 1,
            'deep_kernel_size': 3, 'pool_size': 1
        }

        combined_params = {**default_params, **convnet_params}  # potentially overwrite defaults

        self.augment_input = augment_input
        channel_n = 6 if augment_input else 3
        self.net_convex = ConvexConvNet(image_size=32, channel_n=channel_n, **combined_params)

        self.net_convex.init_project()

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.net_convex.project()

    def lipschitz_forward(self, x):
        x = self.normalize(x)
        if self.augment_input:
            x = torch.cat([x, x.abs()], dim=1)
            return x, {Norm.L1: 2, Norm.L2: math.sqrt(2), Norm.LInf: 1}
        else:
            return x, {Norm.L1: 1, Norm.L2: 1, Norm.LInf: 1}

    def convex_forward(self, x):
        return self.net_convex(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        scheduler = ExponentialLR(optimizer, gamma=0.99)
        return [optimizer], [scheduler]


class ConvexMnist(ConvexCertifiable):
    def __init__(self, mlp_params={}, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('mnist_38', mean_only=True)

        default_params = {
            'feature_ns': [200, 50], 'skip_connections': True, 'batchnorms': True
        }

        combined_params = {**default_params, **mlp_params}  # potentially overwrite defaults

        self.net_convex = ConvexMLP(in_n=784, out_n=1, **combined_params)

        self.net_convex.init_project()

    def lipschitz_forward(self, x):
        return x, {Norm.L1: 1, Norm.L2: 1, Norm.LInf: 1}

    def convex_forward(self, x):
        batch_n = x.shape[0]

        x = self.normalize(x)
        x = x.reshape(batch_n, -1)
        z = self.net_convex.forward(x, x)

        return z.squeeze(1)

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.net_convex.project()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        scheduler = ExponentialLR(optimizer, gamma=0.99)
        return [optimizer], [scheduler]


class ConvexKaggle(ConvexCertifiable):
    def __init__(self, convnet_params={}, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('kaggle_catsdogs', mean_only=True)

        default_params = {
            'feature_n': 32, 'depth': 4,
            'conv_1_stride': 1, 'conv_1_kernel_size': 15, 'conv_1_dilation': 2,
            'deep_kernel_size': 3, 'pool_size': 8
        }

        combined_params = {**default_params, **convnet_params}  # potentially overwrite defaults

        self.net_convex = ConvexConvNet(image_size=224, channel_n=6, **combined_params)

        self.net_convex.init_project()

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.net_convex.project()

    def lipschitz_forward(self, x):
        x = self.normalize(x)
        x = torch.cat([x, x.abs()], dim=1)
        return x, {Norm.L1: 2, Norm.L2: math.sqrt(2), Norm.LInf: 1}

    def convex_forward(self, x):
        return self.net_convex(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        scheduler = ExponentialLR(optimizer, gamma=0.99)
        return [optimizer], [scheduler]


class ConvexMalimg(ConvexCertifiable):
    def __init__(self, convnet_params={}, **kwargs):
        super().__init__(**kwargs)

        default_params = {
            'feature_n': 32, 'depth': 3,
            'conv_1_stride': 2, 'conv_1_kernel_size': 21, 'conv_1_dilation': 1,
            'deep_kernel_size': 3, 'pool_size': 4
        }

        combined_params = {**default_params, **convnet_params}  # potentially overwrite defaults

        self.net_convex = ConvexConvNet(image_size=512, channel_n=1, **combined_params)

        self.net_convex.init_project()

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.net_convex.project()

    def lipschitz_forward(self, x):
        return x, {Norm.L1: 1, Norm.L2: 1, Norm.LInf: 1}

    def convex_forward(self, x):
        return self.net_convex(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        scheduler = ExponentialLR(optimizer, gamma=0.99)
        return [optimizer], [scheduler]


class ConvexSimple(ConvexCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.mlp = ConvexMLP(
            # 2, 1, [50], batchnorms=True, skip_connections=True, in_n_orig=2
            4, 1, [50], batchnorms=True, skip_connections=True, in_n_orig=4
        )
        self.mlp.init_project()

    def convex_forward(self, x):
        batch_n = x.shape[0]
        x = x.reshape(batch_n, -1)

        z = self.mlp.forward(x, x)

        return z.squeeze(1)

    def lipschitz_forward(self, x):
        x = torch.cat([x, x.abs()], dim=1)
        return x, {Norm.L1: 2, Norm.L2: math.sqrt(2), Norm.LInf: 1}

    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.mlp.project()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
        scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]
