from spaghettini import quick_register

from torch import nn
from torch.nn import ReLU, Linear, Conv1d

from src.dl.models.skeleton import Skeleton
from src.dl.models.cell_types.fully_connected import OneLayerFullyConnectedCell
from src.dl.models.cell_types.residual import ResNetLayer
from src.dl.models.hard_coded.prefix_sum import PrefixSumInputPreprocessor
from src.dl.models.building_blocks.convolutional import BasicBlock, Resnet1DLayer, Resnet1DBasicBlock


@quick_register
def get_default_prefix_sum_skeleton(forward_solver, backward_solver, weight_init_std=None, mask_input_injection=False,
                                    z0_init_method="zeros", weight_normalization=True,
                                    num_additional_unroll_steps_after_implicit_forward=0,
                                    deq_jacobian_scaling=1., use_single_layer_classifier=False):
    width = 120
    kernel = (3,)
    stride = 1
    num_blocks = 2
    out_features = 2
    num_pretraining_layers = 10
    wnorm = weight_normalization
    input_preprocessor = PrefixSumInputPreprocessor(width=width, kernel_size=kernel, stride=stride,
                                                    padding=1, bias=False)
    cell = Resnet1DLayer(block=Resnet1DBasicBlock, planes=width, in_planes=width, num_blocks=num_blocks,
                         stride=stride, wnorm=wnorm)
    if not use_single_layer_classifier:
        classifier_layer = nn.Sequential(
            nn.Conv1d(width, width, kernel_size=kernel, stride=(stride,), padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(width, int(width / 2), kernel_size=kernel, stride=(stride,), padding=1, bias=False),
            nn.ReLU(),
            nn.Conv1d(int(width / 2), out_features, kernel_size=kernel, stride=(stride,), padding=1, bias=False)
        )
    else:
        classifier_layer = nn.Sequential(
            nn.Conv1d(width, out_features, kernel_size=kernel, stride=(stride,), padding=1, bias=False)
        )

    return Skeleton(
        input_preprocessor=input_preprocessor,
        cell=cell,
        classifier_layer=classifier_layer,
        forward_solver=forward_solver,
        backward_solver=backward_solver,
        num_pretraining_layers=num_pretraining_layers,
        mask_input_injection=mask_input_injection,
        z0_init_method=z0_init_method,
        weight_init_std=weight_init_std,
        num_additional_unroll_steps_after_implicit_forward=num_additional_unroll_steps_after_implicit_forward,
        deq_jacobian_scaling=deq_jacobian_scaling
    )
