#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from modules import PredsegModule, PredsegSequential
from resnet import get_resnet101_deeplab, get_predseg1


def get_shifts(model):
    if model == 'pixel':
        shift = [[0, 0]]
        subsamp = [1]
    elif model == 'linear':
        shift = [[0, 0], [2, 2]]
        subsamp = [1, 1]
    elif model == 'linear3':
        shift = [[1, 1]]
        subsamp = [1]
    elif model == 'linearbig':
        shift = [[5, 5]]
        subsamp = [1]
    elif model == 'conv1':
        shift = [[0, 0], [3, 3]]
        subsamp = [1, 2]
    elif model == 'resdl':
        shift = [[0, 0], [0, 0], [0, 0], [0, 0]]
        subsamp = [1, 2, 2, 2]
    elif model == 'predseg1':
        shift = [[0, 0], [0, 0], [0, 0], [0, 0]]
        subsamp = [1, 3, 6, 12]
    elif model == 'predseg1_norm':
        shift = [[0, 0], [0, 0], [0, 0], [0, 0]]
        subsamp = [1, 3, 6, 12]
    return shift, subsamp


def get_module(model, neighbors):
    if model == 'pixel':
        base_module = torch.nn.Identity()
        module = PredsegModule(base_module, neighbors, 3)
        dim_out = 3
    elif model == 'linear':
        base_module = torch.nn.Identity()
        module1 = PredsegModule(base_module, neighbors, 3)
        base_module = torch.nn.Conv2d(3, 10, 5, 1)
        module2 = PredsegModule(base_module, neighbors, 10)
        module = PredsegSequential(module1, module2)
        dim_out = 10
    elif model == 'linear3':
        base_module = torch.nn.Conv2d(3, 3, 3, 1)
        module = PredsegModule(base_module, neighbors, 3)
        dim_out = 3
    elif model == 'linearbig':
        base_module = torch.nn.Conv2d(3, 50, 11, 1)
        module = PredsegModule(base_module, neighbors, 50)
        dim_out = 50
    elif model == 'conv1':
        base_module = torch.nn.Identity()
        module1 = PredsegModule(base_module, neighbors, 3)
        base_module = torch.nn.Conv2d(3, 64, 7, 2)
        module2 = PredsegModule(base_module, neighbors, 64)
        module = PredsegSequential(module1, module2)
        dim_out = 64
    elif model == 'resdl':
        module = get_resnet101_deeplab(neighbors=neighbors)
        dim_out = 2048
    elif model == 'predseg1':
        module = get_predseg1(neighbors=neighbors, normalize_output=False)
        dim_out = 256
    elif model == 'predseg1_norm':
        module = get_predseg1(neighbors=neighbors, normalize_output=True)
        dim_out = 256
    return module, dim_out
