# This file is modified from official pycls repository

"""Model and loss construction functions."""

from pycls.core.net import SoftCrossEntropyLoss
from pycls.models.resnet import *
from pycls.models.vgg import *
from pycls.models.alexnet import *

import torch
from torch import nn
from torch.nn import functional as F
# Supported models
_models = {
    # VGG style architectures
    'vgg11': vgg11,
    'vgg11_bn': vgg11_bn,
    'vgg13': vgg13,
    'vgg13_bn': vgg13_bn,
    'vgg16': vgg16,
    'vgg16_bn': vgg16_bn,
    'vgg19': vgg19,
    'vgg19_bn': vgg19_bn,

    # ResNet style archiectures
    'resnet18': resnet18,
    'resnet34': resnet34,
    'resnet50': resnet50,
    'resnet101': resnet101,
    'resnet152': resnet152,
    'resnext50_32x4d': resnext50_32x4d,
    'resnext101_32x8d': resnext101_32x8d,
    'wide_resnet50_2': wide_resnet50_2,
    'wide_resnet101_2': wide_resnet101_2,

    # AlexNet architecture
    'alexnet': alexnet
}

# Supported loss functions
_loss_funs = {"cross_entropy": SoftCrossEntropyLoss}


class FeaturesNet(nn.Module):
    # def __init__(self, in_layers, out_layers, use_mlp=True, penultimate_active=False):
    #     super().__init__()
    #     self.use_mlp = use_mlp
    #     self.penultimate_active = penultimate_active
    #     self.lin1 = nn.Linear(in_layers, 128)
    #     self.bn1 = nn.BatchNorm1d(128)
    #     self.lin2 = nn.Linear(128, 128)
    #     self.bn2 = nn.BatchNorm1d(128)
    #     self.final = nn.Linear(128, out_layers)

    # def forward(self, x):
    #     feats = x
    #     # print(x.shape)
    #     if self.use_mlp:
    #         x = self.bn1(F.relu(self.lin1(x)))
    #         x = self.bn2(F.relu((self.lin2(x))))
    #     out = self.final(x)
    #     if self.penultimate_active:
    #         return feats, out
    #     return out
    def __init__(self, in_layers, out_layers, use_mlp=False, penultimate_active=False):
        super().__init__()
        self.use_mlp = use_mlp
        self.penultimate_active = penultimate_active
        self.lin1 = nn.Linear(in_layers, in_layers)
        self.lin2 = nn.Linear(in_layers, in_layers)
        self.final = nn.Linear(in_layers, out_layers)

    def forward(self, x):
        feats = x
        if self.use_mlp:
            x = F.relu(self.lin1(x))
            x = F.relu((self.lin2(x)))
        out = self.final(x)
        if self.penultimate_active:
            return feats, out
        return out



# class SIMPLE_CNN(nn.Module):
#     def __init__(self, in_layers, out_layers, use_mlp=False, penultimate_active=False):
#         super().__init__()
#         self.use_mlp = use_mlp
#         self.penultimate_active = penultimate_active
#         self.lin1 = nn.Linear(in_layers, in_layers)
#         self.lin2 = nn.Linear(in_layers, in_layers)
#         self.final = nn.Linear(in_layers, out_layers)

#     def forward(self, x):
#         feats = x
#         if self.use_mlp:
#             x = F.relu(self.lin1(x))
#             x = F.relu((self.lin2(x)))
#         out = self.final(x)
#         if self.penultimate_active:
#             return feats, out
#         return out





def get_model(cfg):
    """Gets the model class specified in the config."""
    err_str = "Model type '{}' not supported"
    assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
    return _models[cfg.MODEL.TYPE]


def get_loss_fun(cfg):
    """Gets the loss function class specified in the config."""
    err_str = "Loss function type '{}' not supported"
    assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
    return _loss_funs[cfg.MODEL.LOSS_FUN]


def build_model(cfg):
    """Builds the model."""

    # if cfg.onenn:
    #     return 
    # else:
    if cfg.DATASET.NAME in ['TRPB', 'TRPB_umap', 'octanoate', 'butyrate', 'acetate', '60butyrate', '90butyrate', 'octanoate_700_2', 'butyrate_700_2', 'acetate_700_2', '60butyrate_700_2', '90butyrate_700_2', 'octanoate_700_1028', 'butyrate_700_1028', 'acetate_700_1028', '60butyrate_700_1028', '90butyrate_700_1028', 'sysdata']:
        if cfg.EXP_NAME == "Q1_no700":
            dims = {"embed": 1280, "umap_2": 2, "umap_50": 50, "umap_100": 100, "umap_500": 500}
            num_features = dims[cfg.data_type]
            return FeaturesNet(num_features, cfg.MODEL.NUM_CLASSES)
        else:
            if cfg.DATASET.NAME == 'TRPB':
                num_features = 84
            elif cfg.DATASET.NAME in ['octanoate_700_1028', 'butyrate_700_1028', 'acetate_700_1028', '60butyrate_700_1028', '90butyrate_700_1028']:
                num_features = 1280
            else:
                num_features =2
            # num_features = 84 if cfg.DATASET.NAME == 'TRPB' else 2
            return FeaturesNet(num_features, cfg.MODEL.NUM_CLASSES)

    if cfg.MODEL.LINEAR_FROM_FEATURES:
        if cfg.DATASET.NAME in ['IMAGENET50', 'IMAGENET100', 'IMAGENET200', 'TRPB']:
            num_features = 384
        elif cfg.DATASET.NAME in ['octanoate_700_1028', 'butyrate_700_1028', 'acetate_700_1028', '60butyrate_700_1028', '90butyrate_700_1028']:
            num_features = 1279
        else:
            num_features =512
        # num_features = 384 if cfg.DATASET.NAME in ['IMAGENET50', 'IMAGENET100', 'IMAGENET200', 'TRPB'] else 512
        return FeaturesNet(num_features, cfg.MODEL.NUM_CLASSES)

    model = get_model(cfg)(num_classes=cfg.MODEL.NUM_CLASSES, use_dropout=True)
    if cfg.DATASET.NAME == 'MNIST':
        model.conv1 =  torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    
    return model 


def build_loss_fun(cfg):
    """Build the loss function."""
    return get_loss_fun(cfg)()


def register_model(name, ctor):
    """Registers a model dynamically."""
    _models[name] = ctor


def register_loss_fun(name, ctor):
    """Registers a loss function dynamically."""
    _loss_funs[name] = ctor
