import sys
import timeit 
import random 
import math 
# 
import numpy as np
from numpy.polynomial.hermite import hermgauss
from scipy.spatial import distance
from scipy.special import logsumexp
from scipy.linalg import sqrtm
from scipy.stats import qmc
#from scipy.ndimage import fourier_gaussian
#
from pykeops.numpy import LazyTensor
# GQ compression
import gq
# Fourier compression
import compression 
# to do parallel compression with recombination
import multiprocessing
from multiprocessing import Pool
print(multiprocessing.cpu_count())
pool = Pool(processes=2)
#  
print("Importing algorithms.py")
#         
def rec_compress_both_parallel(ep,wy,y,wx,x,params):
  c=params["t_sampling"]
  method=params["fourier_solver"]
  if "L_scale" in params:
    L=params["L_scale"]
  else:
    L=1.0
  out=pool.map(compression.rec_worker,[[ep,wy,y,m,c,method,L],[ep,wx,x,m,c,method,L]])            #
  wy=out[0][0]; yy=out[0][1]; idx_y=out[0][2]
  wx=out[1][0]; xx=out[1][1]; idx_x=out[1][2]      
  return wy,yy,idx_y,wx,xx,idx_x
# evaluate soft C-transform
# Parameters:
#   q,x: definite potential at nodes x
#   t: where to evaluate C-transform
#   ep:epsilon
# 
    
# def evaluate_potential(q, x, t, ep):
#     C = cost(x, t)
#     f = -ep*logsumexp((q[:,None]-C)/ep, axis=0)
#     return f

# fast version of above using Keops
# x has dimension (m,d)
# q as dimension  (m)
# t has hadimension (n,d)
def evaluate_potential_orig(q, x, t, ep, c='sqeuclidean'):
  if x.ndim == 1:
    x_i = LazyTensor(x[:,None, None])  # (M, 1, 1) 
    q_i = LazyTensor(q[None,:,None,None]) # M
    t_j = LazyTensor(t[None, :, None])  # (1, N, 1)
  else: 
    x_i = LazyTensor(x[None, :, :])  # (1, M, D) 
    q_i = LazyTensor(q[None,:,None]) # (1,M,1)
    t_j = LazyTensor(t[:, None, :])  # ( N,1, D) 
  #
  D_ij = x_i - t_j  # (M, N, D) symbolic tensor of differences
  C_ij = (D_ij**2).sum(-1)  # 1d: (M, N, 1) = (M, N) symbolic matrix of squared distances
  if (c == 'euclidean'):
    A_ij= (q_i-C_ij**(0.5))/ep # euclidean
  elif (c == 'sqeuclidean'):
    A_ij= (q_i-C_ij)/ep # sqeuclidean
  # 
  f=-ep*A_ij.logsumexp(dim=1)
  return f.squeeze()

def evaluate_potential(q, x, t, ep, c='sqeuclidean'):
  if x.ndim == 1:
    x_i = LazyTensor(x[None,:, None])  # (M, 1, 1) 
    q_i = LazyTensor(q[None,:,None]) # M
    t_j = LazyTensor(t[:,None, None])  # (1, N, 1)
    C_ij = (x_i - t_j)**2  #
  else: 
    x_i = LazyTensor(x[None, :, :])  # (1, M, D) 
    q_i = LazyTensor(q[None,:,None]) # (1,M,1)
    t_j = LazyTensor(t[:, None, :])  # ( N,1, D) 
    D_ij = x_i - t_j  # (M, N, D) symbolic tensor of differences
    C_ij = (D_ij**2).sum(-1)  # 1d: (M, N, 1) = (M, N) symbolic matrix of squared distances
  #
  
  if (c == 'euclidean'):
    A_ij= (q_i-C_ij**(0.5))/ep # euclidean
  elif (c == 'sqeuclidean'):
    A_ij= (q_i-C_ij)/ep # sqeuclidean
  # 
  f=-ep*A_ij.logsumexp(dim=1)
  return f.squeeze()


# cost matrix
def cost(xt, yt, c='sqeuclidean'):
    # return cost matrix
    assert(xt.ndim==yt.ndim)
    if xt.ndim == 1:
      x = xt[:,None]
      y = yt[:,None]
    else:
      x = xt.copy()
      y = yt.copy()
    D = distance.cdist(x, y, c)
    return D

         

# OT objective
# Paramters:
#   fx,gy,x,y: defines potentials at nodes x,y
#   samples of target distributon: a,b
#   ep: epsilon
def obj_func(fx, gy, x, y, a, b, ep):
    # evaluate dual objective function (Eq (2) from confirmation report)
    # output to screen with label
    #    Cp = cost(x, y)
    #P = np.exp((fx[:,None]+gy[None,:]-Cp)/ep)
    #obj = np.dot(fx,a) + np.dot(gy,b) - ep*(np.inner(a, np.inner(P,b)))
    
    obj = np.dot(fx,a) + np.dot(gy,b) 
    return obj

#
#
#
#
def sinkhorn(x,y, maxits, ep, simultaneous=False):
  # classical Sinkhorn iteration
  tol = 1e-6
  assert(x.shape == y.shape)
  shape = x.shape
  n = shape[0]
  a = np.ones(n)/n
  b = np.ones(n)/n
  C = cost(x, y)
  Ct = C.T
  g = np.zeros(n)
  f_old=np.ones(n)
  g_old=np.ones(n)
  obj = np.zeros(maxits)
  loga = np.log(a)
  logb = np.log(b)
  for t in range(maxits):
    f = -ep*(loga + logsumexp((g[:,None]-Ct)/ep, axis=0) )
    if simultaneous:
      f_update = f_old
    else:
      f_update = f
    g = -ep*(logb + logsumexp((f_update[:,None]-C)/ep, axis=0))
    f_diff = f - f_old
    g_diff = g - g_old
    err = g_diff.max() - g_diff.min() + f_diff.max() - f_diff.min()
    f_old = f
    g_old = g
    if (err < tol):
      obj[t:] = obj_func(f, g, x, y, a, b, ep)
      break
    obj[t] = obj_func(f, g, x, y, a, b, ep)
    print('objective is', obj[t], '; error is', err)
  return (f, g, obj)

#
def online_sinkhorn(params):
    # Online Sinkhorn
    #
    print("Online Sinkhorn with Compression=",params["compress"])
    if params["compress"]:
       print("Compression method: ",params["method"])
       if ("fourier_solver" in params) and (params["method"]=="Fourier"):
          print("Fourier solver=",params["fourier_solver"])
                
    ep = params['epsilon']
    d = params['dim']
    # samples for computing obj fun
    xtt_var = params["test_samples1_var"]
    ytt_var = params["test_samples2_var"]
    if "test_samples1_obj" in params:
      xtt_obj = params["test_samples1_obj"]
      ytt_obj = params["test_samples2_obj"]
      Nx = len(xtt_obj)
      Ny = len(ytt_obj)
    # routine for generating new samples
    get_samps1=params["get_samples1"]
    get_samps2=params["get_samples2"]
    # method parameters
    maxits = params['maxits']
    b = params['b'] # eta_t=t^{b} for -1<b<0 (eta=1 Sinkhorn; check limits)
    a = params['a'] # n_t=t^{2a} for a>1+b (number of samples upto iteration t)
    zeta=params['zeta'] # 
    #
    init_batch=1000; batch_const=1;
    if ("init_batch" in params):
      init_batch=params["init_batch"] 
    if ("batch_const" in params):
      batch_const=params["batch_const"] 
    batch = [int(init_batch+batch_const*pow(i, 2*a)) for i in range(1, maxits+2)]
    #  
    exact=False
    if "exact_f" in params:
      #exact=True
      f_exact=params["exact_f"](xtt_var)
      g_exact=params["exact_g"](ytt_var)
    # initialise
    x = get_samps1(init_batch,d)
    y = get_samps2(init_batch,d)
    q = np.ones(init_batch)
    p=ep*np.log(1./init_batch) +evaluate_potential(q, y, x, ep)
    if "no_initial_sinkhorn_its" in params:
      print("Running",params["no_initial_sinkhorn_its"], "Sinkhorn iterations to start-up.")
      for _ in range(params["no_initial_sinkhorn_its"]):
        q=ep*np.log(1./init_batch) +evaluate_potential(p, x, y, ep)
        p=ep*np.log(1./init_batch) +evaluate_potential(q, y, x, ep)
    #
    compression_runtime=0.
    obj = np.zeros(maxits)
    err = np.zeros(maxits)
    #
    fx_old = evaluate_potential(q,y,xtt_var,ep)
    gy_old = evaluate_potential(p,x,ytt_var,ep)
    #
    x_sample_factor=1.
    y_sample_factor=1.
    #  
    # online Sinkhorn iteration     
    for t in range(maxits):
        if t>0:
#           etat = min(((t/params['eta_decay_const']))**b,1.0)
          etat = ((1+t/params['eta_decay_const']))**b
        else:
          etat=1.
        #print(etat, np.log(1-etat))
        # geneerate new samples  
        bt_x = int(x_sample_factor* batch[t+1])
        bt_y = int(y_sample_factor* batch[t+1])
        xt = get_samps1(bt_x,d)
        x_sample_factor=np.var(xt)
        yt = get_samps2(bt_y,d)
        y_sample_factor=np.var(yt)
        # compressed Online Sinkhorn
        # 1d array
        num = len(x)  
        if (params['compress']==True) & (num > params['min_compress']) & (t% params['compression_skip']==0) & (t>params["compression_const"]) & (t<maxits-params["compression_skip"]):
          start = timeit.default_timer()
          m = int(10+ (((t/params['compression_const'])**a)/etat)**(1/zeta))
          if (params['method'] == 'Fourier') :
            print('Compressing with Fourier from ',num,' to ',m,' points.')
            e_pot_y=evaluate_potential(p, x, y, ep)
            e_pot_x=evaluate_potential(q, y, x, ep)
            wy = np.exp((q - e_pot_y)/ep)
            wx = np.exp((p - e_pot_x)/ep) 
            if True: # old style   
              wy,y,idx_y = compression.rec_compress(ep, wy, y, m, params) # compressed weights and points
              wx,x,idx_x = compression.rec_compress(ep, wx, x, m, params) # 
            else: # parallel
              wy,y,idx_y,wx,x,idx_x=rec_compress_both_parallel(ep,wy,y,wx,x,m,params)
            q = ep*np.log(wy) + e_pot_y[idx_y] # 
            p  = ep*np.log(wx) + e_pot_x[idx_x] # 
          elif (params['method'] == 'GQ'):
            print('Compressing with GQ...')
            #
            m = int(10+ (((t/params['compression_const'])**a)/etat)**(1/zeta))
            print('Compressing from ', num, 'to ', m)
            wy = np.exp((q - evaluate_potential(p, x, y, ep))/ep) # 
            yy, wy = gq.compress(y, wy, m) # compressed weights and points
            yy, wy = np.squeeze(yy), np.squeeze(wy)
            qq = ep*np.log(wy) + evaluate_potential(p, x, yy, ep) # calculate q
            #
            wx = np.exp((p - evaluate_potential(q, y, x, ep))/ep) # 
            xx, wx = gq.compress(x, wx, m) # compressed weights and points
            xx, wx = np.squeeze(xx), np.squeeze(wx)
            p  = ep*np.log(wx) + evaluate_potential(q, y, xx, ep) # calculate p
            q=qq
            x=xx
            y=yy
          compression_runtime = timeit.default_timer()-start
          print("Compression time {:.2f}".format(compression_runtime))
        start = timeit.default_timer()
        # update (q,y)
        gt = evaluate_potential(p, x, yt, ep)
        qt = (ep*np.log(etat/bt_y) + gt)
        if etat<1.:
          q = np.concatenate((q+ep*np.log(1-etat), qt))
          y = np.concatenate((y, yt))
        else:
          q=qt; y=yt;
        # update (p,x)
        ft = evaluate_potential(q, y, xt, ep)
        pt = (ep*np.log(etat/bt_x) + ft)
        if etat<1.:
          p = np.concatenate((p+ep*np.log(1-etat), pt))
          x = np.concatenate((x, xt))
        else:
          p=pt; x=xt
        # error calculations
        fx = evaluate_potential(q,y,xtt_var,ep)
        gy = evaluate_potential(p,x,ytt_var,ep)
        if (exact):
          err1 = max(fx-f_exact)-min(fx-f_exact)
          err2 = max(gy-g_exact)-min(gy-g_exact)
        else:
          err1 = max(fx-fx_old)-min(fx-fx_old)
          err2 = max(gy-gy_old)-min(gy-gy_old)
        err_total = err1+err2
        err[t] = err_total
        if err_total < 1e-6:  
          print('Error is smaller than the tolerance.')
          # break
        #
        if "measure" in params:
           obj[t]=params["measure"](lambda zz:evaluate_potential(q, y, zz, ep),lambda zz:evaluate_potential(p, x, zz, ep))
        else:
           fx = evaluate_potential(q,y,xtt_obj,ep)
           gy = evaluate_potential(p,x,ytt_obj,ep)
           obj[t] = obj_func(fx, gy, xtt_obj, ytt_obj, np.ones(Nx)/Nx, np.ones(Ny)/Ny, ep)
        OS_runtime = timeit.default_timer()-start 
        fx = evaluate_potential(q,y,xtt_var,ep)
        gy = evaluate_potential(p,x,ytt_var,ep)
        fx_old = fx
        gy_old = gy
        if (compression_runtime>0):
          print("OS step time {:.2f} (as fraction of last compression {:.2f})".format(OS_runtime,OS_runtime/compression_runtime))
          print("step is", t, 'obj is', obj[t], 'err is', err[t], "; eta=",etat)
        else:
          print("step is", t, 'obj is', obj[t], 'err is', err[t], "; eta=",etat,"OS step time {:.2f}".format(OS_runtime))
    #comment: f and g
    def f(t):
        return evaluate_potential(q, y, t, ep)
    def g(t):
        return evaluate_potential(p, x, t, ep)
    return (f, g, x, y, obj, err)
##############
def get_gaussian_potentials(A,B,a,b,ep):
      # exact form for potentials f,g in Gaussian case
   
      if type(a)==float:
         A=np.array([[A]])
         B=np.array([[B]])
         a=np.array([a])
         b=np.array([b])
      print("Exact form")
      print("A mean=",a, "covariance=",A)
      print("B mean=",b, "covariance=",B)
      d=A.shape[0]
      id=np.eye(d)
      #     
      invA=np.linalg.inv(A)
      invB=np.linalg.inv(B)
      #
      D=(1/2)*sqrtm(ep**2*id+16*A@B)  
      S=(1/2)*invA@invB@((ep/2)*id+D)
      Q_f=(id+(ep/2)*invA-S@B)
      Q_g=(id+(ep/2)*invB-S.T@A)    
      constant=(1/4)*ep*np.log((2/ep)**d*np.linalg.det(B@S.T@A))+(1/2)*np.dot(a-b,a-b)
      #
      loss=np.trace(Q_f@A+Q_g@B)+2*constant
      #
      def f(x):         
        if len(x.shape)==1:
           x=x[:,np.newaxis]
        return np.einsum('ij,jk,ik->i',x-a,Q_f,x-a)+2*np.einsum('ij,j->i',x-a,a-b)+constant
      def g(y): 
        if len(y.shape)==1:
           y=y[:,np.newaxis]
        return np.einsum('ij,jk,ik->i',y-b,Q_g,y-b)+2*np.einsum('ij,j->i',y-b,b-a)+constant
      return f,g,loss
##############       
def integrate(a,A,f,n):
   # integrate f wrt N(a,A) using degree-n Gauss-Hermite quadrature
   y,w=hermgauss(n)
   x=y*np.sqrt(2*A)+a
   fx=f(x)
   return np.sum(w* fx.squeeze())/np.sqrt(np.pi)
###
def GaussHermiteIntegrate(a,A,f,b,B,g,deg):
  # evalute <f,alpha>+<g,beta> for Gaussian alpha,beta using Gauss-Hermite
  return integrate(a,A,f,deg)+integrate(b,B,g,deg)
####
def random_cov_matrix(matrixSize):
  # generate random covariance matrix
  A = np.random.rand(matrixSize, matrixSize)
  B = np.dot(A, np.transpose(A))
  return B
####
def get_qmc_samples(a,A,num_pts):
       d=a.shape[0]
       dist = qmc.MultivariateNormalQMC(mean=np.zeros(d), )
       t= dist.random(num_pts) @sqrtm(A)+a[np.newaxis,:]
       return t

####
def get_qmc_samples1(a,A,num_pts):
       
       dist = qmc.MultivariateNormalQMC(mean=np.zeros(1), )
       t= dist.random(num_pts) *np.sqrt(A)+a
       return t.squeeze()
