import os
import sys
import torch
import torch.nn as nn
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import numpy as np
from scipy.integrate import quad, dblquad
from scipy.stats import hypsecant
from scipy.stats import norm

epsabs = 1e-2
epsrel = 1e-2
dz = 0.05
zmin = -10.0
zmax = 10.0

@torch.no_grad()
def get_p0_thm(args, sigma_ws, sigma_bs, etas):
    device = args.device
    ws = torch.FloatTensor(sigma_ws).to(device)
    wss = ws*ws
    bs = torch.FloatTensor(sigma_bs).to(device)
    bss = bs*bs
    etas = torch.FloatTensor(etas).to(device)
    L = args.n_layers; T = args.T
    n_w = len(sigma_ws); n_b = len(sigma_bs); n_et = len(etas)
    n = n_w*n_b*n_et
    p_diag = torch.zeros((n, T, L)).to(device)
    q = torch.zeros((n, T, L)).to(device)
    r = torch.zeros((n, T, L-1)).to(device)
    bgl = torch.zeros((n, T, L)).to(device)

    ## set initial values of p, beta, gamma, kappa, s
    p = torch.eye(L, L).repeat(n_w, n_b, n_et, 1, 1).to(device) # n_w, n_b, n_et, L, L
    beta = torch.zeros((n_w, n_b, n_et, L, L)).to(device)
    gamma = torch.eye(L, L).repeat(1, n_b, 1, 1, 1).to(device) \
        * bss[None, :, None, None, None] # 1, n_b, 1, L, L
    gamma = gamma.to(device)
    _kappa = (1-torch.matmul((1+wss)[:,None], etas[None,:])) # n_w, n_et
    kappa = _kappa.repeat(n_b, 1, 1).permute(1,0,2).to(device) # n_w, n_b, n_et
    rho = torch.matmul(ws[:,None], etas[None,:]).repeat(n_b, 1, 1).permute(1,0,2)  # n_w, n_b, n_et

    ## set weight (coef)
    coef_1d = torch.stack([rho, kappa, rho], dim=-1) # n_w, n_b, n_et, 3
    coef_2d = torch.matmul(coef_1d.unsqueeze(-1), coef_1d.unsqueeze(-2)) # n_w, n_b, n_et, 3, 3
    coef_1d_g = torch.stack([etas.repeat(n_w, n_b, 1), -rho], dim=-1) # n_w, n_b, n_et, 2

    coef_1d_10 =  torch.stack([torch.ones(n_w, n_b, n_et), torch.zeros(n_w, n_b, n_et)], dim=-1).to(device) # n_w, n_b, n_et, 2
    coef_2d_g10 = torch.matmul(coef_1d_10.unsqueeze(-1), coef_1d_g.unsqueeze(-2))

    coef_1d_q = torch.stack([ -ws.repeat(n_b, n_et, 1).permute(2,0,1),
                             torch.ones_like(kappa).to(device) ], dim=-1) # n_w, n_b, n_et, 2
    coef_2d_q = torch.matmul(coef_1d_q.unsqueeze(-1), coef_1d_q.unsqueeze(-2)) # n_w, n_b, n_et, 2, 2

    ## get weight (coef)
    lyr_b = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = (1, 3),
                      padding=(0, 1), bias=False).to(device)
    lyr_b.weight = nn.Parameter(coef_1d.reshape(n, 1, 1, 3))
    lyr_g = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = (1, 2),
                      padding=(0, 1), bias=False).to(device)
    lyr_g.weight = nn.Parameter(coef_1d_g.reshape(n, 1, 1, 2))
    lyr_gg = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = 2,
                      padding=1, bias=False).to(device)
    lyr_gg.weight = nn.Parameter(coef_2d_g10.reshape(n, 1, 2, 2))
    lyr_p = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = 3,
                      padding=1, bias=False).to(device)
    lyr_p.weight = nn.Parameter(coef_2d.reshape(n, 1, 3, 3)) # n_w*n_b*n_et, 1, 3, 3
    lyr_q = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = (2, 2),
                      padding=(1, 1), bias=False).to(device)
    lyr_q.weight = nn.Parameter(coef_2d_q.reshape(n, 1, 2, 2))
    lyr_qb = nn.Conv2d(in_channels=n, out_channels=n, groups=n, kernel_size = (1, 2),
                      padding=(0, 1), bias=False).to(device)
    lyr_qb.weight = nn.Parameter(-coef_1d_q.reshape(n, 1, 1, 2))

    # gamma 
    gamma_r = gamma.repeat(n_w, 1, n_et, 1, 1).reshape(1, n, L, L) # batch_siz, channel, w, h

    # t = 1 init cond 
    p_diag[:, 0, :] = torch.diagonal(p.reshape(1, n, L, L), dim1=-2, dim2=-1)

    for t in range(1, T):
        # beta 
        beta_r = beta.reshape(1, n, L, L)
        _beta_sum = lyr_b(beta_r)
        _gamma_sum = lyr_gg(gamma_r)[:, :, :-1, :-1]
        new_beta = _beta_sum + _gamma_sum
        ### boundary case
        new_beta[:, :, :, 0] = 0
        new_beta[:, :, :, -1] = 0

        # p 
        p_r = p.reshape(1, n, L, L)
        _p_sum = lyr_p(p_r)
        _beta_ss = lyr_g(_beta_sum.transpose(-1, -2))[:, :, :, 1:] ## transposed to apply lyr_g with (2, 1) kernel
        _g_sum = lyr_g(_gamma_sum.transpose(-1, -2))[:, :, :, 1:]
        new_p = _p_sum + _beta_ss + _beta_ss.transpose(-1, -2) + _g_sum
        ## boundary case
        _p_b = lyr_b(p_r)
        new_p[:, :, :, 0] = _p_b[:, :, :, 0]
        new_p[:, :, :, -1] = _p_b[:, :, :, -1]
        _p_b_t = _p_b.transpose(-1, -2)
        new_p[:, :, 0, :] = _p_b_t[:, :, 0, :]
        new_p[:, :, -1, :] = _p_b_t[:, :, -1, :]
        ## corner
        new_p[:, :, 0, 0] = 1
        new_p[:, :, -1, -1] = 1
        new_p[:, :, 0, -1] = 0
        new_p[:, :, -1, 0] = 0
        ## update
        p = new_p
        beta = new_beta

        # q 
        _p_sum_for_q = lyr_q(p)[0, :, :-1, :-1] # n, L, L
        _b_sum_for_q = lyr_qb(beta)[0, :, :, :-1] # n, L, L
        gamma_diag = bs.repeat(n_w, n_et, L, 1).permute(0,3,1,2).reshape(n, L) # n, L
        q[:, t-1, :] = torch.diagonal(_p_sum_for_q, dim1=-2, dim2=-1) + \
                     torch.diagonal(2*_b_sum_for_q, dim1=-2, dim2=-1) + \
                     gamma_diag
        q[:, t-1, 0] = 0
        # u, v 계산
        p_diag[:, t, :] = torch.diagonal(p, dim1=-2, dim2=-1)
        # r[:, t, :] = p_diag[:, t, :-1]  * q[:, t, 1:]
    r = p_diag[:, :, :-1] * q[:, :, 1:]

    # reshape from (n, T, L) to (n_w, n_b, n_et, T, L)
    p_diag = p_diag.reshape(n_w, n_b, n_et, T, L).cpu().numpy()
    q = q.reshape(n_w, n_b, n_et, T, L).cpu().numpy()
    r = r.reshape(n_w, n_b, n_et, T, L-1).cpu().numpy()
    return p_diag, q, r, q


def get_p0_thm_v0(args, sigma_ws, sigma_bs, etas):
    device = args.device
    ws = torch.FloatTensor(sigma_ws).to(device)
    wss = ws*ws
    L = args.n_layers; T = args.T
    n_w = len(sigma_ws); n_b = len(sigma_bs); n_et = len(etas)
    p = torch.eye(L, L).repeat(n_w, n_b, n_et, 1, 1).to(device) \
        * wss[:, None, None, None, None] # n_w, n_b, n_et, L, L
    beta = torch.zeros((n_w, n_b, n_et, L, L)).to(device)
    coef = torch.zeros((n_w, n_b, n_et, L, L, 3, 3)).to(device)
    gamma = torch.eye(L, L).repeat(1, n_b, 1, 1, 1).to(device) \
        * sigma_bs[None, :, None, None, None] # 1, n_b, 1, L, L
    rho = torch.zeros((n_w, n_b, n_et, L, L)).to(device)
    for t in range(T-1):
        # p update
        p_lm1 = p[:, :, :, 0:-3, :]
        p_l = p[:, :, :, 1:-2, :]
        p_lp1 = p[:, :, :, 2:-1, :]

        #if option == 0:
        p_ls = torch.stack([p_lm1, p_l, p_lp1], dim=-1)
        p_km2 = p_ls[:, :, :, :, 0:-3, :]
        p_km1 = p_ls[:, :, :, :, 1:-2, :]
        p_k = p_ls[:, :, :, :, 2:-1, :] # where k >= 2
        p_3x3 = torch.stack([p_km2, p_km1, p_k], dim=-1) # where k >= 2
        p_wsum = (coef[:, :, :, :, :, :-1, :-1] * p_3x3).sum((-1, -2))

        # elif option == 2: directly get 16 terms

        rho_lm1 = rho[:, :, :, 1:-3, 0:-3] # need to check when k = 0 (then, k-2 becames -2)
        rho_l = rho[:, :, :, 1:-3, 1:-2]
        rho_lp1 = rho[:, :, :, 1:-3, 2:-1]
        rho_ls = torch.stack([rho_lm1, rho_l, rho_lp1], dim=-1)
        rho_wsum = (coef[:, :, :, :, :, :-1, -1] * rho_ls).sum((-1))

        beta_k = beta[:, :, :, :, 0:-3] # check
        beta_kp1 = beta[:, :, :, :, 1:-2]
        beta_kp2 = beta[:, :, :, :, 2:-1]
        beta_ks = torch.stack([beta_k, beta_kp1, beta_kp2], dim=-1)
        beta_wsum = (coef[:, :, :, :, :, -1, :-1] * beta_ks).sum((-1))
        new_p = p_wsum + rho_wsum + beta_wsum + coef[:, :, :, :, :, -1, -1] * gamma
        # p boundary
        # beta update
        # beta boundary


def get_pmaps_theory(args, sigma_ws, sigma_bs, etas, grid = False):
    eta = args.eta; L = args.n_layers; T = args.T
    nw = len(sigma_ws); nb = len(sigma_bs)
    # if grid:
    #     sigma_ws = np.linspace(0.1, 2.0, 20)
    #     sigma_bs = np.linspace(0.01, 0.1, 10)
    n_ps = 4
    ps = np.zeros((n_ps, nw, nb, T, L))
    ps[0] = 1.0
    qs = np.zeros((nw, nb, T, L))
    for i_w, sigma_w in enumerate(sigma_ws):
        for i_b, sigma_b in enumerate(sigma_bs):
            for t in range(T-1):
                for l in range(L):
                    ratio, pre_p, pre_ph, pst_p, pst_ph, pst_pt, pst_beta = None, \
                                None, None, None, None, None, None
                    if l < L-1:
                        ratio = args.ds[l+1] / args.ds[l]
                        pst_p, pst_ph, pst_pt, pst_beta = ps[:, i_w, i_b, t, l+1]
                    if l > 0:
                        ratio_rev = args.ds[l-1] / args.ds[l]
                        pre_p, pre_ph = ps[:2, i_w, i_b, t, l-1]
                    p, ph, pt, beta = ps[:, i_w, i_b, t, l]
                    pin = (p, ph, pt, beta, pre_p, pst_p, pst_ph, pst_beta)
                    pmap_0612(l, pin, sigma_w, sigma_b, eta, ratio, L, ps[:, i_w, i_b, t+1, l])
                    qin = (p, ph, pre_p)
                    if l > 0:
                        qs[i_w, i_b, t, l] = qmap_d(qin, sigma_w, sigma_b, eta, ratio_rev)
    return np.concatenate((ps[:-1],np.expand_dims(qs,0)))


def pmap(l, pin, sigma_w, sigma_b, eta, ratio, L, ps_):
    p, ph, pt, pc, prev_p, next_p, next_ph = pin
    vw = sigma_w**2; vb = sigma_b**2
    if l > 0 and l < L-1:
        ps_[0] = (1-2*eta)*p + 2*eta*(ph + pt - pc)
    if l > 0:
        ps_[1] = (1-eta*(2+(vw)))*ph + eta*((vw)*(p + prev_p) + vb)
    if l < L-1:
        # ps_[2] = (1-2*eta)*pt + eta*(vw)*(p + ratio * (next_p + next_ph)) # old ver
        # ps_[2] = (1-2*eta)*pt + eta*(vw)*(p + ratio * (next_p - next_ph))
        ps_[2] = (1-2*eta)*pt + eta*(p + (vw) * ratio * (next_p - next_ph)) # sol 1


def pmap_0612(l, pin, sigma_w, sigma_b, eta, ratio, L, ps_):
    p, ph, pt, beta, prev_p, next_p, next_ph, next_beta = pin
    vw = sigma_w**2; vb = sigma_b**2
    if l > 0 and l < L-1:
        ps_[0] = (1-2*eta*(1+vw * ratio))*p + 2*eta*(ph + pt)
    if l > 0:
        ps_[1] = (1-2*eta*(1+vw))*ph + eta*(beta + (vw)*(p + prev_p) + vb)
        ps_[3] = (1-eta)*beta + eta*vb
    if l < L-1:
        ps_[2] = (1-2*eta*(1+vw))*pt + eta*(vw * ratio * (next_p - next_ph - next_beta)) # sol 1


def pmap_guess(l, pin, sigma_w, sigma_b, eta, ratio, L, ps_):
    p, ph, pt, pc, pre_p, pre_ph, pst_p, pst_ph, pst_pt = pin
    vw = sigma_w**2; vb = sigma_b**2
    if l > 0 and l < L-1:
        ps_[0] = (1-2*eta)*p +2*eta*(ph + pt - vw * p)
    if l > 0:
        ps_[1] = (1-2*eta)*ph + eta*(vw*(pre_ph*ph/pre_p + p - ph + pre_p + ph*pt/p - ph) + vb)
    if l < L-1:
        ps_[2] = (1-2*eta)*pt + eta* vw*(p + pst_pt*pt/pst_p - pt + ph*pt/p + ratio * (pst_p - pst_ph))


def qmap_d(pin, sigma_w, sigma_b, eta, ratio,
         epsabs=epsabs, epsrel=epsrel, zmin=-10, zmax=10, dz=dz, fast=True):
    p, ph, prev_p = pin
    vw = sigma_w**2; vb = sigma_b**2
    return p + (vw)*ratio*prev_p + vb - 2*ph

# TODO: compute over linspace of sigma_w and sigma_b
def get_pmaps_theory_grid(args, weight_sigmas, bias_sigmas):
    # nt = num_iters (T), layer should be fixed for any index
        # num_iters != num_layers
    eta = args.eta; L = args.n_layers; T = args.T
    sigma_ws = weight_sigmas
    sigma_bs = bias_sigmas
    nw = len(sigma_ws); nb = len(sigma_bs)
    # set initial values
    n_ps = 4
    ps = np.zeros((n_ps, nw, nb, L, T))
    ps[0] = 1.0
    qs = np.zeros((nw, nb, L, T-1))
    for i_w, sigma_w in enumerate(sigma_ws):
        for i_b, sigma_b in enumerate(sigma_bs):
            vw = sigma_w**2; vb = sigma_b**2
            for t in range(T-1):
                for l in range(L):
                    ratio, prev_p, next_p, next_ph = \
                        None, None, None, None
                    if l < L-1:
                        ratio = args.ds[l+1] / args.ds[l]
                        next_p, next_ph = ps[:2, i_w, i_b, l+1, t]
                    if l > 0:
                        ratio_rev = args.ds[l-1] / args.ds[l]
                        prev_p = ps[0, i_w, i_b, l-1, t]
                    p, ph, pt, pc = ps[:, i_w, i_b, l, t]
                    pin = (p, ph, pt, pc, prev_p, next_p, next_ph)
                    pmap(l, pin, sigma_w, sigma_b, eta, ratio,
                                    L, ps[:-1, i_w, i_b, l, t+1])
                    qin = (p, ph, prev_p)
                    if l > 0:
                        ps[3, i_w, i_b, l, t+1] = \
                            (vw) * ps[0, i_w, i_b, l, t+1]
                        qs[i_w, i_b, l, t] = qmap_d(qin, sigma_w, \
                                        sigma_b, eta, ratio_rev)
    return ps[:,:,:,:,-1], qs[:,:,:,-1]


# Perform Gaussian integral
# def integrand(z, vs):
    # sign = (vs/abs(vs)) if vs else 1
    # return norm.pdf(z[:, None]) * (sign * \
#         np.sqrt(np.expand_dims([abs(vs)],0)) * z[:, None])**2
# int_p, int_ph, int_pt, int_pc, int_pp, int_np, int_nph = \
#     [fast_integral(integrand, pin[i], zmin, zmax, dz=dz) \
#         for i in range(len(pin))]


def fast_integral(integrand, vs, zmin, zmax, dz, ndim=1):
    zs = np.r_[zmin:zmax:dz]
    if ndim > 1:
        zgrid = np.meshgrid(*((zs,) * ndim))
    else:
        zgrid = (zs,)
    out = integrand(*zgrid, vs)
    return out.sum(tuple(np.arange(ndim))) * dz**ndim


def p_fixed_point(weight_sigma, bias_sigma, nonlinearity, max_iter=500, tol=1e-9, pinit=3.0, fast=True, tol_frac=0.01):
    """Compute fixed point of q map"""
    p = pinit
    ps = []
    for i in range(max_iter):
        pnew = pmap(p, weight_sigma, bias_sigma, nonlinearity, fast=fast)
        err = np.abs(pnew - p)
        ps.append(p)
        if err < tol:
            break
        p = pnew
    # Find first time it gets within tol_frac fracitonal error of q*
    frac_err = (np.array(ps) - p)**2 / (1e-9 + p**2)
    t = np.flatnonzero(frac_err < tol_frac)[0]
    return t, p


if __name__ == '__main__':
    from collections import namedtuple
    Args = namedtuple('Args', ['eta', 'sigma_b', 'n_conds',
                                'n_layers', 'T', 'ds', 'device'])
    args = Args(eta=0.1, sigma_b=0.03, n_conds=4, n_layers=5, T=100,
                ds=[100]*5, device='cuda')
    run_cond_vs = [0.5, 1., 1.5,2.0]
    # pmaps, qmaps = get_pmaps_theory(args, run_cond_vs)
    # print(pmaps[0])
    sigma_ws = [1.0, 5.0]
    sigma_bs = [0.01, 0.03, 0.09]
    etas = [0.01, 0.02, 0.04, 0.08, 0.16]
    p,q,u,v=get_p0_thm(args, sigma_ws, sigma_bs, etas)
