# lsnn_burgers_riemann_cai_style.py
# Cai-style LSNN for inviscid Burgers Riemann problem

import math, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import copy

# =============== global setup ===============
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float32)
torch.manual_seed(0); random.seed(0); np.random.seed(0)

# Problem/domain
X_MIN, X_MAX = -1.0, 1.0
T0, T1 = 0.0, 0.6
u_L, u_R = 1.0, 0.0

# Time blocking
NB = 3
T_EDGES = torch.linspace(T0, T1, NB+1)  # [0.0, 0.2, 0.4, 0.6]

# Hyper-params (numerics unchanged)
ALPHA = 20.0
ITERS_PER_BLOCK = 30000
LR0 = 3e-3
BATCH_CELLS = 4096

# Composite trapezoid & mesh
P_SPACE = 2
P_TIME  = 2
H = 0.01
DT = 0.01



# =============== network ===============
class NetSigmoid(nn.Module):
    # 2-10-10-1 with Sigmoid activations (Cai 笔记本风格)
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 1)
        self.act = nn.Sigmoid()
        for m in [self.fc1, self.fc2, self.fc3]:
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, xt):
        h = self.act(self.fc1(xt))
        h = self.act(self.fc2(h))
        return self.fc3(h)  # u(x,t)

# 若想用 ReLU：把下面一行改为 Net = NetReLU（再实现一个 ReLU+Kaiming 的类即可）
Net = NetSigmoid

# =============== composite trapezoid ===============
def trap_nodes_1d(a, b, p):
    p = int(p)
    w = (b - a) / p
    idx = torch.arange(0, p+1, device=a.device, dtype=a.dtype).view(1, -1)
    nodes = a + w * idx
    weights = w.repeat(1, p+1)
    weights[:, 0] *= 0.5
    weights[:, -1] *= 0.5
    return nodes, weights

def _eval_on_grid(model, X, T):
    XT = torch.stack([X, T], dim=-1).reshape(-1, 2)
    u = model(XT).reshape(X.shape)
    return u

# =============== discrete divergence on control volumes ===============
def divT_on_cells(model, x0, x1, t0, t1):
    # time integral on vertical faces
    t_nodes, t_w = trap_nodes_1d(t0, t1, P_TIME)      # (B, PT1)
    PT1 = t_nodes.shape[1]
    u_x1 = _eval_on_grid(model, x1.repeat(1, PT1), t_nodes)
    u_x0 = _eval_on_grid(model, x0.repeat(1, PT1), t_nodes)
    sig_diff = 0.5*(u_x1*u_x1) - 0.5*(u_x0*u_x0)
    time_int = torch.sum(sig_diff * t_w, dim=1, keepdim=True) / (t1 - t0)

    # space integral on horizontal faces
    x_nodes, x_w = trap_nodes_1d(x0, x1, P_SPACE)     # (B, PX1)
    PX1 = x_nodes.shape[1]
    u_t1 = _eval_on_grid(model, x_nodes, t1.repeat(1, PX1))
    u_t0 = _eval_on_grid(model, x_nodes, t0.repeat(1, PX1))
    u_diff = u_t1 - u_t0
    space_int = torch.sum(u_diff * x_w, dim=1, keepdim=True) / (x1 - x0)

    return time_int + space_int                       # (B,1)

# =============== sampling utils (fix generator usage) ===============
def _rand_uniform(shape, lo, hi, generator=None):
    if generator is None:
        return lo + (hi - lo) * torch.rand(*shape, device=DEVICE)
    # 用 CPU 生成器采样，再搬到 DEVICE（兼容旧版 PyTorch）
    r = lo + (hi - lo) * torch.rand(*shape, device="cpu", generator=generator)
    return r.to(DEVICE)

def sample_cells(B, t_a, t_b, *, generator=None):
    x0 = _rand_uniform((B,1), X_MIN, X_MAX - H, generator)
    t0 = _rand_uniform((B,1), t_a,  t_b   - DT, generator)
    x0 = torch.round(x0 / H) * H
    t0 = torch.round(t0 / DT) * DT
    x1 = (x0 + H).clamp_max_(X_MAX)
    t1 = (t0 + DT).clamp_max_(t_b - 1e-12)
    return x0, x1, t0, t1

def sample_inflow_left(M, t_a, t_b, *, generator=None):
    t = _rand_uniform((M,1), t_a, t_b, generator)
    x = torch.full_like(t, X_MIN)
    y = torch.full_like(t, u_L)
    return torch.cat([x, t], dim=1), y

def sample_inflow_right(M, t_a, t_b, *, generator=None):
    t = _rand_uniform((M,1), t_a, t_b, generator)
    x = torch.full_like(t, X_MAX)
    y = torch.full_like(t, u_R)
    return torch.cat([x, t], dim=1), y

def sample_initial(M, *, generator=None):
    x = _rand_uniform((M,1), X_MIN, X_MAX, generator)
    t = torch.zeros_like(x)
    y = torch.where(x < 0, torch.full_like(x, u_L), torch.full_like(x, u_R))
    return torch.cat([x, t], dim=1), y

# =============== Cai-style utils: exp LR & fixed val-batch ===============
def exp_lr_scheduler(optimizer, it, *, base_lr=LR0, decay_gamma=0.9, decay_every=5000):
    k = it // decay_every
    new_lr = base_lr * (decay_gamma ** k)
    for pg in optimizer.param_groups:
        pg['lr'] = new_lr
    return new_lr

@torch.no_grad()
def make_val_batch(t_a, t_b):
    # 用 CPU generator，确保不同环境都兼容
    g = torch.Generator().manual_seed(123456 + int(1000*t_a))
    cells  = sample_cells(4096, t_a, t_b, generator=g)
    L = sample_inflow_left(1024, t_a, t_b, generator=g)
    R = sample_inflow_right(1024, t_a, t_b, generator=g)
    x = _rand_uniform((2048,1), X_MIN, X_MAX, g)
    t = torch.full_like(x, t_a)
    XT_if = torch.cat([x,t], dim=1)
    return cells, L, R, XT_if

@torch.no_grad()
def val_score(model, prev_model, t_a, t_b, val_pack):
    (x0,x1,t0,t1), (XT_L,yL), (XT_R,yR), XT_if = val_pack
    Rv = divT_on_cells(model, x0,x1,t0,t1)
    vb = torch.mean(Rv**2)
    vbc = torch.mean((model(XT_L)-yL)**2) + torch.mean((model(XT_R)-yR)**2)
    if prev_model is None:  # block 1: initial condition
        XT0, y0 = sample_initial(1024)
        vif = torch.mean((model(XT0)-y0)**2)
    else:
        y_if = prev_model(XT_if)
        vif = torch.mean((model(XT_if)-y_if)**2)
    return (vb + ALPHA*(vbc + vif)).item(), vb.item(), vbc.item(), vif.item()

# =============== training one block ===============
def train_block(k, model, prev_model=None):
    t_a = T_EDGES[k-1].item()
    t_b = T_EDGES[k].item()
    opt = optim.Adam(model.parameters(), lr=LR0)

    # fixed validation batch for this block
    val_pack = make_val_batch(t_a, t_b)
    best_state = copy.deepcopy(model.state_dict())
    best_val, _, _, _ = val_score(model, prev_model, t_a, t_b, val_pack)

    for it in range(1, ITERS_PER_BLOCK+1):
        cur_lr = exp_lr_scheduler(opt, it, base_lr=LR0, decay_gamma=0.9, decay_every=5000)

        opt.zero_grad()

        # (1) bulk residual
        x0, x1, t0, t1 = sample_cells(BATCH_CELLS, t_a, t_b)
        R = divT_on_cells(model, x0, x1, t0, t1)
        loss_bulk = torch.mean(R**2)

        # (2) inflow boundaries
        XT_L, yL = sample_inflow_left(1024, t_a, t_b)
        XT_R, yR = sample_inflow_right(1024, t_a, t_b)
        loss_bc = torch.mean((model(XT_L)-yL)**2) + torch.mean((model(XT_R)-yR)**2)

        # (3) initial vs interface
        if k == 1:
            XT0, y0 = sample_initial(2048)
            loss_if = torch.mean((model(XT0)-y0)**2)
        else:
            x = _rand_uniform((2048,1), X_MIN, X_MAX)  # 不再使用未定义的 generator
            t = torch.full_like(x, t_a)
            XT = torch.cat([x,t], dim=1)
            with torch.no_grad():
                y_if = prev_model(XT)
            loss_if = torch.mean((model(XT)-y_if)**2)

        loss = loss_bulk + ALPHA*(loss_bc + loss_if)
        loss.backward()
        opt.step()

        if it % 5000 == 0:
            print(f"[block {k}/{NB}] it={it:5d}  lr={cur_lr:.2e}  "
                  f"bulk={loss_bulk.item():.3e}  bc={loss_bc.item():.3e}  if={loss_if.item():.3e}  "
                  f"ALPHA*(bc+if)={(ALPHA*(loss_bc+loss_if)).item():.3e}  total={loss.item():.3e}")
            vtot, vb, vbc, vif = val_score(model, prev_model, t_a, t_b, val_pack)
            print(f"          [val] total={vtot:.3e}  bulk={vb:.3e}  bc={vbc:.3e}  if={vif:.3e}")
            if vtot < best_val:
                best_val = vtot
                best_state = copy.deepcopy(model.state_dict())

    # best-of-block
    model.load_state_dict(best_state)

    # freeze copy for next block
    frozen = Net().to(DEVICE)
    frozen.load_state_dict(model.state_dict())
    for p in frozen.parameters():
        p.requires_grad_(False)
    return frozen

# =============== robust shock locator ===============
@torch.no_grad()
def shock_pos(model, t, nx=4001):
    x = torch.linspace(X_MIN, X_MAX, nx, device=DEVICE).view(-1,1)
    tcol = torch.full_like(x, t)
    u = model(torch.cat([x,tcol], dim=1)).view(-1)
    mid = 0.5*(u_L + u_R)
    y = u - mid
    sgn = torch.sign(y)
    idx = torch.nonzero(sgn[:-1] * sgn[1:] < 0, as_tuple=False).view(-1)
    if len(idx) == 0:
        j = torch.argmin(torch.abs(y))
        return x[j].item()
    i = idx[0].item()
    x0, x1 = x[i].item(), x[i+1].item()
    y0, y1 = y[i].item(), y[i+1].item()
    return x0 - y0*(x1 - x0)/(y1 - y0)

# =============== main ===============
def main():
    model = Net().to(DEVICE)
    prev = None
    for k in range(1, NB+1):
        prev = train_block(k, model, prev_model=prev)
        for t in [(T_EDGES[k-1]+T_EDGES[k]).item()*0.5, T_EDGES[k].item()]:
            xs_true = 0.5 * t
            xs_pred = shock_pos(model, t)
            print(f"[eval after block {k}] t={t:.3f}  x_s(true)={xs_true:.3f}  x_s(pred)={xs_pred:.3f}  |err|={abs(xs_true-xs_pred):.3e}")

    for t in [0.2, 0.4, 0.6]:
        xs_true = 0.5 * t
        xs_pred = shock_pos(model, t)
        print(f"t={t:.1f}  x_s(true)={xs_true:.3f}  x_s(pred)={xs_pred:.3f}  |err|={abs(xs_true-xs_pred):.3e}")

    torch.save(model.state_dict(), "lsnn_burgers_riemann_cai_style.pt")

if __name__ == "__main__":
    main()
