import torch
from gen_neg_toy.utils import get_register_fn
from gen_neg_toy.utils.nn import ModelBase, zero_module
from torch import nn

from gen_neg_toy.models.modules import ResidualBlock, timestep_embedding
from gen_neg_toy.utils import expand_tensor_dims_as

from .models.precond import get_precond

_CLASSIFIERS = {}
register_classifier = get_register_fn(_CLASSIFIERS)


class ClassifierNet(ModelBase):
    def __init__(self, precond, in_dim=2, h_dim=256, res_blocks=2, t_dim=128):
        super().__init__()
        self.precond = get_precond(precond)()
        self.h_dim = h_dim
        self.t_dim = t_dim
        self.in_layer = nn.Linear(in_dim, self.h_dim)

        self.act = nn.SiLU()
        self.res_layers = nn.ModuleList(
            [
                ResidualBlock(
                    dim=self.h_dim,
                    t_dim=self.t_dim,
                    h_dim=self.h_dim,
                    act="silu",
                )
                for _ in range(res_blocks)
            ]
        )
        # Zero-out the last layer of the residual blocks such that their output is equal to their input
        for module in self.res_layers:
            if isinstance(module, ResidualBlock):
                zero_module(module.fc2)

        self.out_layer = nn.Linear(self.h_dim, 1)

    def forward(self, x, sigma):
        precond_params = self.precond.get_params(sigma)
        c_in = expand_tensor_dims_as(precond_params["c_in"], x)
        c_noise = precond_params["c_noise"]
        x = c_in * x
        sigma = c_noise

        temb = timestep_embedding(sigma, self.t_dim)
        h = self.act(self.in_layer(x))
        for layer in self.res_layers:
            h = self.act(layer(h, temb))
        return self.out_layer(h)


@register_classifier(name="softplus")
class SoftplusClassifier(ClassifierNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def out_to_log_p(self, pred, sigma):
        return -torch.nn.Softplus()(pred.to(torch.float64))

    def out_to_p(self, pred, sigma):
        return self.out_to_log_p(pred, sigma).exp()

    def criterion(self, pred, target, sigma, **kwargs):
        return torch.nn.functional.binary_cross_entropy(
            self.out_to_p(pred, sigma), target.to(torch.float64), **kwargs
        )


@register_classifier(name="sigmoid")
class SigmoidClassifier(ClassifierNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def out_to_log_p(self, pred, sigma):
        return torch.nn.functional.logsigmoid(pred)

    def out_to_p(self, pred, sigma):
        return torch.sigmoid(pred)

    def criterion(self, pred, target, sigma, **kwargs):
        return torch.nn.functional.binary_cross_entropy_with_logits(
            pred, target, **kwargs
        )


@register_classifier(name="rescaled_softplus")
class RescaledSoftplusClassifier(ClassifierNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def out_to_log_p(self, pred, sigma):
        return (-torch.nn.Softplus()(pred.to(torch.float64))) / (sigma**2)

    def out_to_log_p_scaled(self, pred, sigma):
        return -torch.nn.Softplus()(pred.to(torch.float64))

    def out_to_p(self, pred, sigma):
        return self.out_to_log_p(pred, sigma).exp()

    def criterion(self, pred, target, sigma, **kwargs):
        return torch.nn.functional.binary_cross_entropy(
            self.out_to_p(pred, sigma), target.to(torch.float64), **kwargs
        )


def dispatch_model(model_config):
    return _CLASSIFIERS[model_config.classifier_param](model_config.precond)


def dispatch_model_from_path(path):
    checkpoint_data = torch.load(path)
    config = checkpoint_data["config"]
    model = dispatch_model(config.model).to(config.device)
    model.load(path)
    return model, config
