import sys
import numpy as np
from scipy.stats import qmc
from scipy.optimize import nnls
from recombination2 import *
from celer import Lasso
from sklearn.linear_model import LinearRegression
# Moments for compression algorithm
# Parameters:
#  t: frequency
#  y: poly input
#  ep: epsilon
#  c: type 'squeclidean', 'sqeuclidean_gau', 'euclidean'
# Returns vector F of moments for given frequency t
def Four_gaussian(t,y,ep,c='sqeuclidean'):
    if y.ndim == 1:
        d = 1
    else:
        d = y.shape[1]
    if d == 1:
        if (c == 'sqeuclidean'):
            F = np.sqrt(ep*np.pi)*np.exp(-ep*t**2/4 - 1j*t*y)
        elif (c=='sqeuclidean_gau'):
            F = np.sqrt(ep*np.pi)*np.exp(- 1j*t*y)
        elif (c == 'euclidean'):
            F =  np.exp(-1j*t*y) * (2*np.pi*ep) / (1+ep**2*t**2)
    elif d>1:
        if (c == 'sqeuclidean'):
            F = np.sqrt(ep*np.pi)*np.exp(-ep*t@t/4- 1j*y@t)
        elif (c=='sqeuclidean_gau'):
            F = np.sqrt(ep*np.pi)*np.exp(- 1j*y@t)
        elif (c == 'euclidean'):
            F = np.exp(-1j*y@t) * (1/ep*(2*np.pi)**d*gamma(d/2)) / (np.pi**(d/2)*(1/ep**2 + t@t))**((d+1)/2)
    return  F
#
# Similar to above but returns inner product with weight vector w
# y has no_points rows, and d columns
# w has no_points rows
# calculate F(y) w for d-dimension t
def Four_gaussian_rhs(t,y,w,ep,c='sqeuclidean'):
    if y.ndim == 1:
        d = 1
    else:
        d = y.shape[1]
    if d == 1:
        if (c == 'sqeuclidean'):
            F = np.sqrt(ep*np.pi)*np.inner(np.exp(-ep*t**2/4 - 1j*t*y),w)
        elif (c=='sqeuclidean_gau'):
            F = np.sqrt(ep*np.pi)*np.inner(np.exp(- 1j*t*y),w)
        elif (c == 'euclidean'):
            F =  np.inner(np.exp(-1j*t*y),w) * (2*np.pi*ep) / (1+ep**2*t**2)
    elif d>1:
        if (c == 'sqeuclidean'):
            F = np.sqrt(ep*np.pi)*np.inner(np.exp(-ep*t@t/4- 1j*y@t),w)
        elif (c=='sqeuclidean_gau'):
            F = np.sqrt(ep*np.pi)*np.inner(np.exp(- 1j*y@t),w)
        elif (c == 'euclidean'):
            F = np.inner(np.exp(-1j*y@t),w) * (1/ep*(2*np.pi)**d*gamma(d/2)) / (np.pi**(d/2)*(1/ep**2 + t@t))**((d+1)/2)
    #print("FGR",y.shape,t.shape,w.shape,F.shape,np.shape(y*t),np.shape(np.inner(np.exp(- 1j*t*y),w)),c)
    return  F



# Sample frequencies and generate moments for compression algorithm
# parameters:
#  num_freq: return vectors has 2*num_freq
#  y: quadrature points
#  ep: epsilon
#  type: 'uniform' (default), 'qmc_uniform', 'qmc_gau'
def Four_poly(ep,y,num_freq,c='uniform',L=1.,w=None,y_rhs=None):
    N = y.shape[0]
    if y.ndim == 1:
        d = 1
    else:
        d = y.shape[1]
    A = np.zeros([N,2*num_freq])
    #print("Four_poly",c)
    # generate frequencies
    if d==1:
       if c=='uniform':
          S = 10*int(np.ceil(np.log(1/ep)))
          t = [i*S/num_freq for i in range(1,num_freq+1)] # frequencies t
          t = np.asarray(t, dtype='int')
          #print('t',t)
       elif c=='qmc_uniform':
          dist = qmc.Halton(d )
          t=  dist.random(num_freq)[:]*S
          cc='sqeuclidean'
       elif c=='qmc_gau':
          cc='sqeuclidean_gau'
          #print("Four_poly L=",L)
          scale=2*L**2/(ep)
          dist = qmc.MultivariateNormalQMC(mean=np.zeros(d), )
          t=  dist.random(num_freq)[:]*scale**0.5
    else:# d>1
        if c=='uniform':
          S = 10*int(np.ceil(np.log(1/ep)))
          t = np.random.uniform(0,S,(num_freq,d))
        elif c=='qmc_uniform':
          dist = qmc.Halton(d )
          t=  dist.random(num_freq)*S
          cc='sqeuclidean'
        elif c=='qmc_gau':
          cc='sqeuclidean_gau'
          scale=2/(ep)
          dist = qmc.MultivariateNormalQMC(mean=np.zeros(d), )
          t=  dist.random(num_freq)*scale**0.5
    
    # generate moment matrix at points y
    for i in range(1,num_freq+1): # (-s,s)
        if d == 1:
            F = Four_gaussian(t[i-1],y/L,ep,cc)
        else:
            F = Four_gaussian(t[i-1,:],y,ep,cc)
        A[:,i-1] = np.real(F)
        A[:,i-1+num_freq] = np.imag(F)
    # generate rhs calculation at y_rhs if required
    if w is not None:
      rhs=np.zeros([2*num_freq])
      for i in range(1,num_freq+1): # (-s,s)
          if d == 1:
              F = Four_gaussian_rhs(t[i-1],y_rhs/L,w,ep,cc)
          else:
              F = Four_gaussian_rhs(t[i-1,:],y_rhs,w,ep,cc)
          rhs[i-1] = np.real(F)
          rhs[i-1+num_freq] = np.imag(F)
    else:
        rhs=None
    return A, rhs

# Main recombination compression
# Parameters:
#   ep: epsilon
#   weights, x: weights and nodes of input measure
#   m: target size for compression
#   c: type - see Four_poly()
# Returns:
#  reduced set of weights, points of size m
def rec_compress(ep,weights,x,m,params):
    method=params["fourier_solver"]
    print("Compress from ",x.size, " to ", m, " points (",method,")") # 2m-1
    # evaluate matrix of functions of points
    c=params["t_sampling"]
    if "L_scale" in params:
      L=params["L_scale"]
    else:
      L=1.0
    #
    use_reduced=True
    list1=("nnls","celer","sklearn")
    if (method in list1) and (use_reduced==True):
       print("Reduced compression method")
       r=int(min(x.size, 1.5*m))
       px,rhs = Four_poly(ep,x[:r],int(np.ceil((m-1)/2)),c,L,weights,x)
    else:
       print("Full compression method")
       px,rhs = Four_poly(ep,x,int(np.ceil((m-1)/2)),c,L)
    if (method in list1) and (use_reduced==False):
       rhs=px.T@weights
    
    # main compression algorithm
    if (method=="fast_car"):
      new_weights = Fast_Caratheodory(px,weights,m+1)
      # extract new weights (non-zero weighted output)
      indx=np.nonzero(new_weights)
      new_weights=new_weights[indx]
      new_pts=x[indx]
    elif(method=="recomb_combined"): # 
      max_its=5
      new_weights, idx_star, tmp, t, ERR, iterations, eliminated_points=recomb_combined(px,max_its,weights)
      # extract new points
      new_pts=x[idx_star]
    elif(method=="recomb_log"): #
      max_its=0
      new_weights, idx_star, tmp, t, ERR, iterations, eliminated_points=recomb_log(px,max_its,weights)
      # extract new points
      if ERR==0:
        new_pts=x[idx_star]
    elif(method=="sklearn"):  
       estimator=LinearRegression(positive=True,fit_intercept=False).fit(px.T,rhs)
       new_weights=estimator.coef_
    elif(method=="nnls"):  
       new_weights,residual=nnls(px.T,rhs)
    elif(method=="celer"):  
       estimator=Lasso(alpha=1e-3,positive=True,tol=1e-4,max_iter=100,fit_intercept=False)
       estimator.fit(px.T, rhs)
       new_weights=estimator.coef_
    elif(method=="tl"):   
      max_its=0
      new_weights, idx_star, tmp, t, ERR, iterations, eliminated_points= Tchernychova_Lyons(px,weights)
      # extract new points
      new_pts=x[idx_star]
    else:
       print("Unknown method",method)
       sys.exit('My error message')
    if method in list1:
       if True:
         print("Relative residual=", np.linalg.norm(px.T@new_weights-rhs)/np.linalg.norm(rhs))
       idx_star=np.where(new_weights>1e-8)
       new_pts=x[idx_star]    
       new_weights=new_weights[idx_star]
    # 
    print("Compressed to ",new_pts.size," points")
    return new_weights,new_pts,idx_star
##

def rec_compress_both(ep,wy,y,wx,x,m,c):
  wy, yy, idx_y = rec_compress(ep, wy, y, m,c) # compressed weights and points
  wx, xx, idx_x = rec_compress(ep, wx, x, m,c) # compressed weights and points
  return wy,yy,idx_y,wx,xx,idx_x

# for parallel runs
def rec_worker(inp): 
   out=rec_compress(inp[0],inp[1],inp[2],inp[3],inp[4],inp[5])
   return out

