import math, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ----- setup -----
def set_seed(s=0):
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(0)
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# ----- problem -----
L = 5.0
X_MIN, X_MAX = -L, L
T_MIN, T_MAX = 0.0, 2.0
def sech(x): return 1.0/torch.cosh(x)

def analytic_uv(x, t, eta=2.0):
    g = eta * sech(eta * x)          # = 2*sech(2x)
    phase = 0.5 * eta * eta * t      # = 2 t
    return g*torch.cos(phase), g*torch.sin(phase)

def ic_uv(x, eta=2.0):
    return eta * sech(eta * x), torch.zeros_like(x)

class SchrodingerPINN(nn.Module):
    """
    i*ħ*∂ψ/∂t = -ħ²/(2m)*∂²ψ/∂x² + V(x)*ψ
    
    non-dimensional: i*∂ψ/∂t = -1/2*∂²ψ/∂x² + V(x)*ψ
    """
    
    def __init__(self, layers):
        super(SchrodingerPINN, self).__init__()
        
        self.layers = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))
        
        for layer in self.layers:
            nn.init.xavier_normal_(layer.weight)
            nn.init.zeros_(layer.bias)
    
    def forward(self, x, t):
        
       
        x_n = 2.0*(x - 0.5*(X_MIN+X_MAX)) / (X_MAX - X_MIN)
        t_n = 2.0*(t - 0.5*(T_MIN+T_MAX)) / (T_MAX - T_MIN)
        h = torch.cat([x_n, t_n], dim=1)

        for layer in self.layers[:-1]:
            h = torch.tanh(layer(h))
        out = self.layers[-1](h)
        u, v = out[:, :1], out[:, 1:2]
        return u, v
        


# ----- derivatives -----
def derivatives(model, x, t):
    x.requires_grad_(True); t.requires_grad_(True)
    u, v = model(x, t)
    u_t = torch.autograd.grad(u, t, torch.ones_like(u), create_graph=True)[0]
    v_t = torch.autograd.grad(v, t, torch.ones_like(v), create_graph=True)[0]
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
    v_x = torch.autograd.grad(v, x, torch.ones_like(v), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u_x), create_graph=True)[0]
    v_xx = torch.autograd.grad(v_x, x, torch.ones_like(v_x), create_graph=True)[0]
    return u, v, u_t, v_t, u_x, v_x, u_xx, v_xx

# ----- losses -----
def pde_residual_mse(model, x_c, t_c):
    u, v, u_t, v_t, _, _, u_xx, v_xx = derivatives(model, x_c, t_c)
    r1 = u_t + 0.5*v_xx + v*(u*u + v*v)
    r2 = v_t - 0.5*u_xx - u*(u*u + v*v)
    return torch.mean(r1**2 + r2**2)

def ic_mse(model, x0):
    t0 = torch.zeros_like(x0)
    u, v = model(x0, t0)
    u0, v0 = ic_uv(x0)
    return torch.mean((u-u0)**2 + (v-v0)**2)

def periodic_bc_mse(model, tb):
    xl = torch.full_like(tb, X_MIN, requires_grad=True)
    xr = torch.full_like(tb, X_MAX, requires_grad=True)
    uL, vL = model(xl, tb); uR, vR = model(xr, tb)
    uL_x = torch.autograd.grad(uL, xl, torch.ones_like(uL), create_graph=True)[0]
    vL_x = torch.autograd.grad(vL, xl, torch.ones_like(vL), create_graph=True)[0]
    uR_x = torch.autograd.grad(uR, xr, torch.ones_like(uR), create_graph=True)[0]
    vR_x = torch.autograd.grad(vR, xr, torch.ones_like(vR), create_graph=True)[0]
    val_mse = torch.mean((uL-uR)**2 + (vL-vR)**2)
    der_mse = torch.mean((uL_x-uR_x)**2 + (vL_x-vR_x)**2)
    return val_mse + der_mse

# ----- metrics -----
def rel_l2(a, b, eps=1e-12):
    num = torch.sqrt(torch.mean((a-b)**2))
    den = torch.sqrt(torch.mean(b**2) + eps)
    return (num/den).cpu().item()

@torch.no_grad()
def val_metrics(model, nc=20000, nuv=4096, nb=200, T=T_MAX):
    # PDE RMS
    x_c = (torch.rand(nc,1)*(X_MAX-X_MIN)+X_MIN).to(device)
    t_c = (torch.rand(nc,1)*(T - T_MIN)+T_MIN).to(device)
    with torch.enable_grad():
        u, v, u_t, v_t, _, _, u_xx, v_xx = derivatives(model, x_c, t_c)
        r1 = u_t + 0.5*v_xx + v*(u*u + v*v)
        r2 = v_t - 0.5*u_xx - u*(u*u + v*v)
        err_h_rms = torch.sqrt(torch.mean(r1**2 + r2**2)).detach().cpu().item()

    # u/v vs analytic
    x = (torch.rand(nuv,1)*(X_MAX-X_MIN)+X_MIN).to(device)
    t = (torch.rand(nuv,1)*(T - T_MIN)+T_MIN).to(device)
    up, vp = model(x, t)
    ut, vt = analytic_uv(x, t)
    err_u_mse = torch.mean((up-ut)**2).cpu().item()
    err_v_mse = torch.mean((vp-vt)**2).cpu().item()
    err_u_rel = rel_l2(up, ut)
    err_v_rel = rel_l2(vp, vt)

    tb = (torch.rand(nb,1)*(T - T_MIN)+T_MIN).to(device)
    with torch.enable_grad():
        err_bc = periodic_bc_mse(model, tb).cpu().item()

    return {
        "val/error_h_rms": err_h_rms,
        "val/error_u_mse": err_u_mse,
        "val/error_v_mse": err_v_mse,
        "val/error_u_rel": err_u_rel,
        "val/error_v_rel": err_v_rel,
        "val/loss_mse": err_u_mse + err_v_mse + err_h_rms,
        "val/loss_rel": 0.5*(err_u_rel + err_v_rel),
        "val/bc": err_bc,
    }

# ----- train -----
def train(neurons=200, epochs=10000, lr=2e-4, n_c=20000, n_i=2048, n_b=200):
    layers = [2, 100, 100, 100, 100, 2]  
    model = SchrodingerPINN(layers).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=400, min_lr=1e-5)

    for ep in range(1, epochs+1):
        # resample
        x_c = (torch.rand(n_c,1)*(X_MAX-X_MIN)+X_MIN).to(device)
        t_c = (torch.rand(n_c,1)*(T_MAX - T_MIN)+T_MIN).to(device)
        x0  = (torch.rand(n_i,1)*(X_MAX-X_MIN)+X_MIN).to(device)
        tb  = (torch.rand(n_b,1)*(T_MAX - T_MIN)+T_MIN).to(device)

        opt.zero_grad()
        loss_pde = pde_residual_mse(model, x_c, t_c)
        loss_ic  = ic_mse(model, x0)
        loss_bc  = periodic_bc_mse(model, tb)
        loss = loss_pde + loss_ic + loss_bc
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        sched.step(loss.detach())

        if ep % 1000 == 0 or ep == 1:
            vm = val_metrics(model)
            print(f"Epoch {ep:5d}/{epochs} | train={loss.item():.3e} "
                  f"| hRMS={vm['val/error_h_rms']:.3e} u_mse={vm['val/error_u_mse']:.3e} v_mse={vm['val/error_v_mse']:.3e} "
                  f"| u_rel={vm['val/error_u_rel']:.3e} v_rel={vm['val/error_v_rel']:.3e} "
                  f"| loss_mse={vm['val/loss_mse']:.3e} loss_rel={vm['val/loss_rel']:.3e} | bc={vm['val/bc']:.3e}")
    return model

# ----- main -----
if __name__ == "__main__":
    model = train(neurons=400, epochs=10000, lr=1e-3,
                  n_c=20000, n_i=2048, n_b=200)

    vm = val_metrics(model)
    print("\n=== Final Validation ===")
    for k, v in vm.items():
        print(f"{k}: {v:.12e}")
