from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from .toy_models import *
from .resnet import *
from .resnet_leaky import ResNet18, ResNet18Features
from .classifier_head import ImageClassifierDANN, ImageClassifierMDD, ImageClassifier
from path_learning.utils.log import get_logger

BIAS_VALUE = None

logger = get_logger("models")


class BasicCNN(nn.Module):
    def __init__(self, noutputs=10):
        super(BasicCNN, self).__init__()
        self.name = "CNN"
        self.nfeatures = 500
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5 * 5 * 50, self.nfeatures)
        self.fc2 = nn.Linear(self.nfeatures, noutputs)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5 * 5 * 50)  # or 4*4 for 28x28 images
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# FCN described in MNIST experiments: https://arxiv.org/abs/1711.08856
# (CLP = Critical Learning Pds)
class CLPFCN(nn.Module):
    def __init__(self):
        super(CLPFCN, self).__init__()
        self.name = "CLPFCN"
        self.net = nn.Sequential(
            nn.Linear(1024, 2500),
            nn.BatchNorm1d(num_features=2500),
            nn.ReLU(),
            nn.Linear(2500, 2000),
            nn.BatchNorm1d(num_features=2000),
            nn.ReLU(),
            nn.Linear(2000, 1500),
            nn.BatchNorm1d(num_features=1500),
            nn.ReLU(),
            nn.Linear(1500, 1000),
            nn.BatchNorm1d(num_features=1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.BatchNorm1d(num_features=500),
            nn.ReLU()
        )
        self.fc = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 1024)
        x = self.net(x)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


class BasicFCN(nn.Module):
    def __init__(self):
        super(BasicFCN, self).__init__()
        self.name = "FCN"
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 500)
        self.fc3 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)


class LinearFCN(nn.Module):
    def __init__(self):
        super(LinearFCN, self).__init__()
        self.name = "LinearFCN"
        self.fc1 = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        return F.log_softmax(self.fc1(x), dim=1)


class AllConvNet(nn.Module):
    def __init__(self, input_size=3, n_classes=10, **kwargs):
        super(AllConvNet, self).__init__()
        self.noise_level = None
        self.conv1 = nn.Conv2d(input_size, 96, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(96)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(96)
        self.conv3 = nn.Conv2d(96, 192, 3, padding=1, stride=2)
        self.bn3 = nn.BatchNorm2d(192)
        self.conv4 = nn.Conv2d(192, 192, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(192)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(192)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.bn6 = nn.BatchNorm2d(192)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(192)
        self.conv8 = nn.Conv2d(192, 192, 1)
        self.bn8 = nn.BatchNorm2d(192)

        self.class_conv = nn.Conv2d(192, n_classes, 1)

    def set_noise_level(self, noise_level: float):
        self.noise_level = noise_level

    def forward(self, x):
        conv1_out = F.relu(self.bn1(self.conv1(x)))
        conv2_out = F.relu(self.bn2(self.conv2(conv1_out)))
        conv3_out = F.relu(self.bn3(self.conv3(conv2_out)))
        conv4_out = F.relu(self.bn4(self.conv4(conv3_out)))
        conv5_out = F.relu(self.bn5(self.conv5(conv4_out)))
        conv6_out = F.relu(self.bn6(self.conv6(conv5_out)))
        conv7_out = F.relu(self.bn7(self.conv7(conv6_out)))
        conv8_out = F.relu(self.bn8(self.conv8(conv7_out)))

        class_out = F.relu(self.class_conv(conv8_out))

        pool_out = F.adaptive_avg_pool2d(class_out, 1)
        pool_out.squeeze_(-1)
        pool_out.squeeze_(-1)
        return pool_out


class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim=32, num_class=4):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text_offsets):
        (text, offsets) = text_offsets
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def pick_model(**kwargs):
    model_name = kwargs["name"]
    input_size = kwargs.get("input_size", 32)
    num_classes = kwargs["n_outputs"]
    no_of_features = kwargs.get("no_of_features", 2)
    no_of_neurons = kwargs.get("no_of_neurons", 1000)
    feature_extract = kwargs.get("feature_extract", False)
    negative_slope = kwargs.get("slope", 0.0)
    use_pretrained: bool = kwargs.get("use_pretrained", False)
    non_standard_pretrained_path = kwargs.get("non_standard_pretrained_path", None)

    width = kwargs.get("width", 1024)
    bottleneck_dim: bool = kwargs.get("bottleneck_dim", 1024)

    global BIAS_VALUE
    if "use_bias_reset" in kwargs:
        BIAS_VALUE = kwargs["use_bias_reset"]
    elif "only_first_bias_reset" in kwargs:
        BIAS_VALUE = kwargs["only_first_bias_reset"]
    elif "wo_first_bias_reset" in kwargs:
        BIAS_VALUE = kwargs["wo_first_bias_reset"]
    else:
        BIAS_VALUE = None
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None

    if model_name == "clp_mnist_fcn":
        model_ft = CLPFCN()
        # input_size = 1024

    elif model_name == "basic_net":
        model_ft = ToyNet(negative_slope)
    elif model_name == "toy_two_linear_net":
        model_ft = ToyTwoLinear(n_input=no_of_features, n_features=no_of_neurons, noutputs=num_classes)

    elif model_name == "linear_reg_net":
        model_ft = LinearReg(negative_slope, nfeatures=no_of_features)

    elif model_name == "two_layer_linear_net":
        model_ft = TwoLayerLinear(negative_slope, nfeatures=no_of_features, nhiddenneurons=no_of_neurons)

    elif model_name == "two_layer_relu_net":
        model_ft = TwoLayerRelu(negative_slope, nfeatures=no_of_features, nhiddenneurons=no_of_neurons)

    elif model_name == "toy_batch_net":
        model_ft = ToyBatchNet(negative_slope)

    elif model_name == "toy_relu_net":
        model_ft = ToyReluNet(negative_slope)

    elif model_name == "toy_res_net":
        model_ft = ToyResNet(negative_slope)

    elif model_name == "toy_batch_relu_net":
        model_ft = ToyBatchReluNet(negative_slope)

    elif model_name == "toy_batch_res_net":
        model_ft = ToyBatchResNet(negative_slope)

    elif model_name == "toy_svm":
        model_ft = ToySVM(negative_slope)

    elif model_name == "toy_svm_poly":
        model_ft = ToySVMpoly(negative_slope)

    elif model_name == "toy_long_batch_relu_net":
        model_ft = ToyLongBatchReluNet(negative_slope)

    elif model_name == "toy_long_batch_res_net":
        model_ft = ToyLongBatchResNet(negative_slope)

    elif model_name == "toy_gaussian_naive":
        model_ft = ToyGaussianNaiveBayes(negative_slope)

    elif model_name == "allconv":
        model_ft = AllConvNet()
        # input_size = 32

    elif model_name == "resnet":
        """ Resnet18
        """
        # input size of torchvision Resnet18: 224x224
        # model_ft = ResNet18(num_classes=num_classes)
        # input size for resnet_leaky: 32

        if use_pretrained:
            # !! we are using the torchvision Resnet18 to use a pretrained model !!
            # It is not the same model as our local resnet_leaky.py as the architecture is tuned for ImageNet
            model_ft = models.resnet18(pretrained=use_pretrained)
            set_parameter_requires_grad(model_ft, feature_extract)
            num_ftrs = model_ft.fc.in_features
            model_ft.fc = nn.Linear(num_ftrs, num_classes)

            # replace ReLUs with LeakyReLUs if requested
            if negative_slope > 0.0:
                replace_relu_with_leaky_relu(model_ft, negative_slope=negative_slope)
        else:
            # this model architecture of Resnet18 is adjusted for CIFAR10 and input images of size 32
            model_ft = ResNet18(input_size=input_size, num_classes=num_classes, negative_slope=negative_slope)

        if BIAS_VALUE is not None and not ("wo_first_bias_reset" in kwargs):
            model_ft.apply(reset_bias)

    elif model_name in ["standard-resnet50", "resnet18-backbone"]:
        if model_name == "standard-resnet50":
            logger.info(f"Using standard resnet50 with option pretraining {use_pretrained}")
            backbone = resnet50(pretrained=use_pretrained)
        elif model_name == "resnet18-backbone":
            logger.info(f"Using Resnet18 backbone - adapted for small image sizes.")
            backbone = ResNet18Features(input_size, num_classes, negative_slope=negative_slope, use_groupnorm=False)
        else:
            ValueError("Choose valid model backbone.")
        logger.info("loaded backbone")
        
        model_adaptation = kwargs.get("adaptation", None)
        if model_adaptation == "mdd":
            model_ft = ImageClassifierMDD(backbone, num_classes, bottleneck_dim=bottleneck_dim,
                                          width=width)
        elif model_adaptation == "dann":
            model_ft = ImageClassifierDANN(backbone, num_classes, bottleneck_dim=bottleneck_dim)
        else:
            model_ft = ImageClassifier(backbone, num_classes, only_predictions=True, bottleneck_dim=bottleneck_dim)
        logger.info(f"loaded classifier")

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        # input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        # input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model_ft.num_classes = num_classes
        # input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        # input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        # input_size = 299

    elif model_name == "resnet34":
        """ ResNet34
        input size: 224x224
        model_ft = ResNet34(num_classes=num_classes)
        """

        model_ft = models.resnet34(pretrained=use_pretrained)

        if non_standard_pretrained_path is not None:
            model_ft.load_state_dict(torch.load(str(non_standard_pretrained_path)))

        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)

        # replace ReLUs with LeakyReLUs if requested
        if negative_slope > 0.0:
            replace_relu_with_leaky_relu(model_ft, negative_slope=negative_slope)

        if BIAS_VALUE is not None:
            model_ft.apply(reset_bias)

    elif model_name == "text_sentiment":
        model_ft = TextSentiment(vocab_size=input_size)

    else:
        logger.warning("Invalid model name, exiting...")
        exit()

    return model_ft


def reset_bias(m):
    if type(m) == torch.nn.Linear or type(m) == torch.nn.Conv2d or type(m) == torch.nn.BatchNorm2d:
        if m.bias is not None:
            m.bias.data.fill_(torch.max(m.bias.data) + BIAS_VALUE)


def replace_relu_with_leaky_relu(m: nn.Module, negative_slope: float = 0.0):
    layers_to_replace = []
    for layer_name, layer in m.named_children():
        if isinstance(layer, nn.ReLU):
            layers_to_replace.append(layer_name)
        else:
            replace_relu_with_leaky_relu(layer, negative_slope=negative_slope)

    for layer_name in layers_to_replace:
        setattr(m, layer_name, nn.LeakyReLU(negative_slope=negative_slope))
