from spaghettini import quick_register
import torch
from torch import nn
from torch.nn import Linear

from src.dl.models.fully_connected import FCNetFixedWidth

VERY_BIG = 1e10


@quick_register
class BinaryErasureFCVerifier(nn.Module):
    def __init__(self, input_dim, proof_dim, hid_dim, output_dim, num_layers, activation,
                 layer_norm=True, use_spectral_norm=False):
        super().__init__()
        self.input_dim = input_dim
        self.proof_dim = proof_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.act = activation
        self.layer_norm = layer_norm
        self.use_spectral_norm = use_spectral_norm

        self.fcnet = FCNetFixedWidth(num_inputs=input_dim + proof_dim, num_hidden_dim=hid_dim, num_outputs=output_dim,
                                     num_hidden_layers=num_layers, activation_init=self.act,
                                     use_layer_norm=layer_norm, use_spectral_norm=use_spectral_norm)

    def forward(self, xs, proofs_list, **kwargs):
        proofs = proofs_list[0] if type(proofs_list) == list and len(proofs_list) == 1 else proofs_list
        bs = proofs.shape[0]

        # Flatten the inputs and proofs. Concatenate.
        xs_flat, proofs_flat = xs.view(bs, -1), proofs.view(bs, -1)
        xs_cat = torch.cat([xs_flat, proofs_flat], dim=1)

        # Make decision.
        zs = self.fcnet(xs_cat)

        # Create a model dict for further logging.
        model_dict = dict(verifier_logits=zs)

        # Create dummy dictionary for auxiliary head outputs.
        aux_outs = dict()

        return zs, aux_outs, model_dict


@quick_register
class BinaryErasureLinearProofGenerator(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features=in_features, out_features=out_features, bias=bias)

    def forward(self, feats):
        prover_outs = super().forward(input=feats)
        model_dict = dict(prover_logits=prover_outs)

        return prover_outs, model_dict
