# mean-scale Gaussian EM with element-wise CDF
import numpy as np
import scipy.stats

import torch
import torch.nn as nn
import torch.nn.functional as F

import yaecl
from compressai.ops import LowerBound

HALF = float(0.5)
SCALE_UPPER_BOUND = False

class UpperBound(nn.Module):
  def __init__(self, bound):
    super().__init__()
    self.lower_bound = LowerBound(-bound)
  def forward(self, x):
    return -self.lower_bound(-x)

class MSGau_CDF(nn.Module):
  def __init__(self, 
    *args,
    sym_min=-128, sym_max=127, scale_bound=0.11, scale_upper_bound=256.0, ll_bound=0.5**16,
    **kwargs,
  ):
    super().__init__()
    assert sym_min + 1 < sym_max
    assert scale_bound > 0
    assert ll_bound > 0
    self.sym_min = int(sym_min)
    self.sym_max = int(sym_max)
    self.sym_lower_bound = LowerBound(self.sym_min)
    self.sym_upper_bound = UpperBound(self.sym_max)
    self.scale_lower_bound = LowerBound(scale_bound)
    self.scale_upper_bound = UpperBound(scale_upper_bound) if SCALE_UPPER_BOUND else nn.Identity()
    self.likelihood_lower_bound = LowerBound(ll_bound)

  def sym_bound(self, sym):
    return self.sym_lower_bound(self.sym_upper_bound(sym))

  @staticmethod
  def _standardized_cumulative(x):
    # Using the complementary error function maximizes numerical precision, considering erf(x) == -erf(-x)
    const = float(-(2**-0.5))
    return HALF * torch.erfc(const * x)

  def _likelihood(self, symbols, scales, means):
    # cdf N(u, s) at x == cdf N(0,1) at (x-u)/s
    scales = self.scale_lower_bound(scales)
    scales = self.scale_upper_bound(scales)
    upper = self._standardized_cumulative((symbols - means + HALF) / scales)
    lower = self._standardized_cumulative((symbols - means - HALF) / scales)
    if not (upper >= lower).all():
      print('possible extreme scales', scales.min().detach(), scales.max().detach())
      print('possible outrange y symbols', symbols.min().detach(), symbols.max().detach())
      raise
    return upper - lower
  
  def forward(self, symbols, scales, means):
    # range of symbols must be limited
    assert (symbols >= self.sym_min).all()
    assert (symbols <= self.sym_max).all()
    return self.likelihood_lower_bound(self._likelihood(symbols, scales, means))
  
  def cal_cdf(self, scales, means):
    samples = torch.arange(
      self.sym_min, self.sym_max + 1, 
      device='cpu', dtype=scales.dtype)
    # WARNING: extremely huge
    pmf = self._likelihood(
      samples[None, :], 
      scales.detach().cpu().flatten()[:, None], 
      means .detach().cpu().flatten()[:, None])
    if (pmf.max(dim=1)[0] <= 0).any(): # mean out of sym range and scale is small
      idx = pmf.max(dim=1)[0] <= 0
      lim = -(means.flatten()[idx] >= 0).int()
      pmf[idx, lim.tolist()] += 1.0
    # pmf to cdf avoid zero-probability
    cdf = pmf.cumsum(dim=1)
    pmf_reserved = self.sym_max - self.sym_min + 1
    cdf = (cdf / cdf.max(dim=1, keepdim=True)[0] * (2**16 - pmf_reserved)).round()
    cdf = cdf + torch.arange(pmf_reserved)[None, :] + 1
    cdf = torch.cat([
      torch.zeros(cdf.size(0), 1, dtype=cdf.dtype, device=cdf.device),
      cdf
    ], dim=1)
    # assert (cdf.diff(dim=1) >= 1).all()
    # assert (cdf[:, 0]  == 0    ).all()
    # assert (cdf[:, -1] == 65536).all()
    return cdf
  
  def compress(self, symbols, scales, means):
    # range of symbols must be limited
    assert (symbols >= self.sym_min).all()
    assert (symbols <= self.sym_max).all()
    sym = symbols.detach().cpu().flatten().to(torch.int32).numpy() - self.sym_min
    cdf = self.cal_cdf(scales, means)     .to(torch.int32).numpy()
    ac_enc = yaecl.ac_encoder_t()
    ac_enc.encode_nxn(sym, cdf, 16)
    ac_enc.flush()
    return ac_enc.bit_stream
  
  def decompress(self, bit_stream, scales, means):
    cdf = self.cal_cdf(scales, means).to(torch.int32).numpy()
    sym = np.zeros(scales.shape, dtype=np.int32).flatten()
    ac_dec = yaecl.ac_decoder_t(bit_stream)
    ac_dec.decode_nxn(self.sym_max - self.sym_min + 1, memoryview(cdf), 16, memoryview(sym))
    sym = torch.from_numpy(sym + self.sym_min).to(scales.device).to(scales.dtype).reshape(scales.shape)
    return sym