from .FeedForward import FeedForward, Dummy, Identity
from .lenet import MultiLeNetBackbone, MultiLeNetClassification, MultiLeNetRegression, SmallMultiLeNetBackbone, \
    SmallMultiLeNetClassification, SmallMultiLeNetRegression, SmallMultiLeNeBinaryClassification, \
    MultiLeNetBinaryClassification
from .resnet import ResNet18Backbone, ResNet18Classification
from .convceleba import CelebABackbone, CelebAClassification


import torch.nn as nn


class Cached(nn.Module):
    def __init__(self, net, first):
        super(Cached, self).__init__()
        self.first = first
        self.net = net if first else [net]

    def forward(self, input):
        return self.net(input) if self.first else self.net[0](input)


def number_parameters(model):
    from functools import reduce
    return sum(sum(reduce(lambda a, b: a*b, x.size()) for x in m.parameters()) for m in model.values())


def get_activation(name):
    return {
        'identity': nn.Identity(),
        'relu': nn.ReLU(),
        'sigmoid': nn.Sigmoid(),
        'tanh': nn.Tanh()
    }[name]


def build_dense(tasks, args):
    model = {}
    shared = getattr(args, 'shared', False)

    enc_params = getattr(args, 'encoder', args)
    model['rep'] = FeedForward(args.input_size, args.latent_size, enc_params.hidden_size, enc_params.num_layers,
                               get_activation(enc_params.activation))

    dec_params = getattr(args, 'decoder', args)
    if isinstance(dec_params.hidden_size, int) and not shared:
        dec_params.hidden_size = [dec_params.hidden_size] * len(tasks)
    if isinstance(dec_params.output_size, int) and not shared:
        dec_params.output_size = [dec_params.output_size] * len(tasks)
    if isinstance(dec_params.num_layers, int) and not shared:
        dec_params.num_layers = [dec_params.num_layers] * len(tasks)

    if shared:
        net = FeedForward(args.latent_size, dec_params.output_size, dec_params.hidden_size, dec_params.num_layers,
                          get_activation(dec_params.activation), dec_params.drop_last)

        for i, task_i in enumerate(tasks):
            model[task_i.name] = Cached(net, i == 0)
    else:
        for i, task_i in enumerate(tasks):
            model[task_i.name] = FeedForward(args.latent_size, dec_params.output_size[i], dec_params.hidden_size[i],
                                             dec_params.num_layers[i], get_activation(dec_params.activation),
                                             dec_params.drop_last or task_i.loss == 'mse')

    return model


def build_lenet(tasks, args, size):
    model = {
        'rep': {'original': MultiLeNetBackbone,
                'small': SmallMultiLeNetBackbone}[size](args.shape, args.dropout)
    }

    for task_i in tasks:
        if task_i.loss.__name__ == 'nll':
            model[task_i.name] = {'original': MultiLeNetClassification,
                                  'small': SmallMultiLeNetClassification}[size](args.dropout)
        elif task_i.loss.__name__ == 'bce':
            model[task_i.name] = {'original': MultiLeNetBinaryClassification,
                                  'small': SmallMultiLeNeBinaryClassification}[size](args.dropout)
        else:
            model[task_i.name] = {'original': MultiLeNetRegression,
                                  'small': SmallMultiLeNetRegression}[size](args.dropout)

    return model


def build_resnet(tasks, args):
    model = {'rep': ResNet18Backbone(pretrained=args.pretrained, cifar=len(tasks) == 10)}

    for task_i in tasks:
        model[task_i.name] = ResNet18Classification(cifar=len(tasks) == 10)

    return model


def build_convceleba(tasks, args):
    model = {'rep': CelebABackbone()}

    for task_i in tasks:
        model[task_i.name] = CelebAClassification(task_i.index)

    return model


def build_svhn_lenet(tasks, args):
    from .lenet_svhn import MultiLeNetBackbone, MultiLeNetClassification, MultiLeNetRegression, \
        MultiLeNetBinaryClassification

    model = {
        'rep': MultiLeNetBackbone(args.shape, args.dropout)
    }

    for task_i in tasks:
        if task_i.loss.__name__ == 'nll':
            model[task_i.name] = MultiLeNetClassification(args.dropout)
        elif task_i.loss.__name__ == 'bce':
            model[task_i.name] = MultiLeNetBinaryClassification(args.dropout)
        else:
            model[task_i.name] = MultiLeNetRegression(args.dropout)

    return model


def build_model(name, tasks, args):
    if name == 'dense':
        model = build_dense(tasks, args)
    elif name == 'lenet':
        model = build_lenet(tasks, args, size='original')
    elif name == 'small_lenet':
        model = build_lenet(tasks, args, size='small')
    elif name == 'svhn_lenet':
        model = build_svhn_lenet(tasks, args)
    elif name == 'resnet':
        model = build_resnet(tasks, args)
    elif name == 'convceleba':
        model = build_convceleba(tasks, args)
    elif name == 'dummy':
        model = {'rep': Dummy()}
        for task_i in tasks:
            model[task_i.name] = Identity()
    else:
        raise NameError(f'model {name} does not exist.')

    # print('No. of parameters:', number_parameters(model))

    return model


__all__ = ['build_model']
