# Much of the code is adapted from:
# Reference: https://github.com/DifferentiableUniverseInitiative/sbi_lens/blob/main/sbi_lens/simulator/LogNormal_field.py
# and https://github.com/DifferentiableUniverseInitiative/sbi_lens/blob/main/sbi_lens/simulator/redshift.py

import jax
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax.scipy.ndimage import map_coordinates
from jax.tree_util import register_pytree_node_class
from tensorflow_probability.substrates.jax.math import find_root_chandrupatla
from jax_cosmo.scipy.integrate import _difftrap1, _difftrapn, _romberg_diff, simps
from functools import partial
from haiku._src.nets.resnet import ResNet18
import haiku as hk

DIM_SUMMARY = 6                

def sample_lensing_prior(key):
    prior_params = {
        'omega_c': {'mean': 0.2664, 'std': 0.2, 'low': 0.0, 'high': jnp.inf},
        'omega_b': {'mean': 0.0492, 'std': 0.006},
        'sigma_8': {'mean': 0.831, 'std': 0.14},
        'h_0':     {'mean': 0.6727, 'std': 0.063},
        'n_s':     {'mean': 0.9645, 'std': 0.08},
        'w_0':     {'mean': -1.0, 'std': 0.9, 'low': -2.0, 'high': -0.3}
    }

    param_order = ['omega_c', 'omega_b', 'sigma_8', 'h_0', 'n_s', 'w_0']
    keys = jax.random.split(key, len(param_order))

    samples = {}

    key_oc = keys[0]
    p = prior_params['omega_c']
    lower_std_oc = (p['low'] - p['mean']) / p['std']
    upper_std_oc = (p['high'] - p['mean']) / p['std'] 
    z_trunc_oc = jax.random.truncated_normal(key_oc, lower=lower_std_oc, upper=upper_std_oc)
    samples['omega_c'] = jnp.log(p['mean'] + p['std'] * z_trunc_oc)

    key_ob = keys[1]
    p = prior_params['omega_b']
    z_ob = jax.random.normal(key_ob)
    samples['omega_b'] = jnp.log(p['mean'] + p['std'] * z_ob)

    key_s8 = keys[2]
    p = prior_params['sigma_8']
    z_s8 = jax.random.normal(key_s8)
    samples['sigma_8'] = jnp.log(p['mean'] + p['std'] * z_s8)

    key_h0 = keys[3]
    p = prior_params['h_0']
    z_h0 = jax.random.normal(key_h0)
    samples['h_0'] = p['mean'] + p['std'] * z_h0

    key_ns = keys[4]
    p = prior_params['n_s']
    z_ns = jax.random.normal(key_ns)
    samples['n_s'] = p['mean'] + p['std'] * z_ns

    key_w0 = keys[5]
    p = prior_params['w_0']
    lower_std_w0 = (p['low'] - p['mean']) / p['std']
    upper_std_w0 = (p['high'] - p['mean']) / p['std']
    z_trunc_w0 = jax.random.truncated_normal(key_w0, lower=lower_std_w0, upper=upper_std_w0)
    samples['w_0'] = p['mean'] + p['std'] * z_trunc_w0

    return jnp.array([samples[param] for param in param_order])

def lensing_simulator(key,
                    params,
                    n_sim,
                    opt_state_resnet,
                    parameters_compressor,
                    lognormal_shifts_params,
                    compress = False):
    N, map_size, sigma_e, gals_per_arcmin2, nbins, a, b, z0 = 256, 10, 0.26, 27, 5, 2, 0.68, 0.11
    
    model = partial(
        lensingLogNormal,
        theta=params,
        N=N,
        map_size=map_size,
        gal_per_arcmin2=gals_per_arcmin2,
        sigma_e=sigma_e,
        nbins=nbins,
        a=a,
        b=b,
        z0=z0,
        model_type='lognormal',
        lognormal_shifts_params=lognormal_shifts_params,
        with_noise=True
    )
    x = jax.vmap(model)(jax.random.split(key, n_sim))

    if compress:
        dim = 6

        compressor = hk.transform_with_state(
            lambda y : ResNet18(dim)(y, is_training=False)
        )
        obs_mass_map_compressed, _ = compressor.apply(
        parameters_compressor,
        opt_state_resnet,
        None,
        x
        )

        x = obs_mass_map_compressed
    return x

def lensingLogNormal(
    key,
    theta,
    lognormal_shifts_params,
    N=256,
    map_size=10,
    gal_per_arcmin2=27,
    sigma_e=0.26,
    nbins=5,
    a=2,
    b=0.68,
    z0=0.11,
    model_type="lognormal",
    with_noise=True
):
    pix_area = (map_size * 60 / N) ** 2
    map_size = map_size / 180 * jnp.pi
    
    omega_c, omega_b, sigma_8, h_0, n_s, w_0 = theta
    sigma_8 = jnp.exp(sigma_8)

    cosmo = jc.Planck15(
        Omega_c=omega_c, Omega_b=omega_b, h=h_0, n_s=n_s, sigma8=sigma_8, w0=w_0
    )

    if model_type == "lognormal":
        shift = shift_fn(
            lognormal_shifts_params, jnp.array([cosmo.Omega_m, cosmo.sigma8, cosmo.w0])
        )
        shift_array = fill_shift_array(shift)

    nz = jc.redshift.smail_nz(a, b, z0, gals_per_arcmin2=gal_per_arcmin2)
    nz_bins = subdivide(nz, nbins=nbins, zphot_sigma=0.05)
    tracer = jc.probes.WeakLensing(nz_bins, sigma_e=sigma_e)

    ell_tab = 2 * jnp.pi * abs(jnp.fft.fftfreq(2 * N, d=map_size / (2 * N)))
    cell_tab = jc.angular_cl.angular_cl(cosmo, ell_tab, [tracer])
    power = []
    if model_type == "lognormal":
        for cl, l_shift in zip(cell_tab, shift_array):

            def P(k):
                return jc.scipy.interpolate.interp(k.flatten(), ell_tab, cl).reshape(
                    k.shape
                )

            power_map = make_power_map(P, N, map_size)
            power_map = make_lognormal_power_map(power_map, l_shift)
            power.append(power_map)
    elif model_type == "gaussian":
        for cl in cell_tab:

            def P(k):
                return jc.scipy.interpolate.interp(k.flatten(), ell_tab, cl).reshape(
                    k.shape
                )

            power_map = make_power_map(P, N, map_size)
            power.append(power_map)
    power = jnp.stack(power, axis=-1)

    @jax.vmap
    def fill_cov_mat(m):
        idx = np.triu_indices(nbins)
        cov_mat = jnp.zeros((nbins, nbins)).at[idx].set(m).T.at[idx].set(m)
        return cov_mat

    cov_mat = fill_cov_mat(power.reshape(-1, len(cell_tab)))
    eigval, A = jnp.linalg.eigh(cov_mat)
    L = jax.vmap(lambda M, v: M.dot(jnp.diag(jnp.sqrt(jnp.clip(v, a_min=0))).dot(M.T)))(
        A, eigval
    )
    L = L.reshape([N, N, nbins, nbins])
    L = L.at[0, 0].set(jnp.zeros((nbins, nbins)))
    L = L.transpose([2, 3, 0, 1])

    key, subkey = jax.random.split(key)
    z = jax.random.multivariate_normal(subkey, jnp.zeros((nbins, N, N)), jnp.eye(N))
    field = jnp.fft.fft2(z) * L
    field = jnp.fft.ifft2(jnp.sum(field, axis=1)).real
    if model_type == "lognormal":
        field = jnp.einsum(
            "i, ijk -> ijk",
            shift,
            jnp.exp(field - jnp.var(field, axis=(1, 2), keepdims=True) / 2) - 1,
        )

    field = jnp.transpose(field, [1, 2, 0])

    if with_noise is True:
        key, subkey = jax.random.split(key)
        x = jax.random.multivariate_normal(subkey, field, jnp.diag(sigma_e**2 / (jnp.array([b.gals_per_arcmin2 for b in nz_bins]) * pix_area)))
    else:
        x = field

    return x
    
@jax.jit
def shift_fn(params, theta):
    ntheta = len(params.shape[:-2])
    nbins = params.shape[-2]
    shift = [
        map_coordinates(
            params[..., i, -1],
            jnp.stack(
                [
                    (theta[j] - params[..., j].min())
                    / (params[..., j].max() - params[..., j].min())
                    * params.shape[j]
                    - 0.5
                    for j in range(ntheta)
                ],
                axis=0,
            ).reshape([ntheta, -1]),
            order=1,
            mode="nearest",
        ).squeeze()
        for i in range(nbins)
    ]
    return jnp.stack(shift)


def fill_shift_array(shifts):
    idx = np.mask_indices(len(shifts), np.triu)
    shift_array = jnp.outer(shifts, shifts)[idx]
    return shift_array


def make_power_map(pk_fn, N, map_size, zero_freq_val=0.0):
    k = 2 * jnp.pi * jnp.fft.fftfreq(N, d=map_size / N)
    kcoords = jnp.meshgrid(k, k)
    k = jnp.sqrt(kcoords[0] ** 2 + kcoords[1] ** 2)
    ps_map = pk_fn(k)
    ps_map = ps_map.at[0, 0].set(zero_freq_val)
    power_map = ps_map * (N / map_size) ** 2
    return power_map


def make_lognormal_power_map(power_map, pshift, zero_freq_val=0.0):
    power_spectrum_for_lognorm = jnp.fft.ifft2(power_map).real
    power_spectrum_for_lognorm = jnp.log(1 + power_spectrum_for_lognorm / pshift)
    power_spectrum_for_lognorm = jnp.abs(jnp.fft.fft2(power_spectrum_for_lognorm))
    power_spectrum_for_lognorm = power_spectrum_for_lognorm.at[0, 0].set(0.0)
    return power_spectrum_for_lognorm

@register_pytree_node_class
class photoz_bin(jc.redshift.redshift_distribution):
    def pz_fn(self, z):
        parent_pz, zphot_min, zphot_max, zphot_sig = self.params
        p = parent_pz(z)

        x = 1.0 / (jnp.sqrt(2.0) * zphot_sig * (1.0 + z))
        res = (
            0.5
            * p
            * (
                jax.scipy.special.erf((zphot_max - z) * x)
                - jax.scipy.special.erf((zphot_min - z) * x)
            )
        )
        return res

    @property
    def gals_per_arcmin2(self):
        parent_pz, zphot_min, zphot_max, zphot_sig = self.params
        return parent_pz._gals_per_arcmin2 * simps(
            lambda t: parent_pz(t), zphot_min, zphot_max, 256
        )

    @property
    def gals_per_steradian(self):
        return self.gals_per_arcmin2 * jc.redshift.steradian_to_arcmin2


def subdivide(pz, nbins, zphot_sigma):
    zbounds = [0.0]
    bins = []
    n_per_bin = 1.0 / nbins
    for i in range(nbins - 1):
        zbound = find_root_chandrupatla(
            lambda z: romb_jax(pz, 0.0, z) - (i + 1.0) * n_per_bin, zbounds[i], pz.zmax
        ).estimated_root
        zbounds.append(zbound)
        new_bin = photoz_bin(pz, zbounds[i], zbounds[i + 1], zphot_sigma)
        bins.append(new_bin)

    zbounds.append(pz.zmax)
    new_bin = photoz_bin(pz, zbounds[nbins - 1], zbounds[nbins], zphot_sigma)
    bins.append(new_bin)

    return bins

def romb_jax(function, a, b, args=(), divmax=6):
    a = jnp.asarray(a, dtype=jnp.promote_types(float, getattr(a, 'dtype', float)))
    b = jnp.asarray(b, dtype=jnp.promote_types(float, getattr(b, 'dtype', float)))

    def vfunc(x):
        x = jnp.asarray(x, dtype=a.dtype)
        return function(x, *args)

    n = 1 
    interval = jnp.array([a, b]) 
    intrange = b - a
    ordsum = _difftrap1(vfunc, interval)
    result = intrange * ordsum

    state = jnp.repeat(jnp.atleast_1d(result), divmax + 1, axis=-1)
    err = jnp.inf 

    def scan_fn(carry, prev_state_val):
        x, k = carry
        new_val = _romberg_diff(prev_state_val, x, k + 1)
        return (new_val, k + 1), new_val

    for i in range(1, divmax + 1):
        n = 2**i 
        ordsum = ordsum + _difftrapn(vfunc, interval, n)

        x = intrange * ordsum / n
        x = jnp.atleast_1d(x) 

        _, new_extrapolated_vals = jax.lax.scan(scan_fn, (x[0], 0), state[:-1])
        _, new_extrapolated_vals = jax.lax.scan(scan_fn, (x[0], 0), state[:-1])

        new_state = jnp.concatenate([x, new_extrapolated_vals])

        if i > 0:
           err = jnp.abs(state[i - 1] - new_state[i])
        else:
           pass

        state = new_state

    final_result = state[divmax]

    return final_result
    