import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.optim.lr_scheduler import ExponentialLR

from convexrobust.data import datamodules
from convexrobust.model.modules import StandardMLP
from convexrobust.model.cayley_certifiable import CayleyCertifiable

from convexrobust.utils import torch_utils

from lib.orthconv.models import ResNet9
from lib.orthconv.layers import CayleyLinear, CayleyConv, GroupSort


class CayleyMnist(CayleyCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('mnist_38', mean_only=True)
        self.model = StandardMLP(784, 2, [200, 50], linear=CayleyLinear, nonlin=GroupSort)

    def forward(self, x):
        z = self.normalize(x)
        return self.model(z.reshape(z.shape[0], -1))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return [optimizer]


class CayleyCifar(CayleyCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('cifar10_catsdogs', mean_only=True)
        self.model = ResNet9(out_n=2)

    def forward(self, x):
        z = self.normalize(x)
        return self.model(z)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return [optimizer]
