from __future__ import print_function, division, absolute_import
import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .resnext_features import resnext101_32x4d_features
from .resnext_features import resnext101_64x4d_features

__all__ = ['ResNeXt101_32x4d', 'resnext101_32x4d',
           'ResNeXt101_64x4d', 'resnext101_64x4d']

pretrained_settings = {
    'resnext101_32x4d': {
        'imagenet': {
            'url': 'http://data.lip6.fr/cadene/pretrainedmodels/resnext101_32x4d-29e315fa.pth'
        }
    },
    'resnext101_64x4d': {
        'imagenet': {
            'url': 'http://data.lip6.fr/cadene/pretrainedmodels/resnext101_64x4d-e77a0586.pth'
        }
    }
}

class ResNeXt101_32x4d(nn.Module):

    def __init__(self, num_classes=1000, **kwargs):
        super(ResNeXt101_32x4d, self).__init__()
        self.num_classes = num_classes
        self.features = resnext101_32x4d_features
        #self.avg_pool = nn.AvgPool2d((7, 7), (1, 1))
        #self.last_linear = nn.Linear(2048, num_classes)

    def logits(self, input):
        x = self.avg_pool(input)
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        #x = self.logits(x)
        return x


class ResNeXt101_64x4d(nn.Module):

    def __init__(self, num_classes=1000, **kwargs):
        super(ResNeXt101_64x4d, self).__init__()
        self.num_classes = num_classes
        self.features = resnext101_64x4d_features
        #self.avg_pool = nn.AvgPool2d((7, 7), (1, 1))
        #self.last_linear = nn.Linear(2048, num_classes)

    def logits(self, input):
        x = self.avg_pool(input)
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        #x = self.logits(x)
        return x


def resnext101_32x4d(num_classes=1000, pretrained=False, **kwargs):
    model = ResNeXt101_32x4d(num_classes=num_classes, **kwargs)
    if pretrained:
        settings = pretrained_settings['resnext101_32x4d']['imagenet']
        assert num_classes == 1000, \
            "num_classes should be {}, but is {}".format(1000, num_classes)
        model.load_state_dict(remove_fc(model_zoo.load_url(settings['url'])))
        '''
        model.input_space = settings['input_space']
        model.input_size = settings['input_size']
        model.input_range = settings['input_range']
        model.mean = settings['mean']
        model.std = settings['std']
        '''
    return model


def resnext101_64x4d(num_classes=1000, pretrained=False, **kwargs):
    model = ResNeXt101_64x4d(num_classes=num_classes, **kwargs)
    if pretrained:
        settings = pretrained_settings['resnext101_64x4d']['imagenet']
        assert num_classes == 1000, \
            "num_classes should be {}, but is {}".format(1000, num_classes)
        model.load_state_dict(remove_fc(model_zoo.load_url(settings['url'])))
        '''
        model.input_space = settings['input_space']
        model.input_size = settings['input_size']
        model.input_range = settings['input_range']
        model.mean = settings['mean']
        model.std = settings['std']
        '''
    return model


def remove_fc(state_dict):
    """Remove the fc layer parameters from state_dict."""
    # for key, value in state_dict.items():
    #    if key.startswith('fc.'):
    #        del state_dict[key]
    del state_dict['last_linear.weight']
    del state_dict['last_linear.bias']
    return state_dict
