import numbers

import numpy as np
import torch
from torch import nn

import gen_neg_toy.classifier
from gen_neg_toy.utils import expand_tensor_dims_as
from gen_neg_toy.utils.nn import ModelBase, zero_module

from .modules import ResidualBlock, _act_init_dict, timestep_embedding
from .precond import get_precond


class ToyNet(ModelBase):
    def __init__(self, model_config, dim=None, in_dim=None, t_dim=128):
        super().__init__()
        self.bridge_sigma_max = getattr(model_config, "bridge_sigma_max", None)
        self.bridge_scale_schedule = getattr(model_config, "bridge_scale_schedule", None)
        dim = model_config.dim if dim is None else dim
        self.precond = get_precond(model_config.precond)()
        self.classifier = None
        self.bridge = None
        if getattr(model_config, "classifier", None) is not None and model_config.classifier != []:
            classifier = model_config.classifier
            if isinstance(classifier, str):
                classifier = classifier.split(",")
                if len(classifier) == 1:
                    classifier = classifier[0]
            if isinstance(classifier, list):
                self.classifier = nn.ModuleList(
                    [
                        gen_neg_toy.classifier.dispatch_model_from_path(path)[0]
                        for path in classifier
                    ]
                )
                for classifier in self.classifier:
                    classifier.requires_grad_(False)
            else:
                self.classifier, _ = gen_neg_toy.classifier.dispatch_model_from_path(model_config.classifier)
                self.classifier.requires_grad_(False)

        self._dim = dim
        in_dim = self.dim if in_dim is None else in_dim
        self.h_dim = model_config.h_dim
        self.t_dim = t_dim
        self.in_layer = nn.Linear(in_dim, self.h_dim)

        self.act = _act_init_dict[model_config.act]()
        self.res_layers = nn.ModuleList(
            [
                ResidualBlock(
                    dim=self.h_dim,
                    t_dim=self.t_dim,
                    h_dim=self.h_dim,
                    act=model_config.act,
                    group_norm=(model_config.group_norm != 0),
                )
                for _ in range(model_config.hidden_layers // 2)
            ]
        )
        # Zero-out the last layer of the residual blocks such that their output is equal to their input
        for module in self.res_layers:
            if isinstance(module, ResidualBlock):
                zero_module(module.fc2)

        self.out_layer = nn.Linear(self.h_dim, self.dim)

    def bridge_update(self, x, sigma, neg_guidance=False, classifier=None, **model_kwargs):
        classifier = classifier or self.classifier
        if classifier is None:
            return 0
        if classifier is not None:
            if isinstance(classifier, torch.nn.modules.container.ModuleList):
                res = 0
                for c in classifier:
                    res += self.bridge_update(x, sigma, neg_guidance, c, **model_kwargs)
                return res
            if isinstance(classifier, gen_neg_toy.classifier.ClassifierNet):
                with torch.enable_grad():
                    x_var = x.detach().requires_grad_(True)
                    pred = classifier(x_var.to(torch.float32), sigma.to(torch.float32))
                    if neg_guidance:
                        assert isinstance(classifier, gen_neg_toy.classifier.SigmoidClassifier)
                        pred = -pred
                    if hasattr(classifier, "out_to_log_p_scaled"):
                        log_p = classifier.out_to_log_p_scaled(pred, sigma)
                        to_scale = False
                    else:
                        log_p = classifier.out_to_log_p(pred, sigma)
                        to_scale = True
                    grad = torch.autograd.grad(log_p.sum(), x_var)[0]
                grad = grad.detach()
                if to_scale:
                    grad *= expand_tensor_dims_as(sigma ** 2, grad)
                return grad
            else:
                raise ValueError("Unknown classifier type: {}".format(classifier))
        else:
            raise ValueError("ERROR")

    def forward(self, x, sigma, **model_kwargs):
        res = self.precond(self.forward_core, x, sigma, **model_kwargs)
        bridge = self.bridge_update(x, sigma, **model_kwargs)
        if self.bridge_sigma_max is not None:
            bridge_update_mask = sigma < self.bridge_sigma_max
            bridge = bridge * expand_tensor_dims_as(bridge_update_mask, bridge)
        if self.bridge_scale_schedule is not None:
            if self.bridge_scale_schedule == "const":
                scale = torch.ones_like(bridge)
            elif self.bridge_scale_schedule == "linear":
                scale = (80 - sigma.clip(0.002, 80)) / (80 - 0.002)
            elif self.bridge_scale_schedule == "loglinear":
                scale = (np.log(80) - sigma.clip(0.002, 80).log()) / (np.log(80) - np.log(0.002))
            else:
                raise ValueError("Unknown bridge scale schedule: {}".format(self.bridge_scale_schedule))
            bridge *= expand_tensor_dims_as(scale, bridge)
        return res + bridge

    def forward_core(self, x, sigma):
        temb = timestep_embedding(sigma, self.t_dim)
        h = self.act(self.in_layer(x))
        for layer in self.res_layers:
            h = self.act(layer(h, temb))
        return self.out_layer(h)

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)

    @property
    def dim(self):
        return self._dim

    @property
    def sigma_min(self):
        return self.precond.sigma_min

    @property
    def sigma_max(self):
        return self.precond.sigma_max

    def update_with_observations(self, x, **kwargs):
        # This module is identity, but is different in child classes of this class
        return x

    def prepare_model_kwargs(self, xt):
        return {}

    @property
    def mode(self):
        return "uncond"


def dispatch_model(model_config):
    model = ToyNet(model_config)
    return model


def dispatch_model_from_path(path, strict=True, **adjusted_kwargs):
    checkpoint_data = torch.load(path)
    config = checkpoint_data["config"]
    for k,v in adjusted_kwargs.items():
        print(f"(model dispatch) Modifying the attribute {k} from {getattr(config.model, k, None)} to {v}.")
        config.model.unlock()
        with config.model.ignore_type():
            setattr(config.model, k, v)
        config.model.lock()
    model = dispatch_model(config.model).to(config.device)
    model.load(path, strict=strict)
    return model, config