from __future__ import print_function
from __future__ import absolute_import

from torch.nn import DataParallel
from .resnet import ResNet, resnet18, resnet50, cifar_resnet18, cifar_resnet50
# from .pyramidnet import PyramidNet
from .wideresnet import wrn40_2, wrn28_10
from . import preactresnet as parn
from . import masknet as mn
import torchvision.models as models
import torch.nn as nn

import os
os.environ['TORCH_HOME'] = '../hub'

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def get_model(model_name, num_class, datamixer=None, use_cuda=True, data_parallel=False):
    name = model_name

    if name == 'imagenet_teacher':
        model = models.resnet50(pretrained=True)
    elif name == "preactresnet18":
        model = parn.PreActResNet18(num_classes=num_class)
    elif name == 'resnet50_teacher':
        model = models.resnet50(pretrained=True)
    elif name == 'pretrained_resnet18':
        model = models.resnet18(pretrained=True)
        set_parameter_requires_grad(model, False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_class)
    elif name == 'pretrained_resnet50':
        model = models.resnet50(pretrained=True)
        set_parameter_requires_grad(model, False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_class)
    elif name == 'cifar_resnet18':
        model = cifar_resnet18(num_classes=num_class, datamixer=datamixer)
    elif name == 'cifar_resnet50':
        model = models.resnet50(pretrained=False)
    elif name == 'resnet50':
        model = resnet50(num_classes=num_class, datamixer=datamixer)
    elif name == 'resnet200':
        model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True)
    elif name == 'wresnet40_2':
        model = wrn40_2(dropout_rate=0.0, num_class=num_class, datamixer=datamixer)
    elif name == 'wresnet28_10':
        model = wrn28_10(dropout_rate=0.0, num_class=num_class, datamixer=datamixer)
    else:
        raise NameError('no model named, %s' % name)

    if data_parallel:
        model = model.cuda()
        model = DataParallel(model, device_ids=[0, 1])
    else:
        if use_cuda:
            model = model.cuda()
    return model


def get_mask_model(model_name, n_channel, k=2):
    if model_name == 'K1':
        masknet = mn.MaskNet_K1(k=k, in_planes=2*k, n_channel=n_channel).cuda()
    else:
        raise ValueError(f'Mask Net Model {model_name} Not Found!')
    return masknet
