
import torch
import sde_lib
import numpy as np
import os


_MODELS = {}


def register_model(cls=None, *, name=None):
  """A decorator for registering model classes."""

  def _register(cls):
    if name is None:
      local_name = cls.__name__
    else:
      local_name = name
    if local_name in _MODELS:
      raise ValueError(f'Already registered model with name: {local_name}')
    _MODELS[local_name] = cls
    return cls

  if cls is None:
    return _register
  else:
    return _register(cls)


def get_model(name):
  return _MODELS[name]


def create_score_model(config):
  model_name = config.model.name
  score_model = get_model(model_name)(config)
  score_model = score_model.to(config.device)
  score_model = torch.nn.DataParallel(score_model)
  return score_model

def get_model_fn(model, train=False):

  def model_fn(x, labels):

    if not train:
      model.eval()
      return model(x, labels)
    else:
      model.train()
      return model(x, labels)

  return model_fn


def get_score_fn(sde, model, train=False):

  model_fn = get_model_fn(model, train=train)

  if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
    def score_fn(x, t):
      labels = t * 999
      score = model_fn(x, labels)
      std = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = -score / std[:, None]
      return score

  elif isinstance(sde, sde_lib.VESDE):
    def score_fn(x, t):
      labels = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = model_fn(x, labels)
      return score

  return score_fn

def get_conditional_model_fn(model, train=False):

  def model_fn(h, x, labels):

    if not train:
      model.eval()
      return model(h, x, labels)
    else:
      model.train()
      return model(h, x, labels)

  return model_fn


def get_conditional_score_fn(sde, model, train=False):

  model_fn = get_conditional_model_fn(model, train=train)

  if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
    def score_fn(h, x, t):
      labels = t * 999
      score = model_fn(h, x, labels)
      std = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = -score / std[:, None]
      return score

  elif isinstance(sde, sde_lib.VESDE):
    def score_fn(x, t):
      labels = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = model_fn(x, labels)
      return score

  return score_fn

def get_sigmas(config):

  sigmas = np.exp(
    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))

  return sigmas