import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR
import torchvision

from convexrobust.model.modules import StandardMLP

from convexrobust.data import datamodules
from convexrobust.model.randsmooth_certifiable import RandsmoothCertifiable
from convexrobust.utils import dirs

import sys
sys.path.append('../../lib')
sys.path.append('../../lib/smoothingSplittingNoise')
from smoothingSplittingNoise.src.models import ResNet
from smoothingSplittingNoise.src.lib.wide_resnet import WideResNet


class RandsmoothCifar(RandsmoothCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('cifar10_catsdogs')
        self.model = WideResNet(depth=40, widen_factor=2, num_classes=2)

    def forward(self, x):
        return self.model.forward(self.normalize(x))

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epochs_n)
        # optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        # scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]


class RandsmoothMnist(RandsmoothCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('mnist_38')
        self.model = StandardMLP(784, 2, [200, 50])

    def forward(self, x):
        x = self.normalize(x)
        x = x.reshape(x.shape[0], -1)
        z = self.model.forward(x)
        return z

    def configure_optimizers(self):
        # optimizer = torch.optim.SGD(
            # self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True
        # )
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epochs_n)
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]


class RandsmoothKaggle(RandsmoothCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('kaggle_catsdogs')
        self.model = torchvision.models.resnet50(pretrained=False, num_classes=2)

        # self.model_import = ResNet('kaggle_catsdogs', 'cuda')

        # path = dirs.root_path('lib/smoothingSplittingNoise/checkpoints/cifar_split_derandomized_0.5.bk/model_ckpt.torch')
        # state_dict = torch.load(path)
        # if next(iter(state_dict)).startswith('module.'):
            # new_state_dict = OrderedDict([(k[7:], v) for k, v in state_dict.items()])
            # state_dict = new_state_dict
            # print('trimming state dict')
        # self.model_import.load_state_dict(state_dict)

    def forward(self, x):
        # return self.model_import.model(self.normalize(x))
        # return self.model_import.forward(x)
        return self.model.forward(self.normalize(x))

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epochs_n)
        # optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        # scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]


class RandsmoothMalimg(RandsmoothCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.normalize = datamodules.get_normalize_layer('malimg')

        self.model = torchvision.models.resnet18(pretrained=False, num_classes=2)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)

    def forward(self, x):
        return self.model.forward(self.normalize(x))

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epochs_n)
        # optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        # scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]


class RandsmoothSimple(RandsmoothCertifiable):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.mlp = MLP(2, 2, [200, 200]) # Don't project, it's just regular MLP

    def forward(self, x):
        batch_n = x.shape[0]
        x = x.reshape(batch_n, -1)
        z = self.mlp.forward(x)
        return z

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
        scheduler = ExponentialLR(optimizer, gamma=0.95)
        return [optimizer], [scheduler]
