import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
import math
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

np.random.seed(1234)
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(1234)

DIM = 10
X_MIN, X_MAX = 0.0, 1.0
PI = math.pi

def true_u_torch(coords: torch.Tensor) -> torch.Tensor:

    return torch.prod(torch.sin(PI * coords), dim=1, keepdim=True)

@torch.no_grad()
def sample_interior(N: int, device=device) -> torch.Tensor:
    return torch.rand(N, DIM, device=device) * (X_MAX - X_MIN) + X_MIN

@torch.no_grad()
def sample_boundary(N: int, device=device) -> torch.Tensor:

    pts = torch.rand(N, DIM, device=device)
    face_dim = torch.randint(low=0, high=DIM, size=(N,), device=device)
    face_side = torch.randint(low=0, high=2, size=(N,), device=device)  # 0 -> 0.0, 1 -> 1.0
    idx = torch.arange(N, device=device)
    pts[idx, face_dim] = face_side.float()
    return pts

class MLP(nn.Module):
    def __init__(self, in_features=DIM, hidden=64, num_experts=3, depth=4):
        super().__init__()
        layers = []
        d = in_features
        for _ in range(depth):
            layers += [nn.Linear(d, hidden), nn.Tanh()]
            d = hidden
        layers.append(nn.Linear(d, num_experts))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class NAMLayer(nn.Module):

    def __init__(self, hidden=64, depth=2, r=16):
        super().__init__()
        layers = [nn.Linear(2, hidden), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, r)]
        self.net = nn.Sequential(*layers)

        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, val_and_id):
        return self.net(val_and_id)

class NAMExpert(nn.Module):

    def __init__(self, hidden=32, r=16, dim=DIM):
        super().__init__()
        self.r = r
        self.dim = dim
        self.shared = NAMLayer(hidden, r=r)

    def _eval_dim(self, coord_col: torch.Tensor, dim_id: int) -> torch.Tensor:
        dim_col = torch.full_like(coord_col, float(dim_id))
        pair = torch.cat([coord_col, dim_col], dim=-1)   # (N,2)
        return self.shared(pair)      
        

    def forward(self, coords):
        outs = []
        for j in range(self.dim):
            col = coords[:, j:j+1]           # (N,1)
            fj = self._eval_dim(col, j)      # (N,r)
            outs.append(fj)
        prod = outs[0]
        for j in range(1, self.dim):
            prod = prod * outs[j]            # (N,r)
        u = prod.sum(dim=-1, keepdim=True)   # (N,1)
        return u

class DomainMoE(nn.Module):
    def __init__(self, in_features=DIM, num_experts=3, expert_hidden=32, expert_rank=16, router_hidden=64, router_depth=2):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([NAMExpert(hidden=expert_hidden, r=expert_rank, dim=in_features) for _ in range(num_experts)])
        
    def forward(self, x):
        gates = 0                        # (N,E)
        expert_outputs = [exp(x) for exp in self.experts]         # E * (N,1)
        u_pred = expert_outputs[0]
        return u_pred, gates

def _center_unit(M: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    Mc = M - M.mean(axis=0, keepdims=True)
    n  = np.linalg.norm(Mc, axis=0, keepdims=True)
    n  = np.maximum(n, eps)
    return Mc / n

def metricA_subspace_score(F: np.ndarray, G: np.ndarray) -> tuple[float, np.ndarray]:
    Fn = _center_unit(F)
    Gn = _center_unit(G)
    QF, _ = np.linalg.qr(Fn, mode='reduced')
    QG, _ = np.linalg.qr(Gn, mode='reduced')
    Svals = np.linalg.svd(QF.T @ QG, compute_uv=False)
    r_eff = Svals.shape[0]
    S = float(np.mean(Svals[:r_eff]**2))
    return S, Svals

@torch.no_grad()
def visualize_expert_curves(model, expert_idx: int = 0,
                            res: int = 256, f_const: float = 0.0,
                            outdir: str = "expert_vis"):

    os.makedirs(outdir, exist_ok=True)
    exp  = model.experts[expert_idx].eval()
    dev  = next(exp.parameters()).device

    x = torch.linspace(0., 1., res, device=dev).unsqueeze(-1)

    f1_vals = exp._eval_dim(x, 0).cpu().numpy()  
    f2_vals = exp._eval_dim(x, 1).cpu().numpy()  
    f3_vals = exp._eval_dim(x, 2).cpu().numpy()  
    f4_vals = exp._eval_dim(x, 3).cpu().numpy()  
    f5_vals = exp._eval_dim(x, 4).cpu().numpy() 
    f6_vals = exp._eval_dim(x, 5).cpu().numpy() 
    f7_vals = exp._eval_dim(x, 6).cpu().numpy() 
    f8_vals = exp._eval_dim(x, 7).cpu().numpy() 
    f9_vals = exp._eval_dim(x, 8).cpu().numpy() 
    f10_vals = exp._eval_dim(x, 9).cpu().numpy() 
    

    xs = np.linspace(0, 1, res)
    target = np.sin(np.pi * xs)

    Fx1 = f1_vals
    Fx2 = f2_vals
    Fx3 = f3_vals
    Fx4 = f4_vals
    Fx5 = f5_vals
    Fx6 = f6_vals
    Fx7 = f7_vals
    Fx8 = f8_vals
    Fx9 = f9_vals
    Fx10 = f10_vals

    Gx = target[:, None]

    Sx1, svals_x1 = metricA_subspace_score(Fx1, Gx)
    Sx2, svals_x2 = metricA_subspace_score(Fx2, Gx)
    Sx3, svals_x3 = metricA_subspace_score(Fx3, Gx)
    Sx4, svals_x4 = metricA_subspace_score(Fx4, Gx)
    Sx5, svals_x5 = metricA_subspace_score(Fx5, Gx)
    Sx6, svals_x6 = metricA_subspace_score(Fx6, Gx)
    Sx7, svals_x7 = metricA_subspace_score(Fx7, Gx)
    Sx8, svals_x8 = metricA_subspace_score(Fx8, Gx)
    Sx9, svals_x9 = metricA_subspace_score(Fx9, Gx)
    Sx10, svals_x10 = metricA_subspace_score(Fx10, Gx)
    
    S_all = 0.1 * (Sx1 + Sx2 + Sx3 + Sx4 + Sx5 + Sx6 + Sx7 + Sx8 + Sx9 + Sx10)

    # print(f"[metric-A] S_x={Sx1:.4f} (singvals={np.round(svals_x1,4)})")
    # print(f"[metric-A] S_t={Sx2:.4f} (singvals={np.round(svals_x2,4)})")
    # print(f"[metric-A] S_y={Sx3:.4f} (singvals={np.round(svals_x3,4)})")
    # print(f"[metric-A] S_z={Sx4:.4f} (singvals={np.round(svals_x4,4)})")
    # print(f"[metric-A] S_w={Sx5:.4f} (singvals={np.round(svals_x5,4)})")
    print(f"[metric-A] S_avg={S_all:.4f}")


class PINN_Poisson:
    def __init__(self, num_experts=3, expert_hidden=32, expert_rank=16, router_hidden=64, router_depth=4):

        self.w_bc = 5000.0
        self.w_f_initial = 0.01
        self.w_f_final   = 1.0

        self.N_interior = 8192
        self.N_bc = 2048     # 边界点数量

        self.pinn = DomainMoE(in_features=DIM,
                              num_experts=num_experts, expert_hidden=expert_hidden, expert_rank=expert_rank,
                              router_hidden=router_hidden, router_depth=router_depth).to(device)
        
        self.optimizer_adam = torch.optim.Adam(self.pinn.parameters(), lr=5e-4)
        # self.optimizer_adam = torch.optim.Adam(
        #     (p for p in self.pinn.parameters() if p.requires_grad), lr=1e-4
        # )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_adam, T_max=20000, eta_min=1e-6)
        
        self.optimizer_lbfgs = torch.optim.LBFGS(
            self.pinn.parameters(), max_iter=20000, 
            tolerance_grad=1e-9, tolerance_change=1e-12,
            history_size=100, line_search_fn="strong_wolfe"
        )

    def _normalize(self, coords):
        return 2.0 * (coords - X_MIN) / (X_MAX - X_MIN) - 1.0

    @staticmethod
    def _autograd_grads(u, coords):
        return torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True)[0]

    def pde_residual(self, coords):

        coords_with_grad = coords.clone().detach().requires_grad_(True)
        coords_norm = self._normalize(coords_with_grad)
        u_pred, _ = self.pinn(coords_norm)                    # (N,1)

        du = self._autograd_grads(u_pred, coords_with_grad)   # (N,5)

        lap_terms = []
        for d in range(DIM):
            dudx_d = torch.autograd.grad(du[:, d:d+1], coords_with_grad,
                                         grad_outputs=torch.ones_like(du[:, d:d+1]),
                                         create_graph=True)[0][:, d:d+1]
            lap_terms.append(dudx_d)
        lap = sum(lap_terms)                                   # (N,1)

        f_rhs = DIM*(PI**2) * torch.prod(torch.sin(PI*coords_with_grad), dim=1, keepdim=True)
        resid = -lap - f_rhs
        return resid, u_pred

    def loss_func(self, coords_f=None, coords_ic=None, periodic_pairs=None):

        if coords_f is None:
            coords_f = sample_interior(self.N_interior, device=device)
        if coords_ic is None:
            coords_ic = sample_boundary(self.N_bc, device=device)

        # PDE Residual
        resid, u_on_f_pts = self.pde_residual(coords_f)
        loss_f = F.mse_loss(resid, torch.zeros_like(resid))

        # Dirichlet: u=0
        coords_bc_norm = self._normalize(coords_ic)
        u_bc_pred, _ = self.pinn(coords_bc_norm)
        loss_bc = F.mse_loss(u_bc_pred, torch.zeros_like(u_bc_pred))

        with torch.no_grad():
            u_true_f = true_u_torch(coords_f)
            error = torch.norm(u_on_f_pts - u_true_f) / torch.norm(u_true_f)

        return loss_f, loss_bc, error  

    def train(self, n_epochs_adam=10000, viz_every=2000):
        print("--- Starting Adam Optimization ---")
        anneal_epochs = n_epochs_adam * 0.75

        for ep in range(1, n_epochs_adam + 1):
            self.pinn.train()
            # w_f = self.w_f_initial + (self.w_f_final - self.w_f_initial) * min(ep / anneal_epochs, 1.0)
            w_f = self.w_f_final

            self.optimizer_adam.zero_grad()
            loss_f, loss_bc, error = self.loss_func()
            loss = w_f * loss_f + self.w_bc * loss_bc
            loss.backward()
            self.optimizer_adam.step()
            self.scheduler.step()

            if ep % 1000 == 0:
                print(f"{ep:05d} training time: {time.time()-start:.2f}s")
                
                print(f"[Adam {ep:05d}] loss={loss.item():.4e} w_f={w_f:.2f} "
                      f"(f={loss_f.item():.2e}, bc={loss_bc.item():.2e}, L2_err={error.item():.2e})")

            if viz_every and ep % viz_every == 0 and ep > 0:
                # visualize_expert_curves(model=self.pinn, expert_idx=0, res=256)
                self.visualize(ep)

        print("\n--- Starting L-BFGS Optimization ---")
        self.pinn.train()
        coords_f_lbfgs = sample_interior(20000, device=device)
        coords_bc_lbfgs = sample_boundary(5000, device=device)

        self.lbfgs_iter = 0
        def closure():
            self.optimizer_lbfgs.zero_grad()
            loss_f, loss_bc, error = self.loss_func(
                coords_f=coords_f_lbfgs,
                coords_ic=coords_bc_lbfgs,
                periodic_pairs=None
            )
            loss = self.w_f_final * loss_f + self.w_bc * loss_bc
            loss.backward()
            self.lbfgs_iter += 1
            if self.lbfgs_iter % 100 == 0:
                print(f'[L-BFGS {self.lbfgs_iter:05d}] loss={loss.item():.4e} (L2_err={error.item():.2e})')
            return loss
        
        self.optimizer_lbfgs.step(closure)
        
        final_error = self.evaluate_rel_l2(20000)
        print(f"\n--- Optimization Finished ---")
        print(f"Final L2 Relative Error: {final_error:.4e}")

    @torch.no_grad()
    def evaluate_rel_l2(self, N=50000):
        self.pinn.eval()
        coords = sample_interior(N, device=device)
        coords_norm = self._normalize(coords)
        u_pred, _ = self.pinn(coords_norm)
        u_true = true_u_torch(coords)
        err = torch.linalg.norm(u_pred - u_true) / torch.linalg.norm(u_true)
        return float(err.item())

    @torch.no_grad()
    def predict_on_grid(self, res1=161, res2=161, fixed_val=0.5):

        self.pinn.eval()
        x1_vals = np.linspace(X_MIN, X_MAX, res1)
        x2_vals = np.linspace(X_MIN, X_MAX, res2)
        X1, X2 = np.meshgrid(x1_vals, x2_vals, indexing="ij")
        n = X1.size
        fixed_val = 0.5
        fixed_block = np.full((n, 8), fixed_val, dtype=X1.dtype)  
        X = np.concatenate([
                np.stack([X1.ravel(), X2.ravel()], axis=-1),  
                fixed_block                                   
            ], axis=-1) 
        coords_t = torch.from_numpy(X).float().to(device)
        u_pred, gates = self.pinn(self._normalize(coords_t))
        u_np = u_pred.detach().cpu().numpy().reshape(res1, res2)
        # g_np = gates.detach().cpu().numpy().reshape(res1, res2, -1)
        u_true = np.prod(np.sin(np.pi * X).reshape(res1, res2, DIM), axis=-1)
        error = np.linalg.norm(u_np - u_true) / np.linalg.norm(u_true)
        return (X1, X2, u_np, u_true, error)

    @torch.no_grad()
    def visualize(self, ep=0, out_dir="poisson5d_viz"):
        os.makedirs(out_dir, exist_ok=True)
        X1, X2, U, U_true, error = self.predict_on_grid()
        print(f"  [viz @ ep {ep}] L2 Rel Error (slice): {error:.4e}. Saving plots...")

        # fig = plt.figure(figsize=(6, 5))
        # plt.pcolormesh(X1, X2, U_true, cmap='rainbow', shading='auto')
        # cbar = plt.colorbar()
        # cbar.ax.tick_params(labelsize=14)
        
        # plt.xlabel('x1', fontsize=18, labelpad=6)
        # plt.ylabel('x2', fontsize=18, labelpad=6)
        # plt.tick_params(axis='both', labelsize=14)
        
        # plt.title('Exact u(t,x)', fontsize=16, pad=6)
        # plt.tight_layout()

        fig = plt.figure(figsize=(18, 5))
        ax1 = plt.subplot(1, 3, 1)
        ax2 = plt.subplot(1, 3, 2)
        ax3 = plt.subplot(1, 3, 3)

        im1 = ax1.pcolormesh(X1, X2, U, cmap='rainbow', shading='auto')
        fig.colorbar(im1, ax=ax1)
        ax1.set_xlabel('x1'); ax1.set_ylabel('x2'); ax1.set_title('Pred u (slice)')

        im2 = ax2.pcolormesh(X1, X2, np.abs(U - U_true), cmap='rainbow', shading='auto')
        fig.colorbar(im2, ax=ax2)
        ax2.set_xlabel('x1'); ax2.set_ylabel('x2'); ax2.set_title('Abs Error (slice)')

        l2_map = (U - U_true)**2
        error = np.sqrt(l2_map)
        im3 = ax3.pcolormesh(X1, X2, error, cmap='jet', shading='auto')
        fig.colorbar(im3, ax=ax3)
        ax3.set_xlabel('x1'); ax3.set_ylabel('x2'); ax3.set_title('L2 Relative Error (slice)')

        plt.tight_layout()
        png1 = os.path.join(out_dir, f"poisson5d_slice_ep{ep:05d}.png")
        plt.savefig(png1, dpi=250)
        plt.close(fig)


if __name__ == '__main__':
    num_experts = 1
    expert_hidden = 64
    expert_rank = 16
    router_hidden = 64
    router_depth  = 4

    model = PINN_Poisson(
        num_experts=num_experts, expert_hidden=expert_hidden, expert_rank=expert_rank,
        router_hidden=router_hidden, router_depth=router_depth
    )


    model.pinn.to(device)  # DomainMoE
    total_trainable = sum(p.numel() for p in model.pinn.parameters() if p.requires_grad)
    total_all       = sum(p.numel() for p in model.pinn.parameters())
    print(f"Trainable: {total_trainable:,}  |  All: {total_all:,}")

    start = time.time()
    model.train(n_epochs_adam=10000, viz_every=2000)
    print(f"Total training time: {time.time()-start:.2f}s")
    model.visualize(ep=99999)  # Final visualization
    # visualize_expert_curves(model=model.pinn, expert_idx=0, res=256)
