import torch
from torch.func import vmap, jacrev
from scipy.stats import binom,norm

def gaussian_ppf_torch(p: torch.Tensor, loc: torch.Tensor | float = 0.0, scale: torch.Tensor | float = 1.0, alpha_exp=1):

    p = 0.5 * (torch.as_tensor(p)**alpha_exp)

    loc = torch.as_tensor(loc, device=p.device, dtype=p.dtype)
    scale = torch.as_tensor(scale, device=p.device, dtype=p.dtype)
    eps = torch.finfo(p.dtype).eps
    p = p.clamp(min=eps, max=1 - eps)

    z = torch.sqrt(torch.tensor(2.0, device=p.device, dtype=p.dtype)) * torch.special.erfinv(2 * p - 1)

    v = loc + scale * z
    v = torch.clamp(v, max=0.0)
    return v

def gauss_newton_projection_gpu(
    x: torch.Tensor,
    t: torch.Tensor,
    projection_info,
    num_iter: int = 10,
    tol: float = 1e-13,
    bound_tolerance = 1e-13,
    alpha_exp: float = 1,
):
   
    sol = projection_info['solution'][:]               # shape (nsim, nx, nt)
    ic_ref = projection_info['initial_condition'][:]   # shape (nsim, nx)
    g_L = projection_info['bc_flux_left'][:]           # shape (nsim,)
    g_R = projection_info['bc_flux_right'][:]          # shape (nsim,)
    rho = projection_info['rho'][:]
    nsim, nx, nt = sol.shape

    device = x.device
    dtype = x.dtype


    ic_ref = ic_ref.to(device, dtype=dtype)
    g_L = g_L.to(device, dtype=dtype)
    g_R = g_R.to(device, dtype=dtype)
    rho = rho.to(device, dtype=dtype)

    
    if t.numel() == 1:
        t = t.expand(nsim)
    t = t.to(device=device, dtype=dtype)
    v = gaussian_ppf_torch(p=t, loc=0, scale=1 - t, alpha_exp=alpha_exp)  # (nsim,)

    dx = 1.0 / nx
    dt = 1.0 / (nt - 1)
    flux = g_L - g_R  # (batch_size,)

    u = x.clone().detach()


    term1 = -bound_tolerance + v
    term2 = bound_tolerance - v
    low  = ic_ref + term1[:, None]
    high = ic_ref + term2[:, None]
    u[:, :, 0] = torch.max(torch.min(u[:, :, 0], high), low)
    

    def compute_mass_residuals_single(u_single, ic_single, flux_single, rho_single, v_single, t_single):
        # u_single: [nx, nt], ic_single: [nx], flux_single: scalar, v_single: scalar
        mass0 = torch.sum(ic_single) * dx
        

        acc = torch.zeros_like(mass0)
        

        R_prev = rho_single * dx * torch.sum(ic_single - ic_single**2)
        
        residuals = []
        
        
        for k in range(1, nt):

            mk = torch.sum(u_single[:, k]) * dx
            

            acc = acc + (R_prev + flux_single) * dt
            m_pred_k = mass0 + acc
            

            h_upper = (t_single * mk - m_pred_k) - (t_single * bound_tolerance - v_single)
            residuals.append(h_upper)
 
            h_lower = -(t_single * mk - m_pred_k) + (-t_single * bound_tolerance + v_single)
            residuals.append(h_lower)
            

            R_prev = rho_single * dx * torch.sum(u_single[:, k] - u_single[:, k]**2)
            
        return torch.stack(residuals) # Shape: [2 * (nt - 1)]

    def compute_residuals_single(u_single, ic_single, flux_single, rho_single, v_single, t_single):

        mass_res = compute_mass_residuals_single(u_single, ic_single, flux_single, rho_single, v_single, t_single)  # [2*(nt-1)]
        return mass_res
    

    compute_residuals_batched = vmap(compute_residuals_single)
    compute_jacobian_batched = vmap(jacrev(compute_residuals_single))


    for i in range(num_iter):

        u[:, :, 0] = torch.max(torch.min(u[:, :, 0], high), low)
        h_full = compute_residuals_batched(u, ic_ref, flux, rho, v, t)
        m_constraints = h_full.shape[1]

        h_active = torch.relu(h_full)


        active_residual_norm = torch.linalg.norm(h_active)
        if active_residual_norm < tol:
            print(f"Iteration {i}: Converged with active residual norm {active_residual_norm:.2e}")
            break
        if i == 0:
             print(f"Iteration {i}: Initial active residual norm {active_residual_norm:.2e}")
        else:
             print(f"Iteration {i}: Active residual norm {active_residual_norm:.2e}")


        J = compute_jacobian_batched(u, ic_ref, flux, rho, v, t)
        J_flat = J.view(nsim, m_constraints, -1) # Shape: (batch_size, m, nx*nt)


        JJT = J_flat @ J_flat.transpose(-1, -2)
        

        reg = torch.eye(m_constraints, device=device, dtype=dtype) * 1e-6
        update_direction = torch.linalg.solve(JJT + reg, h_active)


        delta_u_flat = -J_flat.transpose(-1, -2) @ update_direction.unsqueeze(-1)
        delta_u = delta_u_flat.squeeze(-1).view(nsim, nx, nt)


        u = u + delta_u

    return u

@torch.no_grad()
def gauss_newton_projection_ns2d_gpu(
    x: torch.Tensor,
    t: torch.Tensor,
    projection_info,
    bound_tolerance: float = 1e-13,
    alpha_exp: float = 1,
):
    

    u = x.clone()  # (N, s, s, T)
    device = u.device
    dtype  = u.dtype
    N, s1, s2, T = u.shape
    assert s1 == s2, "Only square grids are assumed."
    s = s1

    ic = projection_info['initial_condition'][:]  # (N, s, s)
    ic = ic.to(device=device, dtype=dtype)


    dx = torch.as_tensor(1.0 / s, device=device, dtype=dtype)
    dy = torch.as_tensor(1.0 / s, device=device, dtype=dtype)
    area_elem = dx * dy
    total_area = (s * s) * area_elem  


    if t.numel() == 1:
        t = t.expand(N)
    t = t.to(device=device, dtype=dtype)
    v = gaussian_ppf_torch(p=t,loc=0,scale=1-t,alpha_exp=alpha_exp)  # (N,)


    term1 = -bound_tolerance + v
    term1_reshape = term1[:, None, None]
    term2 = bound_tolerance - v
    term2_reshape = term2[:, None, None]
    low  = ic + term1_reshape
    high = ic + term2_reshape
    u[:, :, :, 0] = torch.max(torch.min(u[:, :, :, 0], high), low)



    mass0 = (ic.view(N, -1).sum(dim=1)) * area_elem  # (N,)
    if T >= 2:
   
        mass_k = (u[:, :, :, 1:].reshape(N, -1, T - 1).sum(dim=1)) * area_elem  # (N, T-1)
        diff_k = mass_k - (t.view(N, 1) * mass0.view(N, 1))                                      # (N, T-1)

        h_upper = diff_k - (term2.view(N, 1))
        h_lower = -diff_k + (term1.view(N, 1))
        pos_upper = torch.clamp(h_upper, min=0.0)
        pos_lower = torch.clamp(h_lower, min=0.0)

 
        shift = (pos_upper - pos_lower) / total_area  # (N, T-1)

        shift_ = shift.view(N, 1, 1, T - 1)
        u[:, :, :, 1:] = u[:, :, :, 1:] - shift_



        max_abs_shift = shift.abs().max().item()

        mass_k2 = (u[:, :, :, 1:].reshape(N, -1, T - 1).sum(dim=1)) * area_elem
        drift_after = (mass_k2 - (t.view(N, 1) * mass0.view(N, 1))).abs().max().item()
        print(f"[Mass] max |shift| per-slice = {max_abs_shift:.3e}; max drift after = {drift_after:.3e}")



    return u


def ns2d_constraint_loss(
    x: torch.Tensor,
    projection_info,
    bound_tolerance: float = 1e-15,
):

    device, dtype = x.device, x.dtype
    N, s1, s2, T = x.shape
    assert s1 == s2, "Only square grids are assumed for NS 2D."
    s = s1

    ic = projection_info['initial_condition'][:]
    ic = torch.as_tensor(ic, device=device, dtype=dtype)  # (N, s, s)


    dx = dy = 1.0 / s
    area_elem = torch.as_tensor(dx * dy, device=device, dtype=dtype)


    x0 = x[..., 0]  # (N, s, s)
    low  = ic - bound_tolerance
    high = ic + bound_tolerance
    ic_upper = torch.relu(x0 - high)
    ic_lower = torch.relu(low - x0)
    loss_ic = (ic_upper.pow(2) + ic_lower.pow(2)).mean()


    mass0 = ic.view(N, -1).sum(dim=1) * area_elem              # (N,)
    if T > 1:
        mass_k = x[..., 1:].reshape(N, -1, T - 1).sum(dim=1) * area_elem  # (N, T-1)
        diff = (mass_k - mass0.unsqueeze(1)).abs() - bound_tolerance      # (N, T-1)
        mass_violation = torch.relu(diff)                                  # (N, T-1)
        loss_mass = mass_violation.pow(2).mean()
    else:
        loss_mass = torch.zeros((), device=device, dtype=dtype)

    loss = loss_ic + loss_mass
    loss = loss/x0.shape[0]
    return loss, {"loss_ic": loss_ic.detach(), "loss_mass": loss_mass.detach()}




def rd_constraint_loss(
    x: torch.Tensor,
    projection_info,
    bound_tolerance: float = 1e-15,
):

    device, dtype = x.device, x.dtype
    N, nx, T = x.shape

    ic = projection_info['initial_condition'][:]
    ic = torch.as_tensor(ic, device=device, dtype=dtype)  # (N, nx)

    g_L = projection_info['bc_flux_left'][:]
    g_R = projection_info['bc_flux_right'][:]
    g_L = torch.as_tensor(g_L, device=device, dtype=dtype)  # (N,)
    g_R = torch.as_tensor(g_R, device=device, dtype=dtype)  # (N,)
    flux = g_L - g_R  # (N,)

    rho = projection_info['rho'][:]
    rho = torch.as_tensor(rho, device=device, dtype=dtype)  

    dx = torch.as_tensor(1.0 / nx, device=device, dtype=dtype)
    dt = torch.as_tensor(0.0 if T <= 1 else 1.0 / (T - 1), device=device, dtype=dtype)

    x0 = x[:, :, 0]  # (N, nx)
    low = ic - bound_tolerance
    high = ic + bound_tolerance
    ic_upper = torch.relu(x0 - high)
    ic_lower = torch.relu(low - x0)
    loss_ic = (ic_upper.pow(2) + ic_lower.pow(2)).mean()


    mass0 = ic.sum(dim=1) * dx  # (N,)
    if T > 1:
 
        mass_k = x[:, :, 1:].sum(dim=1) * dx  # (N, T-1)


        u_for_reac = x[:, :, :T - 1]  # (N, nx, T-1)

        if rho.numel() == 1:
            rho_ = rho.expand(N, 1)
        else:
            rho_flat = rho.reshape(-1)
            assert rho_flat.shape[0] == N, "rho batch size mismatch with x"
            rho_ = rho_flat.view(N, 1)
        R = rho_ * dx * (u_for_reac - u_for_reac.pow(2)).sum(dim=1)  # (N, T-1)

        flux_vec = flux.view(N, 1)  # (N, 1)

        acc = torch.cumsum(R + flux_vec, dim=1) * dt  # (N, T-1)
        m_pred = mass0.view(N, 1) + acc  # (N, T-1)

        diff = (mass_k - m_pred).abs() - bound_tolerance  # (N, T-1)
        mass_violation = torch.relu(diff)
        loss_mass = mass_violation.pow(2).mean()
    else:
        loss_mass = torch.zeros((), device=device, dtype=dtype)

    loss = loss_ic + loss_mass
    loss = loss / x0.shape[0]
    return loss, {"loss_ic": loss_ic.detach(), "loss_mass": loss_mass.detach()}

    

    