import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import distrax

from jax.scipy.special import logsumexp
import smarter_jax as sj

import math

class Flow(eqx.Module):
  prior: distrax.Distribution = eqx.field(init=False, static=True, default=None)
  prior_shape: tuple[int, ...] = eqx.field(init=False, default=None)
  
  def add_prior(self, prior, shape):
    object.__setattr__(self, "prior", prior)
    object.__setattr__(self, "prior_shape", shape)
    return self
  
  def encode(self, x, **kwargs):
    return self._forward(x, 0, **kwargs)[0]

  @sj.with_subkeys
  def sample(self, key, **kwargs):
    assert self.prior != None, \
      "Must call `add_prior` before calling `sample`"
    z = self.prior.sample(seed=next(key), sample_shape=self.prior_shape)
    return self._inverse(z, 0, key=next(key), **kwargs)[0]

  def log_prob(self, x, reps=1, **kwargs):
    assert self.prior != None, \
      "Must call `add_prior` before calling `log_prob`"
    if reps > 1:
      assert 'key' in kwargs, \
        "Must provide key if reps > 1"
      z, ldj = sj.vmap(self._forward, splitkey=reps,
                         in_axes=None, kw_axes=None)(x, 0, **kwargs)
      return (logsumexp(self.prior.log_prob(z)
                        .reshape(z.shape[0], -1)
                        .sum(axis=-1) + ldj)
              - jnp.log(reps))
    else:
      z, ldj = self._forward(x, 0, **kwargs)
      return self.prior.log_prob(z).sum() + ldj

  def loss(self, *args, **kwargs):
    return -self.log_prob(*args, **kwargs)

  
class Identity(Flow):
  def _forward(self, x, ldj, **kwargs):
    return x, ldj

  def _inverse(self, x, ldj, **kwargs):
    return x, ldj

  
class Inverse(Flow):
  flow: Flow

  def __new__(cls, flow, *args, **kwargs):
    if flow.__class__ == Inverse:
      return flow.flow
    else:
      return super().__new__(cls, *args, **kwargs)

  def _forward(self, x, ldj, **kwargs):
    return self.flow._inverse(x, ldj, **kwargs)

  def _inverse(self, x, ldj, **kwargs):
    return self.flow._forward(x, ldj, **kwargs)

  
class Reverse(Flow):
  def _forward(self, x, ldj, **kwargs):
    x = jnp.transpose(x, jnp.arange(len(x.shape))[::-1])
    return x, ldj

  def _inverse(self, x, ldj, **kwargs):
    x = jnp.transpose(x, jnp.arange(len(x.shape))[::-1])
    return x, ldj

  
class Rescale(Flow):
  alpha: float

  def _forward(self, x, ldj, **kwargs):
    ldj += jnp.log(self.alpha) * x.size
    x = x * self.alpha
    return x, ldj
            
  def _inverse(self, x, ldj, **kwargs):
    ldj -= jnp.log(self.alpha) * x.size
    x = x / self.alpha
    return x, ldj


class Reshape(Flow):
  in_shape: tuple
  out_shape: tuple

  def _forward(self, x, ldj, **kwargs):
    x = jnp.reshape(x, self.out_shape)
    return x, ldj

  def _inverse(self, x, ldj, **kwargs):
    x = jnp.reshape(x, self.in_shape)
    return x, ldj


class Sigmoid(Flow):
  eps: float = 0.
  
  def _forward(self, x, ldj, **kwargs):
    ldj += (-x + 2 * jnp.log(jax.nn.sigmoid(x))).sum()
    x = jax.nn.sigmoid(x)

    if self.eps is not None:
      ldj -= jnp.log(1 - self.eps) * x.size
      x = (x - 0.5 * self.eps) / (1 - self.eps)

    return x, ldj

  def _inverse(self, x, ldj, **kwargs):
    if self.eps is not None:
      ldj += jnp.log(1 - self.eps) * x.size
      x = x * (1 - self.eps) + 0.5 * self.eps
      
    ldj -= jnp.log(x - x ** 2).sum()
    x = jnp.log(x) - jnp.log(1-x)

    return x, ldj

  
class Sequential(Flow):
  flows: list[Flow]

  def loss(self, *args, **kwargs):
    if any(f.prior is not None for f in self.flows):
      return self.residual_loss(*args, **kwargs)
    else:
      return super().loss(*args, **kwargs)

  @sj.with_subkeys
  def _forward(self, x, ldj, *, key=None, **kwargs):
    for f in self.flows:
      x, ldj = f._forward(x, ldj, key=next(key), **kwargs)
    return x, ldj

  @sj.with_subkeys
  def _inverse(self, x, ldj, *, key=None, **kwargs):
    for f in self.flows[::-1]:
      x, ldj = f._inverse(x, ldj, key=next(key), **kwargs)
    return x, ldj

  @sj.with_subkeys
  def residual_loss(self, x, *, reps=1, key=None, **kwargs):
    assert reps == 1

    rv = 0
    ldj = 0
    for i,f in enumerate(self.flows):
      x, ldj = f._forward(x, ldj, key=next(key), **kwargs)

      if f.prior is not None:
        rv += f.prior.log_prob(x).sum() + ldj
        x = jax.lax.stop_gradient(x)
        ldj = 0

    if self.prior is not None:
      rv += self.prior.log_prob(x).sum() + ldj

    return -rv


class Dequantize(Flow):
  max_val: ... = 255
  in_dtype: ... = jnp.uint8
  out_dtype: ... = jnp.float32
  var_flow: Flow = None
  eps: float = 1e-5
  squash: Flow = eqx.field(init=False, default=None)

  def __post_init__(self):
    if self.var_flow:
      self.var_flow = Sequential(
          [ Inverse(Sigmoid(self.eps)), self.var_flow, Sigmoid(self.eps) ])

    self.squash = Rescale(1 / (self.max_val + 1))

  @sj.with_subkeys
  def _forward(self, x, ldj, *, key, **kwargs):
    v = jax.random.uniform(next(key), x.shape,
                           dtype = self.out_dtype, 
                           minval=0.0, maxval=1.0)
    if self.var_flow:
      sideinfo = x / self.max_val * 2 - 1
      v, ldj =  self.var_flow._forward(v, ldj, key=next(key), sideinfo=sideinfo)
    x = x + v

    x, ldj = self.squash._forward(x, ldj)
    
    return x, ldj

  @sj.with_subkeys
  def _inverse(self, x, ldj, key=None, **kwargs):
    x,ldj = self.squash._inverse(x, ldj)

    floor_x = jnp.floor(x).astype(self.in_dtype)
    v = x - floor_x
    x = floor_x

    if self.var_flow:
      sideinfo = x / self.max_val * 2 - 1
      v, ldj = self.var_flow._inverse(v, ldj, key=next(key), sideinfo=sideinfo)

    return x, ldj


class ParameterizedAffine(Flow):
  """One-dimensional affine transformation
  """
  params: jnp.ndarray
  scale: jnp.ndarray

  def __init__(self, shape, params=None):
    self.params = params
    self.scale = jnp.zeros(shape)

  def _forward(self, x, ldj, *, params=None, **kwargs):
    assert params is not None or self.params is not None, \
      "No parameters provided"
    params = params if params is not None else self.params
    a, b = params[0:2]

    scale = jnp.exp(self.scale)
    a = jnp.tanh(a / scale) * scale

    x = x * jnp.exp(a) + b
    ldj += a.sum()

    return x, ldj

  def _inverse(self, x, ldj, *, params=None, **kwargs):
    assert params is not None or self.params is not None, \
      "No parameters provided"
    params = params if params is not None else self.params
    a, b = params[0:2]

    scale = jnp.exp(self.scale)
    a = jnp.tanh(a / scale) * scale

    x = (x - b) / jnp.exp(a)
    ldj -= a.sum()

    return x, ldj


class ParameterizedHinge(Flow):
  """One-dimensional invertible hinge transformation
  """
  params: jnp.array = None
  scale: jnp.array = None

  def __init__(self, shape, params=None):
    self.params = params
    self.scale = jnp.zeros((2,) + shape)

  def _forward(self, x, ldj, *, params=None, **kwargs):
    assert params is not None or self.params is not None, \
      "No parameters provided"
    params = params if params is not None else self.params

    scale = jnp.exp(self.scale)
    a, b = list(jnp.tanh(params[:2] / scale) * scale)

    ldj += jnp.where(x < 0, a, b).sum()
    x = jnp.where(x < 0, x * jnp.exp(a), x * jnp.exp(b))

    return x, ldj

  def _inverse(self, x, ldj, *, params=None, **kwargs):
    assert params is not None or self.params is not None, \
      "No parameters provided"
    params = params if params is not None else self.params

    scale = jnp.exp(self.scale)
    a, b = list(jnp.tanh(params[:2] / scale) * scale)

    ldj -= jnp.where(x < 0, a, b).sum()
    x = jnp.where(x < 0, x / jnp.exp(a), x / jnp.exp(b))

    return x, ldj


class ParameterizedNLSq(Flow):
  """One-dimensional non-linear squared transformation

  f(x) = a + bx + c / (1 + (dx + f)^2)

  This is guaranteed to be invertible under certain requirements
  which are guaranteed by the parameter representation.
  """
  params: jnp.ndarray
  scale_b: jnp.ndarray
  scale_d: jnp.ndarray

  _logC = math.log(8 * math.sqrt(3)/9 * 0.95)

  def __init__(self, shape, params=None):
    self.params = params
    self.scale_b = jnp.zeros(shape)
    self.scale_d = jnp.zeros(shape)

  def _get_params(self, params):
    assert params is not None or self.params is not None, \
      "No parameters provided"
    params = params if params is not None else self.params
    _a, _b, _c, _d, _f = params[0:5]

    scale_b, scale_d = jnp.exp(self.scale_b), jnp.exp(self.scale_d)
    _b = jnp.tanh(_b / scale_b) * scale_b
    _d = jnp.tanh(_d / scale_d) * scale_d

    a = _a
    b = jnp.exp(_b)
    d = jnp.exp(_d)
    c = jnp.tanh(_c) * jnp.exp(self._logC + _b - _d)
    f = _f

    return a, b, c, d, f

  def _forward(self, x, ldj, *, params=None, **kwargs):
    a, b, c, d, f = self._get_params(params)

    u = d * x + f
    # TODO(): avoid both exp and log calls for b and d?
    ldj += jnp.log(b - (2*c*d*u/((1 + u**2)**2))).sum()
    x = a + b * x + c / (1 + u ** 2)

    return x, ldj

  def _inverse(self, x, ldj, *, params=None, **kwargs):
    a, b, c, d, f = self._get_params(params)

    A = -b * d**2
    B = (x - a) * d**2 - 2 * b * d * f
    C = 2 * d * f * (x - a) - b * (f**2 + 1)
    D = (x - a) * (f**2 + 1) - c

    xN = -B / (3 * A)
    yN = A * xN**3 + B * xN**2 + C * xN + D
    deltasq = (B**2 - 3 * A * C)/(9 * A**2)

    delta = jnp.sqrt(jnp.abs(deltasq))
    h = 2 * A * delta**3

    sign = jnp.sign(yN/h)

    real_x = xN - 2 * sign * delta * jnp.cosh(jnp.arccosh(sign * yN/h)/3)
    imag_x = xN - 2 * delta * jnp.sinh(jnp.arcsinh(yN/h)/3)
    x= jnp.select([deltasq >= 0, deltasq < 0], [real_x, imag_x])

    u = d * x + f
    # TODO(): avoid both exp and log calls for b and d?
    ldj -= jnp.log(b - (2*c*d*u/((1 + u**2)**2))).sum()

    return x, ldj


class Coupling(Flow):
  mask: jax.Array
  transform_param_net: eqx.Module
  transform: Flow = None
  dual: bool = False

  def __post_init__(self):
    if self.transform is None:
      self.transform = ParameterizedAffine(self.mask.shape)

  def _get_params(self, x, mask, sideinfo=None, **kwargs):
    if sideinfo is not None:
      params = self.transform_param_net(x * mask, sideinfo)
    else:
      params = self.transform_param_net(x * mask)

    return params.reshape((-1,) + x.shape) * jnp.expand_dims(1-mask, 0)

  def _forward(self, x, ldj, **kwargs):
    params = self._get_params(x, self.mask, **kwargs)
    x, ldj = self.transform._forward(x, ldj, params=params)

    if self.dual:
      params = self._get_params(x, 1-self.mask, **kwargs)
      x, ldj = self.transform._forward(x, ldj, params=params)

    return x, ldj
  
  def _inverse(self, x, ldj, **kwargs):
    if self.dual:
      params = self._get_params(x, 1-self.mask, **kwargs)
      x, ldj = self.transform._inverse(x, ldj, params=params)

    params = self._get_params(x, self.mask, **kwargs)
    x, ldj = self.transform._inverse(x, ldj, params=params)

    return x, ldj

  
def create_mask(shape, block_size, dtype=jnp.uint8):
  """Generates masks for use with Coupling layers.

  shape: shape of the mask
  block_size: size of a constant-mask block

  If block_size[i] is negative, it partitions the i axis into block_size[i]
  number of blocks.

  Examples: 
  >>> create_mask((8, 8), (2, 2))
  [[0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [1 1 0 0 1 1 0 0]
   [1 1 0 0 1 1 0 0]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [1 1 0 0 1 1 0 0]
   [1 1 0 0 1 1 0 0]]

  >>> create_mask((8, 8), (-1, 2))
  [[0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]]

  >>> create_mask((8, 8), (-2, -4))
  [[0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [0 0 1 1 0 0 1 1]
   [1 1 0 0 1 1 0 0]
   [1 1 0 0 1 1 0 0]
   [1 1 0 0 1 1 0 0]
   [1 1 0 0 1 1 0 0]]
  """
  block_size = [ shape[i] // -d if d < 0 else d 
                 for i,d in enumerate(block_size) ]
  mask = sum([ jnp.reshape(jnp.arange(shape[i]) // block_size[i], 
                           [1 if j != i else -1 for j in range(len(shape)) ])
               for i in range(len(shape)) ]) % 2

  return mask

def checkerboard_mask(shape, dtype=jnp.uint8):
  return create_mask(shape, (1,)*len(shape), dtype)

def channel_mask(shape, dtype=jnp.uint8):
  return create_mask(shape, (-1,)*len(shape) + (-2,), dtype)


@eqx.filter_jit
def train(model, opt, opt_state, x, weights=None, **kwargs):
  @eqx.filter_value_and_grad
  def compute_loss(model, x, weights, **kwargs):
    batch_loss = sj.vmap(model.loss,splitkey=True)(x, **kwargs)
    if weights is None:
      return jnp.mean(batch_loss)
    else:
      return jnp.mean(batch_loss * weights)

  loss_value, grads = compute_loss(model, x, weights, **kwargs)
  updates, opt_state = opt.update(grads, opt_state)
  model = eqx.apply_updates(model, updates)

  return model, opt_state, loss_value
