import torch
import torchvision
from torch import nn


CLASS_NUM = {
    'flickr': 8,
    'twitter': 8,
    'fbp5500': 5,
    'raf': 6,
    'emotion6': 7
}


class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias
    

def build_model(args, other=False):
    # if args.dataset_name == 'raf':
    #     model = nn.Sequential(
    #         nn.Linear(2000, CLASS_NUM[args.dataset_name])
    #     )

    model = getattr(torchvision.models, args.backbone)(
        weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2,
        norm_layer=FrozenBatchNorm2d
    )
    if other:
        model = getattr(torchvision.models, args.backbone)(
        weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1,
        norm_layer=FrozenBatchNorm2d
    )
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, CLASS_NUM[args.dataset_name])
    return model


def build_linear_model(args, feature_dim):
    model = nn.Linear(feature_dim, CLASS_NUM[args.dataset_name])
    return model

