# -*- coding: utf-8 -*-
# NLS PINN (u+iv) — single-layer CauchyNet from Cauchy integral (no tanh)

import math, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ----- setup -----
def set_seed(s=0):
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(0)
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# ----- problem -----
L = 5.0
X_MIN, X_MAX = -L, L
T_MIN, T_MAX = 0.0, 2.0
def sech(x): return 1.0/torch.cosh(x)

def analytic_uv(x, t, eta=2.0):
    g = eta * sech(eta * x)          # = 2*sech(2x)
    phase = 0.5 * eta * eta * t      # = 2 t
    return g*torch.cos(phase), g*torch.sin(phase)

def ic_uv(x, eta=2.0):
    return eta * sech(eta * x), torch.zeros_like(x)

# --- 2D Cauchy fitting layer: f(z,w)=Σ λ_k / ((ξ_k - z)(η_k - w)) ---
class Cauchy2DFit(nn.Module):
    """
    2D Cauchy ：
      f(x,t) = Σ_k λ_k / ((ξ_k - x) (η_k - t)),   λ_k, ξ_k, η_k complex
     (u,v) = (Re f, Im f)
    """
    def __init__(self, m=400, init_std=1e-2, eps=1e-12, use_idelta=False, delta_init=-1.6):
        super().__init__()
        self.m = int(m)
        self.eps = float(eps)
        self.use_idelta = bool(use_idelta)

        # complex λ_k = λr + i λi
        self.l_re = nn.Parameter(torch.randn(m) * init_std)
        self.l_im = nn.Parameter(torch.randn(m) * init_std)

        # complex pole ξ_k = ξr + i ξi（on x），η_k = ηr + i ηi（on t）
        self.xi_re  = nn.Parameter(torch.empty(m).uniform_(-1.0, 1.0))
        self.xi_im  = nn.Parameter(torch.zeros(m))
        self.eta_re = nn.Parameter(torch.empty(m).uniform_(-1.0, 1.0))
        self.eta_im = nn.Parameter(torch.zeros(m))           

        # optimal：i*delta stable item（image >0）
        if self.use_idelta:
            self.delta_x_raw = nn.Parameter(torch.full((m,), float(delta_init)))
            self.delta_t_raw = nn.Parameter(torch.full((m,), float(delta_init)))

    @staticmethod
    def _norm(x, t):
        # [-L,L]×[Tmin,Tmax] -> [-1,1]×[-1,1]
        x_n = 2.0*(x - 0.5*(X_MIN+X_MAX))/(X_MAX - X_MIN)
        t_n = 2.0*(t - 0.5*(T_MIN+T_MAX))/(T_MAX - T_MIN)
        return x_n, t_n

    def forward(self, x, t):
        # x,t: [B,1] （real； z=x, w=t）
        x_n, t_n = self._norm(x, t)               # [B,1]
        B, M = x_n.size(0), self.m

        lam  = torch.complex(self.l_re,  self.l_im).view(1, -1)    # [1,M]
        xi   = (self.xi_re.view(1,-1)  - x_n) + 1j*self.xi_im.view(1,-1)   # [B,M]
        eta  = (self.eta_re.view(1,-1) - t_n) + 1j*self.eta_im.view(1,-1)  # [B,M]

        if self.use_idelta:
            dx = F.softplus(self.delta_x_raw).view(1,-1)
            dt = F.softplus(self.delta_t_raw).view(1,-1)
            xi  = xi  + 1j*dx
            eta = eta + 1j*dt

        denom = xi * eta + self.eps
        fz = (lam / denom).sum(dim=1) / math.sqrt(M)    # [B]，1/√M
        u = fz.real.view(B,1); v = fz.imag.view(B,1)
        return u, v


class PINN_CauchyNet(nn.Module):
    def __init__(self, neurons=200, X_MIN=-5., X_MAX=5., T_MIN=0., T_MAX=2.):
        super().__init__()
        self.head = Cauchy2DFit(neurons)

        self.head.X_MIN, self.head.X_MAX = X_MIN, X_MAX
        self.head.T_MIN, self.head.T_MAX = T_MIN, T_MAX

    def forward(self, x, t):
        return self.head(x, t)


# ----- derivatives -----
def derivatives(model, x, t):
    x.requires_grad_(True); t.requires_grad_(True)
    u, v = model(x, t)
    u_t = torch.autograd.grad(u, t, torch.ones_like(u), create_graph=True)[0]
    v_t = torch.autograd.grad(v, t, torch.ones_like(v), create_graph=True)[0]
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    v_x = torch.autograd.grad(v, x, torch.ones_like(v), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u_x), create_graph=True)[0]
    v_xx = torch.autograd.grad(v_x, x, torch.ones_like(v_x), create_graph=True)[0]
    return u, v, u_t, v_t, u_x, v_x, u_xx, v_xx

# ----- losses -----
def pde_residual_mse(model, x_c, t_c):
    u, v, u_t, v_t, _, _, u_xx, v_xx = derivatives(model, x_c, t_c)
    r1 = u_t + 0.5*v_xx + v*(u*u + v*v)
    r2 = v_t - 0.5*u_xx - u*(u*u + v*v)
    return torch.mean(r1**2 + r2**2)

def ic_mse(model, x0):
    t0 = torch.zeros_like(x0)
    u, v = model(x0, t0)
    u0, v0 = ic_uv(x0)
    return torch.mean((u-u0)**2 + (v-v0)**2)

def periodic_bc_mse(model, tb):
    xl = torch.full_like(tb, X_MIN, requires_grad=True)
    xr = torch.full_like(tb, X_MAX, requires_grad=True)
    uL, vL = model(xl, tb); uR, vR = model(xr, tb)
    uL_x = torch.autograd.grad(uL, xl, torch.ones_like(uL), create_graph=True)[0]
    vL_x = torch.autograd.grad(vL, xl, torch.ones_like(vL), create_graph=True)[0]
    uR_x = torch.autograd.grad(uR, xr, torch.ones_like(uR), create_graph=True)[0]
    vR_x = torch.autograd.grad(vR, xr, torch.ones_like(vR), create_graph=True)[0]
    val_mse = torch.mean((uL-uR)**2 + (vL-vR)**2)
    der_mse = torch.mean((uL_x-uR_x)**2 + (vL_x-vR_x)**2)
    return val_mse + der_mse

# ----- metrics -----
def rel_l2(a, b, eps=1e-12):
    num = torch.sqrt(torch.mean((a-b)**2))
    den = torch.sqrt(torch.mean(b**2) + eps)
    return (num/den).cpu().item()

@torch.no_grad()
def val_metrics(model, nc=20000, nuv=4096, nb=200, T=T_MAX):
    # PDE RMS
    x_c = (torch.rand(nc,1)*(X_MAX-X_MIN)+X_MIN).to(device)
    t_c = (torch.rand(nc,1)*(T - T_MIN)+T_MIN).to(device)
    with torch.enable_grad():
        u, v, u_t, v_t, _, _, u_xx, v_xx = derivatives(model, x_c, t_c)
        r1 = u_t + 0.5*v_xx + v*(u*u + v*v)
        r2 = v_t - 0.5*u_xx - u*(u*u + v*v)
        err_h_rms = torch.sqrt(torch.mean(r1**2 + r2**2)).detach().cpu().item()

    # u/v vs analytic
    x = (torch.rand(nuv,1)*(X_MAX-X_MIN)+X_MIN).to(device)
    t = (torch.rand(nuv,1)*(T - T_MIN)+T_MIN).to(device)
    up, vp = model(x, t)
    ut, vt = analytic_uv(x, t)
    err_u_mse = torch.mean((up-ut)**2).cpu().item()
    err_v_mse = torch.mean((vp-vt)**2).cpu().item()
    err_u_rel = rel_l2(up, ut)
    err_v_rel = rel_l2(vp, vt)

    tb = (torch.rand(nb,1)*(T - T_MIN)+T_MIN).to(device)
    with torch.enable_grad():
        err_bc = periodic_bc_mse(model, tb).cpu().item()

    return {
        "val/error_h_rms": err_h_rms,
        "val/error_u_mse": err_u_mse,
        "val/error_v_mse": err_v_mse,
        "val/error_u_rel": err_u_rel,
        "val/error_v_rel": err_v_rel,
        "val/loss_mse": err_u_mse + err_v_mse + err_h_rms,
        "val/loss_rel": 0.5*(err_u_rel + err_v_rel),        
        "val/bc": err_bc,
    }

# ----- train -----
def train(neurons=200, epochs=10000, lr=2e-4, n_c=20000, n_i=2048, n_b=200):
    model = PINN_CauchyNet(neurons=neurons).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=400, min_lr=1e-5)

    for ep in range(1, epochs+1):
        # resample
        x_c = (torch.rand(n_c,1)*(X_MAX-X_MIN)+X_MIN).to(device)
        t_c = (torch.rand(n_c,1)*(T_MAX - T_MIN)+T_MIN).to(device)
        x0  = (torch.rand(n_i,1)*(X_MAX-X_MIN)+X_MIN).to(device)
        tb  = (torch.rand(n_b,1)*(T_MAX - T_MIN)+T_MIN).to(device)

        opt.zero_grad()
        loss_pde = pde_residual_mse(model, x_c, t_c)
        loss_ic  = ic_mse(model, x0)
        loss_bc  = periodic_bc_mse(model, tb)
        loss = loss_pde + loss_ic + loss_bc
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        sched.step(loss.detach())

        if ep % 1000 == 0 or ep == 1:
            vm = val_metrics(model)
            print(f"Epoch {ep:5d}/{epochs} | train={loss.item():.3e} "
                  f"| hRMS={vm['val/error_h_rms']:.3e} u_mse={vm['val/error_u_mse']:.3e} v_mse={vm['val/error_v_mse']:.3e} "
                  f"| u_rel={vm['val/error_u_rel']:.3e} v_rel={vm['val/error_v_rel']:.3e} "
                  f"| loss_mse={vm['val/loss_mse']:.3e} loss_rel={vm['val/loss_rel']:.3e} | bc={vm['val/bc']:.3e}")
    return model

# ----- main -----
if __name__ == "__main__":
    model = train(neurons=400, epochs=10000, lr=1e-3,
                  n_c=20000, n_i=2048, n_b=200)

    vm = val_metrics(model)
    print("\n=== Final Validation ===")
    for k, v in vm.items():
        print(f"{k}: {v:.12e}")
