import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# -----------------------------
# Config
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Riemann data (shock case)
u_L, u_R = 1.0, 0.0
T_final = 0.6
X_MIN, X_MAX = -1.0, 1.0

# Sampling sizes
N_f = 8000   # interior PDE points
N0  = 1024   # initial condition points
Nb  = 512    # boundary time samples

# Training
EPOCHS = 2000
LR = 1e-3           # base LR for non-activation params
LR_ACT = LR * 3.0   # LR for CauchyAct internal params (raw_l1/raw_l2/raw_d/gamma)
PRINT_EVERY = 100

Wb, Wi, Wc = 1.0, 1.0, 5.0

#  Inviscid NU=:0.0
NU = 1e-3

# "LSNN-style" progress splits for matching log format
NB = 3
T_EDGES = torch.linspace(0.0, T_final, NB+1)                # [0.0, 0.2, 0.4, 0.6]
EP_EDGES = [int(EPOCHS * k / NB) for k in range(1, NB+1)]   # epoch indices hitting 1/3, 2/3, 3/3

# -----------------------------
# Exact solution (shock case)
# -----------------------------
@torch.no_grad()
def exact_riemann_burgers(x, t, uL=u_L, uR=u_R):
    s = 0.5*(uL + uR)  # Rankine–Hugoniot for Burgers
    xs = s * t
    return torch.where(x < xs, torch.full_like(x, uL), torch.full_like(x, uR))

# -----------------------------
# CauchyAct & Net
# -----------------------------
class CauchyAct(nn.Module):
    """
    φ(z) = (l1*z + l2) / (z^2 + d^2) + γ z
    l1,l2 = tanh(raw_*), d = softplus(raw_d)+eps.
    """
    def __init__(self, dim: int, d0: float = 0.5, eps: float = 1e-3,
                 l1_init: float = 0.8, l2_init: float = 0.0,
                 gamma0: float = 0.1, use_gamma: bool = False):
        super().__init__()
        l1_init = float(np.clip(l1_init, -0.99, 0.99))
        l2_init = float(np.clip(l2_init, -0.99, 0.99))
        raw_l1_init = float(np.arctanh(l1_init))
        raw_l2_init = float(np.arctanh(l2_init))
        self.raw_l1 = nn.Parameter(torch.full((dim,), raw_l1_init))
        self.raw_l2 = nn.Parameter(torch.full((dim,), raw_l2_init))
        self.raw_d  = nn.Parameter(torch.full((dim,), math.log(math.expm1(d0))))
        self.eps = eps
        self.use_gamma = use_gamma
        if use_gamma:
            self.gamma = nn.Parameter(torch.tensor(gamma0))
        else:
            self.register_buffer("gamma", torch.tensor(0.0))

    @property
    def l1(self): return torch.tanh(self.raw_l1)
    @property
    def l2(self): return torch.tanh(self.raw_l2)
    @property
    def d (self): return F.softplus(self.raw_d) + self.eps

    def forward(self, z):
        d = self.d
        return (self.l1 * z + self.l2) / (z*z + d*d) + self.gamma * z

class Net(nn.Module):
    def __init__(self, in_dim=2, hidden=1, width=20, out_dim=1,
                 act="cauchy", t_scale=1.0/T_final):
        super().__init__()
        assert hidden >= 1
        self.fc_in = nn.Linear(in_dim, width)
        self.fcs   = nn.ModuleList([nn.Linear(width, width) for _ in range(hidden-1)])
        self.fc_out = nn.Linear(width, out_dim)
        self.use_cauchy = (act == "cauchy")
        if self.use_cauchy:
            
            self.acts = nn.ModuleList([CauchyAct(width, use_gamma=False) for _ in range(hidden)])
        else:
            act_map = {
                "tanh": nn.Tanh(),
                "gelu": nn.GELU(),
                "relu": nn.ReLU(),
                "silu": nn.SiLU(),
                "sigmoid": nn.Sigmoid(),
            }
            assert act in act_map, f"unknown act: {act}"
            self.act = act_map[act]
        self.register_buffer("in_scale", torch.tensor([1.0, t_scale], dtype=torch.float32))
        nn.init.xavier_uniform_(self.fc_in.weight);  nn.init.zeros_(self.fc_in.bias)
        nn.init.xavier_uniform_(self.fc_out.weight); nn.init.zeros_(self.fc_out.bias)

    def forward(self, x):
        x = x * self.in_scale              # [B,2] = [x, t]
        h = self.fc_in(x)
        if self.use_cauchy:
            h = self.acts[0](h)
            for i, fc in enumerate(self.fcs, start=1):
                h = self.acts[i](fc(h))
        else:
            h = self.act(h)
            for fc in self.fcs:
                h = self.act(fc(h))
        return self.fc_out(h)

# -----------------------------
# Autograd & Losses (MSE only)
# -----------------------------
def gradients(y, x):
    ones = torch.ones_like(y, device=y.device)
    return torch.autograd.grad(y, x, grad_outputs=ones, create_graph=True, retain_graph=True)[0]

def pde_residual(model, x_f, t_f, nu):
    """Strong-form residual: R(u) = u_t + u*u_x - nu u_xx"""
    xt = torch.cat([x_f, t_f], dim=1).requires_grad_(True)
    u = model(xt)
    u_xt = gradients(u, xt)   # [du/dx, du/dt]
    u_x = u_xt[:, [0]]
    u_t = u_xt[:, [1]]
    flux_x = u * u_x
    if nu > 0.0:
        u_xx = gradients(u_x, xt)[:, [0]]
    else:
        u_xx = torch.zeros_like(u_x)
    return u_t + flux_x - nu * u_xx

def loss_bulk(model, x_f, t_f, nu):
    R = pde_residual(model, x_f, t_f, nu)
    return torch.mean(R**2)

def loss_ic(model, x0):
    t0 = torch.zeros_like(x0)
    xt = torch.cat([x0, t0], dim=1)
    u_pred = model(xt)
    u_gt = torch.where(x0 < 0, torch.full_like(x0, u_L), torch.full_like(x0, u_R))
    return torch.mean((u_pred - u_gt)**2)

def loss_bc(model, tb, enforce_right=True):
    xL = torch.full_like(tb, X_MIN)
    uL_pred = model(torch.cat([xL, tb], dim=1))
    lL = torch.mean((uL_pred - u_L)**2)
    if not enforce_right:
        return lL
    xR = torch.full_like(tb, X_MAX)
    uR_pred = model(torch.cat([xR, tb], dim=1))
    lR = torch.mean((uR_pred - u_R)**2)
    return 0.5*(lL + lR)

# -----------------------------
# Uniform samplers (no method tricks)
# -----------------------------
def sample_interior_uniform(N):
    x = X_MIN + (X_MAX - X_MIN) * torch.rand(N, 1, device=DEVICE)
    t = (T_final) * torch.rand(N, 1, device=DEVICE)
    return x, t

def sample_ic(N):
    x = X_MIN + (X_MAX - X_MIN) * torch.rand(N, 1, device=DEVICE)
    return x

def sample_bc(N):
    t = (T_final) * torch.rand(N, 1, device=DEVICE)
    return t

# -----------------------------
# Metrics & LSNN-style prints
# -----------------------------
@torch.no_grad()
def shock_position_halfheight(model, t_scalar, uL=1.0, uR=0.0, nx=4001):
    x = torch.linspace(X_MIN, X_MAX, nx, device=DEVICE).view(-1,1)
    t = torch.full_like(x, t_scalar)
    u = model(torch.cat([x,t], dim=1)).view(-1)
    mid = 0.5*(uL + uR)
    y = u - mid
    sgn = torch.sign(y)
    idx = torch.nonzero(sgn[:-1] * sgn[1:] < 0, as_tuple=False).view(-1)
    if len(idx) == 0:
        j = torch.argmin(torch.abs(y))
        return x[j].item()
    i = idx[0].item()
    x0, x1 = x[i].item(), x[i+1].item()
    y0, y1 = y[i].item(), y[i+1].item()
    return x0 - y0*(x1 - x0)/(y1 - y0)

@torch.no_grad()
def rel_l2(model, t, nx=4001):
    x = torch.linspace(X_MIN, X_MAX, nx, device=DEVICE).view(-1,1)
    u_num = model(torch.cat([x, torch.full_like(x, t)], 1)).view(-1)
    u_ex  = exact_riemann_burgers(x.view(-1,1), t).view(-1)
    return (torch.linalg.norm(u_num-u_ex)/torch.linalg.norm(u_ex)).item()

@torch.no_grad()
def print_lsnn_style_block_evals(model, k):
    s = 0.5*(u_L + u_R)
    t_a = T_EDGES[k-1].item()
    t_b = T_EDGES[k].item()
    for t in [0.5*(t_a + t_b), t_b]:
        xs_true = s * t
        xs_pred = shock_position_halfheight(model, t)
        print(f"[eval after block {k}] t={t:.3f}  x_s(true)={xs_true:.3f}  x_s(pred)={xs_pred:.3f}  |err|={abs(xs_true-xs_pred):.3e}")

@torch.no_grad()
def print_lsnn_style_final_evals(model):
    s = 0.5*(u_L + u_R)
    for t in [0.2, 0.4, 0.6]:
        xs_true = s * t
        xs_pred = shock_position_halfheight(model, t)
        print(f"t={t:.1f}  x_s(true)={xs_true:.3f}  x_s(pred)={xs_pred:.3f}  |err|={abs(xs_true-xs_pred):.3e}")

# -----------------------------
# Optim: split LR for CauchyAct params
# -----------------------------
def split_params_for_cauchy(model):
    act, other = [], []
    for n, p in model.named_parameters():
        if any(k in n for k in ["raw_l1","raw_l2","raw_d","gamma"]):
            act.append(p)
        else:
            other.append(p)
    return act, other

# -----------------------------
# Training
# -----------------------------
def main(seed=0, width=20, hidden=1, act="cauchy", log_per_epoch=False):
    torch.manual_seed(seed)


    model = Net(in_dim=2, hidden=hidden, width=width, out_dim=1, act=act).to(DEVICE)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"trainable_params = {n_params}")

    act_params, other_params = split_params_for_cauchy(model)
    if len(act_params) == 0:
        opt = optim.Adam(model.parameters(), lr=LR)
    else:
        opt = optim.Adam([
            {"params": other_params, "lr": LR},
            {"params": act_params,   "lr": LR_ACT},
        ])

    opt = optim.Adam([
        {"params": other_params, "lr": LR},
        {"params": act_params,   "lr": LR_ACT},
    ])
    sched = optim.lr_scheduler.StepLR(opt, step_size=500, gamma=0.5)

    next_edge_idx = 0
    try:
        for ep in range(1, EPOCHS+1):
            # uniform resample each epoch
            x_f, t_f = sample_interior_uniform(N_f)
            x0 = sample_ic(N0)
            tb = sample_bc(Nb)

            opt.zero_grad()
            lb  = loss_bulk(model, x_f, t_f, NU)             # MSE strong-form residual
            lic = loss_ic(model, x0)
            lbc = loss_bc(model, tb, enforce_right=True)
            loss = Wb*lb + Wi*lic + Wc*lbc
            loss.backward()
            opt.step()
            sched.step()

            # optional per-epoch log
            if log_per_epoch and (ep % PRINT_EVERY == 0 or ep == 1):
                with torch.no_grad():
                    r2 = rel_l2(model, 0.2); r4 = rel_l2(model, 0.4); r6 = rel_l2(model, 0.6)
                    print(f"[{ep:4d}] loss={loss.item():.4e} | bulk={lb.item():.4e} ic={lic.item():.4e} bc={lbc.item():.4e} "
                          f"| RelL2: t=0.2 {r2:.3e}, t=0.4 {r4:.3e}, t=0.6 {r6:.3e}")

            # LSNN-style block evals at 1/3, 2/3, 3/3 progress
            if next_edge_idx < len(EP_EDGES) and ep == EP_EDGES[next_edge_idx]:
                k = next_edge_idx + 1
                print_lsnn_style_block_evals(model, k)
                next_edge_idx += 1

    except KeyboardInterrupt:
        print("Interrupted — saving checkpoint...")
    finally:
        print_lsnn_style_final_evals(model)
        torch.save(model.state_dict(), "burgers_riemann_pinn_cauchy_params_only.pt")
    return model

if __name__ == "__main__":
    #  width=20
    main(seed=0, width=20, hidden=1, act="cauchy", log_per_epoch=False)
    #main(seed=0, width=35, hidden=1, act="tanh", log_per_epoch=False)
    #main(seed=0, width=35, hidden=1, act="gelu", log_per_epoch=False)


