import pickle
from loguru import logger
import os
import numpy as np
from scipy import interpolate, integrate
from functools import partial

def fetch_gaussian_hyperplanes(num_planes, latent_dim):
    """
        latent_dim: dimension of the latent space where random normal hyperplanes are generated
        num_planes: no of hyperplanes to generate
        Load/generate specified no of gaussian hyperplanes
    """
    logger.info('Fetching gaussian hplanes ...')
    fp = f'./data/gauss_hplanes_numPlanes{num_planes}_latentDim{latent_dim}.pkl'
    if os.path.exists(fp):
        hplanes = pickle.load(open(fp,'rb'))
    else:
        hplanes = np.random.normal(size=(latent_dim, num_planes))
        pickle.dump(hplanes, open(fp,"wb"))
    logger.info("Loading random hyperplanes from %s", fp)
    return hplanes

def fetch_omega_samples(conf):
    """
        fetch samples of  w given T, a, b
    """
    num_samples =  conf.dataset.embed_dim * conf.hashing.m_use #* 4
    return fetch_n_omega_samples(conf, num_samples)

def fetch_n_omega_samples(conf, n):
    """
        fetch n number of samples of  w given T, a, b
    """
    logger.info('Fetching samples ...')    

    samples_fp = "./allPklDumps/"  + conf.hashing.Sm +\
                        "samples_num" + str(n) + "_a"+ str(conf.hashing.a) +\
                        "_b" + str(conf.hashing.b) + "_T" + str(conf.hashing.T) +\
                        "hashing_name" + conf.hashing.name + ".pkl"
    if os.path.exists(samples_fp):
        all_d = pickle.load(open(samples_fp,"rb"))
        logger.info(f"Fetching samples from  {samples_fp}")
    else:
        logger.info(f"Samples not found, so generating samples and dumping to {samples_fp}")
        all_d = generate_samples(conf,n, conf.hashing.T, conf.hashing.a, conf.hashing.b)
        pickle.dump(all_d, open(samples_fp, "wb"))
    logger.info('Samples fetched')
    return np.float32(all_d['samples']), np.float32(all_d['pdf'])

def pdf_basic(w, T):
    R_G = 2 * (np.sin(w*T/2))**2 / w**2 + T * np.sin(w*T) / w
    I_G = np.sin(w*T) / w**2 - T * np.cos(w*T) / w
    return  (np.abs(R_G) + np.abs(I_G))

def pdf_expdecay(w,T,a):
    R_G = 2 * (np.sin(w*T/2))**2 / w**2 + T * np.sin(w*T) / w
    I_G = np.sin(w*T) / w**2 - T * np.cos(w*T) / w
    return np.sqrt(R_G*R_G+I_G*I_G)*a/np.sqrt(a*a + w*w)

def pdf_expdecay_corr(w,T,a):
    R_G = 2 * (np.sin(w*T/2))**2 / w**2 + T * np.sin(w*T) / w
    I_G = np.sin(w*T) / w**2 - T * np.cos(w*T) / w
    Re_s = a*a/(a*a + w*w)
    Im_s = -a*w/(a*a + w*w)
    R_G_final = R_G * Re_s - I_G * Im_s
    I_G_final = R_G * Im_s + I_G * Re_s
    return (np.abs(R_G_final) + np.abs(I_G_final))

def get_pdf(conf):
    if conf.hashing.Sm == 'none':
        return pdf_basic
    elif conf.hashing.Sm == 'EDIN':
        return partial(pdf_expdecay, a=conf.hashing.b)
    elif conf.hashing.Sm == 'ED':
        return partial(pdf_expdecay_corr, a=conf.hashing.b)
    else:
        raise ValueError("Unknown smoothing type")

def get_smoothing_factor(conf, ws, T):
    if conf.hashing.Sm == 'EDIN':
        a = conf.hashing.b
        Re_s = a*a/(a*a + ws*ws)
        Im_s = -a*ws/(a*a + ws*ws)
        return Re_s, Im_s
    elif conf.hashing.Sm == 'ED':
        a = conf.hashing.b
        Re_s = a*a/(a*a + ws*ws)
        Im_s = -a*ws/(a*a + ws*ws)
        return Re_s, Im_s
    else:
        raise ValueError("Unknown smoothing type")

def get_pdf_normalization(pdf,T, a1, b1):
    N = integrate.quad(pdf, a1, b1, points=0, limit=10000, args=T)[0]
    return N



def generate_samples(conf,num_samples, T, a1, b1):
    x1 = np.linspace(a1,b1,100000)
    pdf = get_pdf(conf)
    y1 = pdf(x1,T)
    cdf_y = np.cumsum(y1)
    cdf_y = cdf_y/cdf_y.max()
    inverse_cdf = interpolate.interp1d(cdf_y,x1)
    N = get_pdf_normalization(pdf,T, a1, b1)
    # num_samples = 1000
    uniform_samples = np.random.uniform(1e-5,1,num_samples)
    cdf_samples = inverse_cdf(uniform_samples)
    pdfs = np.array([pdf(w, T) / N for w in cdf_samples])
    samples_with_pdf = list(zip(cdf_samples, pdfs))
    all_data = {
            'samples': cdf_samples,
            'pdf': pdfs,
            'samples_with_pdf': samples_with_pdf,
            'a': a1,
            'b': b1,
            'T': T,
            'num': num_samples
        }
    return all_data

