import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models  # https://github.com/pytorch/hub/issues/46
from transformers import BertModel, DistilBertModel, GPT2Model, AutoModel
from subpopbench.models import wide_resnet
import timm
from torch.hub import load_state_dict_from_url


class Identity(nn.Module):

    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class MLP(nn.Module):

    def __init__(self, n_inputs, n_outputs, hparams):
        super(MLP, self).__init__()
        self.input = nn.Linear(n_inputs, hparams["mlp_width"])
        self.dropout = nn.Dropout(hparams["mlp_dropout"])
        self.hiddens = nn.ModuleList(
            [
                nn.Linear(hparams["mlp_width"], hparams["mlp_width"])
                for _ in range(hparams["mlp_depth"] - 2)
            ]
        )
        self.output = nn.Linear(hparams["mlp_width"], n_outputs)
        self.n_outputs = n_outputs

    def forward(self, x):
        x = self.input(x)
        x = self.dropout(x)
        x = F.relu(x)
        for hidden in self.hiddens:
            x = hidden(x)
            x = self.dropout(x)
            x = F.relu(x)
        x = self.output(x)
        return x


class PretrainedImageModel(torch.nn.Module):

    def forward(self, x):
        """Encode x into a feature vector of size n_outputs."""
        return self.dropout(self.network(x))

    def train(self, mode=True):
        """Override the default train() to freeze the BN parameters."""
        super().train(mode)
        self.freeze_bn()

    def freeze_bn(self):
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()


class ResNet(PretrainedImageModel):

    def __init__(self, input_shape, hparams, pretrained=True, freeze_bn=False):
        super(ResNet, self).__init__()

        if hparams["resnet18"]:
            self.network = torchvision.models.resnet18(pretrained=pretrained)
            self.n_outputs = 512
        else:
            self.network = torchvision.models.resnet50(pretrained=pretrained)
            self.n_outputs = 2048

        # adapt number of channels
        nc = input_shape[0]
        if nc != 3:
            tmp = self.network.conv1.weight.data.clone()

            self.network.conv1 = nn.Conv2d(
                nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
            )
            for i in range(nc):
                self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :]

        # save memory
        del self.network.fc
        self.network.fc = Identity()
        self.hparams = hparams
        self.dropout = nn.Dropout(hparams["last_layer_dropout"])

        if freeze_bn:
            self.freeze_bn()
        else:
            assert hparams["last_layer_dropout"] == 0.0


class TimmModel(PretrainedImageModel):

    def __init__(self, name, input_shape, hparams, pretrained=True, freeze_bn=False):
        super().__init__()

        self.network = timm.create_model(name, pretrained=pretrained, num_classes=0)
        self.n_outputs = self.network.num_features
        self.hparams = hparams
        self.dropout = nn.Dropout(hparams["last_layer_dropout"])

        if freeze_bn:
            self.freeze_bn()
        else:
            assert hparams["last_layer_dropout"] == 0.0


class HubModel(PretrainedImageModel):

    def __init__(
        self, name1, name2, input_shape, hparams, pretrained=True, freeze_bn=False
    ):
        super().__init__()

        self.network = torch.hub.load(name1, name2, force_reload=True)
        if hasattr(self.network, "num_features"):
            self.n_outputs = self.network.num_features
        else:
            self.n_outputs = 2048
        self.hparams = hparams
        self.dropout = nn.Dropout(hparams["last_layer_dropout"])

        if freeze_bn:
            self.freeze_bn()
        else:
            assert hparams["last_layer_dropout"] == 0.0


class ImportedModel(PretrainedImageModel):

    def __init__(
        self, network, n_outputs, input_shape, hparams, pretrained=True, freeze_bn=False
    ):
        super().__init__()

        self.network = network
        self.n_outputs = n_outputs
        self.hparams = hparams
        self.dropout = nn.Dropout(hparams["last_layer_dropout"])

        if freeze_bn:
            self.freeze_bn()
        else:
            assert hparams["last_layer_dropout"] == 0.0


class MNIST_CNN(nn.Module):

    n_outputs = 128

    def __init__(self, input_shape):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)
        self.bn0 = nn.GroupNorm(8, 64)
        self.bn1 = nn.GroupNorm(8, 128)
        self.bn2 = nn.GroupNorm(8, 128)
        self.bn3 = nn.GroupNorm(8, 128)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn0(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn1(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = self.bn2(x)

        x = self.conv4(x)
        x = F.relu(x)
        x = self.bn3(x)

        x = self.avgpool(x)
        x = x.view(len(x), -1)
        # x = F.normalize(x, dim=1)
        return x


class MNIST_MLP(nn.Module):
    # taken from IRM paper
    def __init__(self, grayscale=False):
        super(MLP, self).__init__()
        self.grayscale = grayscale
        if self.grayscale:
            lin1 = nn.Linear(14 * 14, 256)
        else:
            lin1 = nn.Linear(2 * 14 * 14, 256)
        lin2 = nn.Linear(256, 256)
        lin3 = nn.Linear(256, 1)
        for lin in [lin1, lin2, lin3]:
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)
        self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3)

    def forward(self, input):
        if self.grayscale:
            out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1)
        else:
            out = input.view(input.shape[0], 2 * 14 * 14)
            out = self._main(out)
        return out


class BertFeatureWrapper(torch.nn.Module):

    def __init__(self, model, hparams):
        super().__init__()
        self.model = model
        self.n_outputs = model.config.hidden_size
        classifier_dropout = (
            hparams["last_layer_dropout"]
            if hparams["last_layer_dropout"] != 0.0
            else model.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)

    def forward(self, x):
        kwargs = {"input_ids": x[:, :, 0], "attention_mask": x[:, :, 1]}
        if x.shape[-1] == 3:
            kwargs["token_type_ids"] = x[:, :, 2]
        output = self.model(**kwargs)
        if hasattr(output, "pooler_output"):
            return self.dropout(output.pooler_output)
        else:
            return self.dropout(output.last_hidden_state[:, 0, :])


def replace_module_prefix(state_dict, prefix, replace_with=""):
    state_dict = {
        (key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val
        for (key, val) in state_dict.items()
    }
    return state_dict


def get_torchvision_state_dict(url):
    model = load_state_dict_from_url(url)
    model_trunk = (
        model["classy_state_dict"]["base_model"]["model"]["trunk"]
        if "classy_state_dict" in model
        else model
    )
    return replace_module_prefix(model_trunk, "_feature_blocks.")


def imagenet_resnet50_ssl(URL):
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = torch.nn.Identity()
    model.load_state_dict(get_torchvision_state_dict(URL))
    model.fc.in_features = 2048
    model.n_outputs = 2048
    return model


def load_swag(URL):
    m = torchvision.models.vit_b_16(pretrained=False)
    m.heads = torch.nn.Identity()
    state_dict = load_state_dict_from_url(URL)
    state_dict_new = {}
    for key, val in state_dict.items():
        if "layer_" in key:
            key = key.replace("layer_", "encoder_layer_", 1)
        if key == "encoder.pos_embedding":
            val = val.permute((1, 0, 2))
        state_dict_new[key] = val
    m.load_state_dict(state_dict_new)
    m.n_outputs = 768
    return m


SIMCLR_RN50_URL = (
    "https://dl.fbaipublicfiles.com/vissl/model_zoo/"
    "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch"
)
BARLOWTWINS_RN50_URL = (
    "https://dl.fbaipublicfiles.com/vissl/model_zoo/"
    "barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch"
)


def Featurizer(data_type, input_shape, hparams):
    """Auto-select an appropriate featurizer for the given data type & input shape."""
    if data_type == "images":
        if len(input_shape) == 1:
            return MLP(input_shape[0], hparams["mlp_width"], hparams)
        elif input_shape[1:3] == (14, 14):
            return MNIST_MLP(input_shape)
        elif input_shape[1:3] == (28, 28):
            return MNIST_CNN(input_shape)
        elif input_shape[1:3] == (32, 32):
            return wide_resnet.WideResNet(input_shape, 16, 2, 0.0)
        elif input_shape[1:3] == (224, 224):
            if hparams["image_arch"] == "resnet_sup_in1k":
                return ResNet(input_shape, hparams, hparams["pretrained"])
            elif hparams["image_arch"] in [
                "vit_sup_in1k",
                "vit_sup_in21k",
                "vit_clip_oai",
                "vit_clip_laion",
                "resnet_sup_in21k",
                "vit_dino_in1k",
            ]:
                return TimmModel(
                    {
                        "resnet_sup_in21k": "tresnet_m_miil_in21k",  # https://github.com/Alibaba-MIIL/ImageNet21K
                        "vit_sup_in1k": "vit_base_patch32_224.augreg_in1k",  # https://arxiv.org/abs/2106.10270
                        "vit_sup_in21k": "vit_base_patch32_224.augreg_in21k",
                        "vit_clip_oai": "vit_base_patch32_clip_224.openai",
                        "vit_clip_laion": "vit_base_patch32_clip_224.laion2b",
                        "vit_dino_in1k": "vit_base_patch16_224.dino",  # https://github.com/facebookresearch/dino
                    }[hparams["image_arch"]],
                    input_shape,
                    hparams,
                    hparams["pretrained"],
                )
            elif hparams["image_arch"] == "resnet_dino_in1k":
                return ImportedModel(
                    imagenet_resnet50_ssl(
                        "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
                    ),
                    2048,
                    input_shape,
                    hparams,
                    hparams["pretrained"],
                )
            elif hparams["image_arch"] == "vit_sup_swag":
                # https://github.com/facebookresearch/SWAG
                return ImportedModel(
                    load_swag("https://dl.fbaipublicfiles.com/SWAG/vit_b16.torch"),
                    768,
                    input_shape,
                    hparams,
                    hparams["pretrained"],
                )
            elif hparams["image_arch"] in ["resnet_barlow_in1k", "resnet_simclr_in1k"]:
                return ImportedModel(
                    imagenet_resnet50_ssl(
                        {
                            "resnet_simclr_in1k": SIMCLR_RN50_URL,
                            "resnet_barlow_in1k": BARLOWTWINS_RN50_URL,
                        }[hparams["image_arch"]]
                    ),
                    2048,
                    input_shape,
                    hparams,
                    hparams["pretrained"],
                )
        else:
            raise NotImplementedError
    elif data_type == "text":
        if hparams["text_arch"] == "bert-base-uncased":
            text_model = BertModel.from_pretrained(hparams["text_arch"])
        elif hparams["text_arch"] in [
            "xlm-roberta-base",
            "allenai/scibert_scivocab_uncased",
        ]:
            text_model = AutoModel.from_pretrained(hparams["text_arch"])
        elif hparams["text_arch"] == "gpt2":
            text_model = GPT2Model.from_pretrained("gpt2")
        elif hparams["text_arch"] == "distilbert-base-uncased":
            text_model = DistilBertModel.from_pretrained("distilbert-base-uncased")
        else:
            raise NotImplementedError
        return BertFeatureWrapper(text_model, hparams)
    elif data_type == "tabular":
        return MLP(input_shape[0], hparams["mlp_width"], hparams)
    else:
        raise NotImplementedError(f"{data_type} not supported.")


def Classifier(in_features, out_features, is_nonlinear=False):
    if is_nonlinear:
        return torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 2, in_features // 4),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 4, out_features),
        )
    else:
        return torch.nn.Linear(in_features, out_features)
