import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import time
import math
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

np.random.seed(1234)
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(1234)

X_MIN, X_MAX = 0.0, 1.0
Y_MIN, Y_MAX = 0.0, 1.0
T_MIN, T_MAX = 0.0, 1.0

def true_u_torch(coords: torch.Tensor, c: float) -> torch.Tensor:
    t = coords[:, 0:1]
    x = coords[:, 1:2]
    y = coords[:, 2:3]
    return torch.sin(math.pi * x) * torch.sin(math.pi * y) * torch.cos(c * math.pi * math.sqrt(2.0) * t)

def sample_interior(N, device=device) -> torch.Tensor:
    t = torch.rand(N, 1, device=device) * (T_MAX - T_MIN) + T_MIN
    x = torch.rand(N, 1, device=device) * (X_MAX - X_MIN) + X_MIN
    y = torch.rand(N, 1, device=device) * (Y_MAX - Y_MIN) + Y_MIN
    return torch.cat([t, x, y], dim=-1)

def sample_ic(N: int, device=device) -> torch.Tensor:
    t = torch.zeros(N, 1, device=device)
    x = torch.rand(N, 1, device=device) * (X_MAX - X_MIN) + X_MIN
    y = torch.rand(N, 1, device=device) * (Y_MAX - Y_MIN) + Y_MIN
    return torch.cat([t, x, y], dim=-1)

def sample_bc(N: int, device=device) -> torch.Tensor:
    t = torch.rand(N, 1, device=device) * (T_MAX - T_MIN) + T_MIN
    x0 = torch.full_like(t, X_MIN); x1 = torch.full_like(t, X_MAX)
    yx = torch.rand(N, 1, device=device) * (Y_MAX - Y_MIN) + Y_MIN
    c_x0 = torch.cat([t, x0, yx], dim=-1)
    c_x1 = torch.cat([t, x1, yx], dim=-1)
    y0 = torch.full_like(t, Y_MIN); y1 = torch.full_like(t, Y_MAX)
    xy = torch.rand(N, 1, device=device) * (X_MAX - X_MIN) + X_MIN
    c_y0 = torch.cat([t, xy, y0], dim=-1)
    c_y1 = torch.cat([t, xy, y1], dim=-1)
    return torch.cat([c_x0, c_x1, c_y0, c_y1], dim=0)

class MLP(nn.Module):
    def __init__(self, in_features=3, hidden=64, num_experts=3, depth=4):
        super().__init__()
        layers = []
        d = in_features
        for _ in range(depth):
            layers += [nn.Linear(d, hidden), nn.Tanh()]
            d = hidden
        layers.append(nn.Linear(d, num_experts))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class NAMLayer(nn.Module):
    def __init__(self, hidden=64, depth=2, r=16):
        super().__init__()
        layers = [nn.Linear(2, hidden), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, r)]
        self.net = nn.Sequential(*layers)
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, scalar_with_id):
        return self.net(scalar_with_id)

class NAMExpert(nn.Module):
    def __init__(self, hidden=32, r=16):
        super().__init__()
        self.r = r
        self.shared = NAMLayer(hidden, r=r)   


    def _eval_dim(self, coord: torch.Tensor, dim_id: int) -> torch.Tensor:
        dim_col = torch.full_like(coord, float(dim_id))
        pair    = torch.cat([coord, dim_col], dim=-1)  # (B,1)->(B,2)
        return self.shared(pair)                       

    def forward(self, coords):
        t = coords[:, 0:1]
        x = coords[:, 1:2]
        y = coords[:, 2:3]
        ft = self._eval_dim(t, 0)   # (B,r)
        fx = self._eval_dim(x, 1)   # (B,r)
        fy = self._eval_dim(y, 2)   # (B,r)

        prod = ft * fx * fy         # (B,r)
        u = prod.sum(dim=-1)        # (B,)
        return u.unsqueeze(-1)      # (B,1)

class DomainMoE(nn.Module):
    def __init__(self, in_features=3, num_experts=3,
                 expert_hidden=32, expert_rank=16):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([
            NAMExpert(hidden=expert_hidden, r=expert_rank) for _ in range(num_experts)
        ])

    def forward(self, x):
        gates = 0                         # (N,E)
        expert_outputs = [exp(x) for exp in self.experts]         # E * (N,1)
        u_pred = expert_outputs[0]
        return u_pred, gates


def _center_unit(M: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    Mc = M - M.mean(axis=0, keepdims=True)
    n  = np.linalg.norm(Mc, axis=0, keepdims=True)
    n  = np.maximum(n, eps)
    return Mc / n

def metricA_subspace_score(F: np.ndarray, G: np.ndarray) -> tuple[float, np.ndarray]:
    Fn = _center_unit(F)
    Gn = _center_unit(G)
    QF, _ = np.linalg.qr(Fn, mode='reduced')
    QG, _ = np.linalg.qr(Gn, mode='reduced')
    Svals = np.linalg.svd(QF.T @ QG, compute_uv=False)
    r_eff = Svals.shape[0]
    S = float(np.mean(Svals[:r_eff]**2))
    return S, Svals

@torch.no_grad()
def visualize_expert_curves(model, expert_idx: int = 0,
                            res: int = 256, f_const: float = 0.0,
                            outdir: str = "expert_vis"):

    os.makedirs(outdir, exist_ok=True)
    exp  = model.experts[expert_idx].eval()
    dev  = next(exp.parameters()).device

    t = torch.linspace(0., 1., res, device=dev).unsqueeze(-1)           # (res,1)
    x = torch.linspace(0., 1, res, device=dev).unsqueeze(-1)
    y = torch.linspace(0., 1, res, device=dev).unsqueeze(-1)

    ft_vals = exp._eval_dim(t, 0).cpu().numpy()                                 # (res,2)
    fx_vals = exp._eval_dim(x, 1).cpu().numpy()
    fy_vals = exp._eval_dim(y, 2).cpu().numpy()


    xs = np.linspace(0, 1, res)
    ts = np.linspace(0, 1, res)
    c = 2
    target_x = np.sin(np.pi*xs)   
    target_y = np.sin(np.pi*xs)   
    target_t = np.cos(ts * c * np.pi)

    Fx = fx_vals                          
    Ft = ft_vals                         
    Fy = fy_vals
    
    Gx = target_x[:, None]                
    Gt = target_t[:, None]
    Gy = target_y[:, None]

    Sx, svals_x = metricA_subspace_score(Fx, Gx)
    St, svals_t = metricA_subspace_score(Ft, Gt)
    Sy, Svals_y = metricA_subspace_score(Fy, Gy)
    S_all = 1/3 * (Sx + St + Sy)

    # print(f"[metric-A] S_x={Sx:.4f} (singvals={np.round(svals_x,4)})")
    # print(f"[metric-A] S_t={St:.4f} (singvals={np.round(svals_t,4)})")
    print(f"[metric-A] S_avg={S_all:.4f}")
    
    # for name, arr, axis_vals in zip(("ft", "fx"), (ft_vals, fx_vals), (ts, xs)):
    #     plt.figure(figsize=(4, 3))

    #     for j in range(arr.shape[1]):
    #         plt.plot(axis_vals, arr[:, j], label=f"rank {j}")
        
    #     plt.xlabel("t" if name == "ft" else "x")
    #     plt.ylabel(f"{name}(t)" if name == "ft" else f"{name}(x)")
    #     plt.title(f"Expert {expert_idx} — {name}(t)")
    #     plt.legend()
    #     plt.tight_layout()

    #     fname = os.path.join(outdir, f"{name}_expert{expert_idx}.png")
    #     plt.savefig(fname, dpi=300)
    #     plt.close()
    #     print(f"[vis] saved {fname}")
    
class PINN_Wave2D:
    def __init__(self,
                 c=2.0,
                 num_experts=2,
                 expert_hidden=32,
                 expert_rank=16):
        self.c = c

        self.w_f = 1.0
        self.w_ic_u  = 10.0
        self.w_ic_ut = 10.0
        self.w_bc = 10.0

        self.N_interior = 8192
        self.N_ic = 4096
        self.N_bc = 4096

        self.pinn = DomainMoE(in_features=3,
                              num_experts=num_experts,
                              expert_hidden=expert_hidden,
                              expert_rank=expert_rank).to(device)

        self.optimizer_adam = torch.optim.Adam(self.pinn.parameters(), lr=5e-4)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_adam, T_max=20000, eta_min=1e-6)
        self.optimizer_lbfgs = torch.optim.LBFGS(
            self.pinn.parameters(), max_iter=20000,
            tolerance_grad=1e-9, tolerance_change=1e-12,
            history_size=100, line_search_fn="strong_wolfe"
        )

    @staticmethod
    def _autograd_grads(u, coords):
        grad_outputs = torch.ones_like(u)
        du_dcoords = torch.autograd.grad(u, coords, grad_outputs=grad_outputs, create_graph=True)[0]  # (B,3)
        u_t = du_dcoords[:, 0:1]
        u_x = du_dcoords[:, 1:2]
        u_y = du_dcoords[:, 2:3]
        u_tt = torch.autograd.grad(u_t, coords, grad_outputs=grad_outputs, create_graph=True)[0][:, 0:1]
        u_xx = torch.autograd.grad(u_x, coords, grad_outputs=grad_outputs, create_graph=True)[0][:, 1:2]
        u_yy = torch.autograd.grad(u_y, coords, grad_outputs=grad_outputs, create_graph=True)[0][:, 2:3]
        return u_t, u_x, u_y, u_tt, u_xx, u_yy

    def pde_residual(self, coords):
        coords_grad = coords.clone().detach().requires_grad_(True)
        u_pred, _ = self.pinn(coords_grad)
        _, _, _, u_tt, u_xx, u_yy = self._autograd_grads(u_pred, coords_grad)
        residual = u_tt - (self.c**2) * (u_xx + u_yy)
        return residual, u_pred

    def loss_func(self, coords_f=None, coords_ic=None, coords_bc=None):
        if coords_f is None:
            coords_f = sample_interior(self.N_interior, device=device)
        if coords_ic is None:
            coords_ic = sample_ic(self.N_ic, device=device)
        if coords_bc is None:
            coords_bc = sample_bc(self.N_bc, device=device)

        self.pinn.train()

        resid, u_for_err_calc = self.pde_residual(coords_f)
        loss_f = F.mse_loss(resid, torch.zeros_like(resid))

        coords_ic_grad = coords_ic.clone().detach().requires_grad_(True)
        u_ic_pred, _ = self.pinn(coords_ic_grad)
        u_t_ic_pred, _, _, _, _, _ = self._autograd_grads(u_ic_pred, coords_ic_grad)

        u_ic_true = torch.sin(math.pi * coords_ic[:, 1:2]) * torch.sin(math.pi * coords_ic[:, 2:3])
        loss_ic_u = F.mse_loss(u_ic_pred, u_ic_true)
        loss_ic_ut = F.mse_loss(u_t_ic_pred, torch.zeros_like(u_t_ic_pred))

        u_bc_pred, _ = self.pinn(coords_bc)
        loss_bc = F.mse_loss(u_bc_pred, torch.zeros_like(u_bc_pred))

        loss = self.w_f * loss_f + self.w_ic_u * loss_ic_u + self.w_ic_ut * loss_ic_ut + self.w_bc * loss_bc

        true = true_u_torch(coords_f, self.c)
        error = torch.norm(u_for_err_calc - true, p=2) / torch.norm(true, p=2)

        return loss, loss_f, loss_ic_u, loss_ic_ut, loss_bc, error

    def train(self, n_epochs_adam=10000, viz_every=1000, t_vis=0.5):
        print("--- Starting Adam Optimization ---")
        for ep in range(1, n_epochs_adam + 1):
            self.optimizer_adam.zero_grad()
            loss, loss_f, loss_ic_u, loss_ic_ut, loss_bc, error = self.loss_func()
            loss.backward()
            self.optimizer_adam.step()
            self.scheduler.step()

            if ep % 1000 == 0:
                print(f"[Adam {ep:05d}] loss={loss.item():.4e} "
                      f"(ic_u={loss_ic_u.item():.2e}, ic_ut={loss_ic_ut.item():.2e}, "
                      f"bc={loss_bc.item():.2e}, err={error.item():.2e})")

                # print(f"[Adam {ep:05d}] loss={float(loss):.4e} "
                #       f"(ic_u={float(loss_ic_u):.2e}, ic_ut={float(loss_ic_ut):.2e}, bc={float(loss_bc):.2e}, err={error:.2e})")

            if viz_every and ep % viz_every == 0:
                visualize_expert_curves(model=self.pinn, expert_idx=0, res = 256, f_const = 3.0)
                self.visualize(ep, t_fixed=t_vis)


        print("\n--- Starting L-BFGS Optimization ---")
        self.pinn.train()

        coords_f_lbfgs = sample_interior(20_000, device=device)
        coords_ic_lbfgs = sample_ic(5_000, device=device)
        coords_bc_lbfgs = sample_bc(5_000, device=device)

        self.lbfgs_iter = 0
        def closure():
            self.optimizer_lbfgs.zero_grad()
            loss, *_ = self.loss_func(
                coords_f=coords_f_lbfgs,
                coords_ic=coords_ic_lbfgs,
                coords_bc=coords_bc_lbfgs
            )
            loss.backward()
            self.lbfgs_iter += 1
            if self.lbfgs_iter % 100 == 0:
                with torch.no_grad():
                    u_pred, _ = self.pinn(coords_f_lbfgs)              # 纯前向
                    true = true_u_torch(coords_f_lbfgs, self.c)
                    error = torch.norm(u_pred - true, p=2) / torch.norm(true, p=2)
                print(f'[L-BFGS {self.lbfgs_iter:05d}] loss={loss.item():.4e} (L2_err={error.item():.2e})')

            return loss

        self.optimizer_lbfgs.step(closure)

        X, Y, U, U_true, final_error = self.predict_slice_xy(t_fixed=0.5)
        print(f"\n--- Optimization Finished ---")
        print(f"Final L2 Relative Error (t=0.5 slice): {final_error:.4e}")

    @torch.no_grad()
    def predict_slice_xy(self, t_fixed=0.5, res_x=141, res_y=141):
        self.pinn.eval()
        x_vals = np.linspace(X_MIN, X_MAX, res_x)
        y_vals = np.linspace(Y_MIN, Y_MAX, res_y)
        X, Y = np.meshgrid(x_vals, y_vals, indexing="xy")  # (res_y,res_x)

        t_col = np.full((res_x * res_y, 1), t_fixed, dtype=np.float32)
        coords = np.concatenate([t_col, X.reshape(-1,1), Y.reshape(-1,1)], axis=-1).astype(np.float32)
        coords_t = torch.from_numpy(coords).to(device)

        u_pred, _ = self.pinn(coords_t)
        U = u_pred.detach().cpu().numpy().reshape(res_y, res_x)

        U_true = (np.sin(np.pi * X) * np.sin(np.pi * Y) * np.cos(self.c * np.pi * np.sqrt(2.0) * t_fixed)).astype(np.float32)
        error = np.linalg.norm(U - U_true) / np.linalg.norm(U_true)

        return X, Y, U, U_true, error

    @torch.no_grad()
    def visualize(self, ep=0, out_dir="wave2d_viz", t_fixed=0.5):
        os.makedirs(out_dir, exist_ok=True)
        X, Y, U, U_true, error = self.predict_slice_xy(t_fixed=t_fixed)
        print(f"  [viz @ ep {ep}] L2 Rel Error (t={t_fixed}): {error:.4e}. Saving plots...")

        fig = plt.figure(figsize=(15, 4.5))

        plt.subplot(1, 3, 1)
        plt.pcolormesh(X, Y, U_true, cmap='rainbow', shading='auto', vmin=-1, vmax=1)
        plt.colorbar(); plt.xlabel('x'); plt.ylabel('y'); plt.title(f'True u(x,y) @ t={t_fixed}')

        plt.subplot(1, 3, 2)
        plt.pcolormesh(X, Y, U, cmap='rainbow', shading='auto', vmin=-1, vmax=1)
        plt.colorbar(); plt.xlabel('x'); plt.ylabel('y'); plt.title(f'Predicted u(x,y) @ t={t_fixed}')

        # plt.subplot(1, 3, 2)
        # abs_err = np.abs(U_true - U)
        # plt.pcolormesh(X, Y, abs_err, cmap='rainbow', shading='auto')
        # plt.colorbar(); plt.xlabel('x'); plt.ylabel('y'); plt.title('Abs Error')

        plt.subplot(1, 3, 3)
        sq = (U - U_true)**2
        plt.pcolormesh(X, Y, np.sqrt(sq), cmap='jet', shading='auto')
        plt.colorbar(); plt.xlabel('x'); plt.ylabel('y'); plt.title('L2 Relative Error')

        plt.tight_layout()
        png = os.path.join(out_dir, f"wave2d_t{t_fixed}_ep{ep:05d}.png")
        plt.savefig(png, dpi=300)
        plt.close(fig)

if __name__ == '__main__':

    num_experts   = 1
    expert_hidden = 64
    expert_rank   = 5

    C_SPEED = [2.0]   

    for c in C_SPEED:
        model = PINN_Wave2D(
            c=c,
            num_experts=num_experts,
            expert_hidden=expert_hidden,
            expert_rank=expert_rank,
        )

        model.pinn.to(device)  # DomainMoE
        total_trainable = sum(p.numel() for p in model.pinn.parameters() if p.requires_grad)
        total_all       = sum(p.numel() for p in model.pinn.parameters())
        print(f"Trainable: {total_trainable:,}  |  All: {total_all:,}")
        start = time.time()
        model.train(n_epochs_adam=10000, viz_every=2000, t_vis=0.50)
        print(f"Training time: {time.time()-start:.2f}s")
        model.visualize(ep=99999, out_dir=f"wave2d_viz_c={c}", t_fixed=0.25)
        model.visualize(ep=99999, out_dir=f"wave2d_viz_c={c}", t_fixed=0.50)
        model.visualize(ep=99999, out_dir=f"wave2d_viz_c={c}", t_fixed=0.75)
        visualize_expert_curves(model=model.pinn, expert_idx=0, res = 256, f_const = 3.0)
