import types
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import wide_resnet50_2, wide_resnet101_2, vgg16
from torchvision.models.resnet import _resnet, Bottleneck

from .vaes import *

from .wideresnet import *
from .resnet import resnet18, resnet50, resnet152, resnet101, resnext101_32x8d
from .resnet import ResNet50Layer3
from .vgg import Vgg19, Vgg16, Vgg16BN, Vgg19BN, Vgg16Norm02
from .mlps import VggMLP, MLP, MLPv2, LargeMLP, LargeMLPv2
from .meminf import shadow_attack_model, whitebox_attack_model
from .alt_resnet import (
    altResNet20,
    altResNet20Norm02,
    altResNet32,
    altResNet32Norm02,
    altResNet110,
    altResNet110Norm02,
)


def tvgg16(n_features, n_classes, n_channels):
    return vgg16(pretrained=False, num_classes=n_classes)

def tWRN50_2(n_features, n_classes, n_channels):
    return wide_resnet50_2(num_classes=n_classes)

def tWRN101_2(n_features, n_classes, n_channels):
    return wide_resnet101_2(num_classes=n_classes)

def wide_resnet50_4(pretrained=False, progress=True, **kwargs):
    kwargs['width_per_group'] = 64 * 4
    return _resnet('wide_resnet50_4', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)

def wide_resnet50_5(pretrained=False, progress=True, **kwargs):
    kwargs['width_per_group'] = 64 * 5
    return _resnet('wide_resnet50_5', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)

def tWRN50_4(n_features, n_classes, n_channels):
    return wide_resnet50_4(num_classes=n_classes)

def tWRN50_5(n_features, n_classes, n_channels):
    return wide_resnet50_5(num_classes=n_classes)


####################
###### ResNet ######
####################

def preResNet50Norm02(n_features, n_classes, n_channels):
    resnet = resnet50(pretrained=True, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.4914, 0.4822, 0.4465), normalize_std=(0.2023, 0.1994, 0.2010))
                      #normalize_mean=(0.485, 0.456, 0.406),
                      #normalize_std=(0.229, 0.224, 0.225))
    return resnet

def preResNet18Norm02(n_features, n_classes, n_channels):
    resnet = resnet18(pretrained=True, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.4914, 0.4822, 0.4465), normalize_std=(0.2023, 0.1994, 0.2010))
    return resnet

def ResNet50Norm01(n_features, n_classes, n_channels):
    resnet = resnet50(pretrained=False, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.485, 0.456, 0.406),
                      normalize_std=(0.229, 0.224, 0.225))
    return resnet

def ResNet50Norm02(n_features, n_classes, n_channels):
    resnet = resnet50(pretrained=False, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.4914, 0.4822, 0.4465), normalize_std=(0.2023, 0.1994, 0.2010))
    return resnet

def ResNet18Norm01(n_features, n_classes, n_channels):
    resnet = resnet18(pretrained=False, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.485, 0.456, 0.406),
                      normalize_std=(0.229, 0.224, 0.225))
    return resnet

def ResNet18Norm02(n_features, n_classes, n_channels):
    resnet = resnet18(pretrained=False, n_channels=n_channels, num_classes=n_classes,
                      normalize_mean=(0.4914, 0.4822, 0.4465), normalize_std=(0.2023, 0.1994, 0.2010))
    return resnet

def ResNet18(n_features, n_classes, n_channels):
    resnet = resnet18(pretrained=False, n_channels=n_channels, num_classes=n_classes)
    return resnet

def ResNet101(n_features, n_classes, n_channels):
    resnet = resnet101(pretrained=False, n_channels=n_channels, num_classes=n_classes)
    return resnet

def preResNet50(n_classes, n_channels):
    resnet = resnet50(pretrained=True, n_channels=n_channels, num_classes=n_classes)
    return resnet

def ResNet50(n_features, n_classes, n_channels):
    resnet = resnet50(pretrained=False, n_channels=n_channels, num_classes=n_classes)
    return resnet

def ResNet152(n_classes, n_channels):
    resnet = resnet152(pretrained=False, n_channels=n_channels, num_classes=n_classes)
    return resnet

def preResNeXt101(n_classes, n_channels):
    resnet = resnext101_32x8d(pretrained=True, n_channels=n_channels, num_classes=n_classes)
    return resnet



class LR(nn.Module):
    """Logistic Regression."""

    def __init__(self, n_features, n_classes, n_channels=None):
        super(MLP, self).__init__()
        self.fc = nn.Linear(n_features, n_classes)

    def forward(self, x):
        x = self.fc(x)
        return x


class CNN001(nn.Module):
    def __init__(self, n_features, n_classes, n_channels=None, save_intermediates=False):
        super(CNN001, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, n_classes)

        self.save_intermediates = save_intermediates
        self.intermediates = []

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

class CNN002(nn.Module):
    """https://github.com/yaodongyu/TRADES/blob/e20f7b9b99c79ed3cf0d1bb12a47c229ebcac24a/models/small_cnn.py#L5"""
    def __init__(self, n_features, n_classes, drop=0.5, n_channels=1, save_intermediates=False):
        super(CNN002, self).__init__()

        self.num_channels = n_channels

        self.save_intermediates = save_intermediates
        self.intermediates = []

        activ = nn.ReLU(True)

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(self.num_channels, 32, 3)),
            ('relu1', activ),
            ('conv2', nn.Conv2d(32, 32, 3)),
            ('relu2', activ),
            ('maxpool1', nn.MaxPool2d(2, 2)),
            ('conv3', nn.Conv2d(32, 64, 3)),
            ('relu3', activ),
            ('conv4', nn.Conv2d(64, 64, 3)),
            ('relu4', activ),
            ('maxpool2', nn.MaxPool2d(2, 2)),
        ]))

        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(64 * 4 * 4, 200)),
            ('relu1', activ),
            ('drop', nn.Dropout(drop)),
            ('fc2', nn.Linear(200, 200)),
            ('relu2', activ),
            ('fc3', nn.Linear(200, n_classes)),
        ]))

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        nn.init.constant_(self.classifier.fc3.weight, 0)
        nn.init.constant_(self.classifier.fc3.bias, 0)

    def get_repr(self, x, mode="cnn_fet"):
        if mode == "cnn_fet":
            features = self.feature_extractor(x)
            return features.view(-1, 64 * 4 * 4)
        elif mode == "last":
            x = self.feature_extractor(x)
            x = x.view(-1, 64 * 4 * 4)
            x = self.classifier.fc1(x)
            x = self.classifier.relu1(x)
            x = self.classifier.drop(x)
            x = self.classifier.fc2(x)
            x = self.classifier.relu2(x)
            return x
        else:
            raise ValueError()

    def forward(self, x):
        features = self.feature_extractor(x)
        logits = self.classifier(features.view(-1, 64 * 4 * 4))
        return logits

class CNN003(nn.Module):
    """https://github.com/yaodongyu/TRADES/blob/e20f7b9b99c79ed3cf0d1bb12a47c229ebcac24a/models/small_cnn.py#L5"""
    def __init__(self, n_features, n_classes, drop=0.5, n_channels=1, save_intermediates=False):
        super(CNN002, self).__init__()

        self.num_channels = n_channels

        self.save_intermediates = save_intermediates
        self.intermediates = []

        activ = nn.ReLU(True)

        self.feature_extractor = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(self.num_channels, 32, 3)),
            ('relu1', activ),
        ]))

        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(64 * 4 * 4, n_classes)),
        ]))

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        nn.init.constant_(self.classifier.fc3.weight, 0)
        nn.init.constant_(self.classifier.fc3.bias, 0)

    def get_repr(self, x, mode="cnn_fet"):
        if mode == "cnn_fet":
            features = self.feature_extractor(x)
            return features.view(-1, 64 * 4 * 4)
        elif mode == "last":
            x = self.feature_extractor(x)
            x = x.view(-1, 64 * 4 * 4)
            x = self.classifier.fc1(x)
            x = self.classifier.relu1(x)
            x = self.classifier.drop(x)
            x = self.classifier.fc2(x)
            x = self.classifier.relu2(x)
            return x
        else:
            raise ValueError()

    def forward(self, x):
        features = self.feature_extractor(x)
        logits = self.classifier(features.view(-1, 64 * 4 * 4))
        return logits

class SimpleCNN001(nn.Module):
    def __init__(self, n_features, n_classes, n_channels=None, save_intermediates=False):
        super(SimpleCNN001, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.fc1 = nn.Linear(128, n_classes)

        self.save_intermediates = save_intermediates
        self.intermediates = []

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x
