from pandas import wide_to_long
import torch
import torch.nn as nn
import torch.nn.functional as F

class Identity(nn.Module):
    """An identity layer"""
    def __init__(self, d):
        super().__init__()
        self.in_features = d
        self.out_features = d

    def forward(self, x):
        return x

def initialize_model(config, d_out, is_featurizer=False):
    """
    Initializes models according to the config
        Args:
            - config (dictionary): config dictionary
            - d_out (int): the dimensionality of the model output
            - is_featurizer (bool): whether to returnc a model or a (featurizer, classifier) pair that constitutes a model.
        Output:
            If is_featurizer=True:
            - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality)
            - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer.

            If is_featurizer=False:
            - model: a model that is equivalent to nn.Sequential(featurizer, classifier)
    """
    # print('initializing model ........... ')
    # if config.model == 'cifar_resnet18':
    #     config.model = 'resnet18'
    # print(config.model)
    if config.model in ['resnet50', 'resnet34', 'resnet18', 'wideresnet50', 'densenet121']:
        if is_featurizer:
            featurizer = initialize_torchvision_model(
                name=config.model,
                d_out=None,
                **config.model_kwargs)
            classifier = nn.Linear(featurizer.d_out, d_out)
            model = (featurizer, classifier)
        else:
            model = initialize_torchvision_model(
                name=config.model,
                d_out=d_out,
                **config.model_kwargs)
    elif config.model in ['resnet10vw', 'resnet18vw', 'resnet34vw', 'resnet50vw']:  # VARIABLE WIDTH ResNet
        import models.variable_width_resnet as variable_width_resnet
        constructor = getattr(variable_width_resnet, config.model)
        config.model_kwargs['width'] = config.width
        if is_featurizer:
            featurizer = constructor(num_classes=None, **config.model_kwargs)
            classifier = nn.Linear(featurizer.d_out, d_out)
            model = (featurizer, classifier)
        else:
            model = constructor(num_classes=d_out, **config.model_kwargs)
    elif config.model.startswith('cifar'):
        from models import resnet_cifar
        model = {'cifar_resnet18': resnet_cifar.ResNet18(),
                'cifar_resnet34': resnet_cifar.ResNet34(),
                'cifar_resnet50': resnet_cifar.ResNet50(),
                'cifar_resnet101': resnet_cifar.ResNet101(),
                'cifar_resnet152': resnet_cifar.ResNet152()}[config.model]
    elif config.model in ['mlp', 'convnet', 'lenet']:
        from models.simple_models import Net, ConvNet, LeNet
        model = {'mlp': Net(width=config.model_kwargs['width']), 
                'convnet': ConvNet(width=config.model_kwargs['width']),
                'lenet': LeNet(width=config.model_kwargs['width'])}[config.model]
    elif config.model == 'logistic_regression':
        assert not is_featurizer, "Featurizer not supported for logistic regression"
        model = nn.Linear(out_features=d_out, **config.model_kwargs)
    
    else:
        raise ValueError(f'Model: {config.model} not recognized.')

    # The `needs_y` attribute specifies whether the model's forward function
    # needs to take in both (x, y).
    # If False, Algorithm.process_batch will call model(x).
    # If True, Algorithm.process_batch() will call model(x, y) during training,
    # and model(x, None) during eval.
    if not hasattr(model, 'needs_y'):
        # Sometimes model is a tuple of (featurizer, classifier)
        if isinstance(model, tuple):
            for submodel in model:
                submodel.needs_y = False
        else:
            model.needs_y = False

    return model

def initialize_torchvision_model(name, d_out, **kwargs):
    import torchvision

    freeze_feats = kwargs.get('lin', False)
    kwargs.pop("lin", None)

    # get constructor and last layer names
    if name == 'wideresnet50':
        constructor_name = 'wide_resnet50_2'
        last_layer_name = 'fc'
    elif name == 'densenet121':
        constructor_name = name
        last_layer_name = 'classifier'
    elif name in ('resnet50', 'resnet34', 'resnet18'):
        constructor_name = name
        last_layer_name = 'fc'
    else:
        raise ValueError(f'Torchvision model {name} not recognized')
    # construct the default model, which has the default last layer
    constructor = getattr(torchvision.models, constructor_name)
    model = constructor(**kwargs)
    # adjust the last layer
    d_features = getattr(model, last_layer_name).in_features
    if d_out is None:  # want to initialize a featurizer model
        last_layer = Identity(d_features)
        model.d_out = d_features
    else: # want to initialize a classifier for a particular num_classes
        last_layer = nn.Linear(d_features, d_out)
        model.d_out = d_out
    setattr(model, last_layer_name, last_layer)

    if freeze_feats:
        # freeze all layers but the last fc
        for name, param in model.named_parameters():
            if name not in [f'{last_layer_name}.weight', f'{last_layer_name}.bias']:
                param.requires_grad = False
        # # init the fc layer
        # model.fc.weight.data.normal_(mean=0.0, std=0.01)
        # model.fc.bias.data.zero_()
    return model
