import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
import math
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

np.random.seed(1234)
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(1234)

# Using a slightly lower speed makes the problem easier to learn
C_SPEED = 10.0
X_MIN, X_MAX = 0.0, 4.0
T_MIN, T_MAX = 0.0, 1.0

def u0_piecewise_torch(x: torch.Tensor) -> torch.Tensor:
    u = torch.ones_like(x)
    # Define the condition for the "top hat" part of the function
    condition = (x >= 1.0) & (x < 3.0)
    u[condition] = 2.0
    return u

def true_u_torch(coords: torch.Tensor, c: float) -> torch.Tensor:
    t = coords[:, 0:1]
    x = coords[:, 1:2]
    # Apply periodic boundary condition to the translated coordinate
    x_tran = (x - c * t) % X_MAX
    return u0_piecewise_torch(x_tran)

def sample_interior(N, device=device) -> torch.Tensor:
    t = torch.rand(N, 1, device=device) * (T_MAX - T_MIN) + T_MIN
    x = torch.rand(N, 1, device=device) * (X_MAX - X_MIN) + X_MIN
    return torch.cat([t, x], dim=-1)

def sample_ic(N: int, device=device) -> torch.Tensor:
    t = torch.zeros(N, 1, device=device)
    x = torch.rand(N, 1, device=device) * (X_MAX - X_MIN) + X_MIN
    return torch.cat([t, x], dim=-1)

def sample_periodic_pairs(N: int, device=device):
    t = torch.rand(N, 1, device=device) * (T_MAX - T_MIN) + T_MIN
    x0 = torch.zeros_like(t) + X_MIN
    x1 = torch.zeros_like(t) + X_MAX
    c0 = torch.cat([t, x0], dim=-1)
    c1 = torch.cat([t, x1], dim=-1)
    return c0, c1

class MLP(nn.Module):
    def __init__(self, in_features=2, hidden=64, num_experts=3, depth=4):
        super().__init__()
        layers = []
        d = in_features
        for _ in range(depth):
            layers += [nn.Linear(d, hidden), nn.Tanh()]
            d = hidden
        layers.append(nn.Linear(d, num_experts))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class NAMLayer(nn.Module):
    def __init__(self, hidden=64, depth=2, r=16):

        super().__init__()
        layers = [nn.Linear(1, hidden), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, r)]
        self.net = nn.Sequential(*layers)

        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, t_or_x):
        return self.net(t_or_x)

class NAMExpert(nn.Module):
    def __init__(self, hidden=32, r=16):
        super().__init__()
        self.r = r
        # self.shared = NAMLayer(hidden, r=r)
        self.ft = NAMLayer(hidden, r=r)
        self.fx = NAMLayer(hidden, r=r)

    def _eval_dim(self, coord: torch.Tensor, dim_id: int) -> torch.Tensor:
       
        dim_col = torch.full_like(coord, float(dim_id))      # same shape as coord
        pair    = torch.cat([coord, dim_col], dim=-1)        # (B,N,2)
        return self.shared(pair)

    def forward(self, coords):
        t = coords[:, 0:1]
        x = coords[:, 1:2]
        # ft = self._eval_dim(t, 0)
        # fx = self._eval_dim(x, 1)
        fx = self.fx(x)
        ft = self.ft(t)
        prod = ft * fx
        u = prod.sum(dim=-1)
        return u.unsqueeze(-1)

class DomainMoE(nn.Module):
    def __init__(self, in_features=2, num_experts=3, expert_hidden=32, expert_rank=16, router_hidden=64, router_depth=2):
        super().__init__()
        self.num_experts = num_experts
        self.router = MLP(in_features=in_features, hidden=router_hidden, num_experts=num_experts, depth=router_depth)
        self.experts = nn.ModuleList([NAMExpert(hidden=expert_hidden, r=expert_rank) for _ in range(num_experts)])
        self.temperature = 0.5 
    def forward(self, x):
        # logits = self.router(x)
        logits = self.router(x) / self.temperature
        gates = F.softmax(logits, dim=-1)
        expert_outputs = [exp(x) for exp in self.experts]
        expert_outputs_stack = torch.stack(expert_outputs, dim=-1)
        gated_outputs = expert_outputs_stack * gates.unsqueeze(1)
        combined_output = torch.sum(gated_outputs, dim=-1)
        return combined_output, gates

def _center_unit(M: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    Mc = M - M.mean(axis=0, keepdims=True)
    n  = np.linalg.norm(Mc, axis=0, keepdims=True)
    n  = np.maximum(n, eps)
    return Mc / n

def metricA_subspace_score(F: np.ndarray, G: np.ndarray) -> tuple[float, np.ndarray]:

    Fn = _center_unit(F)   
    # print(Fn.shape)              
    Gn = _center_unit(G)

    QF, _ = np.linalg.qr(Fn, mode='reduced')   # (n, r_F_eff)
    # print(QF.shape)
    QG, _ = np.linalg.qr(Gn, mode='reduced')   # (n, r_G_eff)
    # print(QG.shape)

    Svals = np.linalg.svd(QF.T @ QG, compute_uv=False) 
    # print('Svals', Svals)
    r_eff = Svals.shape[0]
    # print(r_eff)
    S = float(np.mean(Svals[:r_eff]**2))        # S ∈ [0,1]

    return S, Svals


@torch.no_grad()
def visualize_expert_curves(model, expert_idx: int = 0,
                            res: int = 256, f_const: float = 0.0,
                            outdir: str = "expert_vis"):

    os.makedirs(outdir, exist_ok=True)
    exp  = model.experts[expert_idx].eval()
    dev  = next(exp.parameters()).device

    t = torch.linspace(0., 1., res, device=dev).unsqueeze(-1)           # (res,1)
    x = torch.linspace(0., 2*torch.pi, res*6, device=dev).unsqueeze(-1)

    ft_vals = exp._eval_dim(t, 0).squeeze(-1).cpu().numpy()                                 # (res,2)
    fx_vals = exp._eval_dim(x, 1).squeeze(-1).cpu().numpy()

    # fx_vals = exp.fx(x).cpu().numpy()
    # ft_vals = exp.ft(t).cpu().numpy()

    xs = np.linspace(0, 2*np.pi, res*6)
    ts = np.linspace(0, 1, res)
    c = 2
    target_x1 = np.sin(xs)   
    target_x2 = np.cos(xs)
    target_t1 = np.cos(ts * c)
    target_t2 = -np.sin(ts * c) 

    # Fx = fx_vals.reshape(-1, 1)               
    # Ft = ft_vals.reshape(-1, 1)                      
    Fx = fx_vals                   
    Ft = ft_vals

    Gx = np.column_stack([target_x1.reshape(-1), target_x2.reshape(-1)])              
    Gt = np.column_stack([target_t1.reshape(-1), target_t2.reshape(-1)])

    Sx, svals_x = metricA_subspace_score(Fx, Gx)
    St, svals_t = metricA_subspace_score(Ft, Gt)
    S_all = 0.5 * (Sx + St)

    print(f"[metric-A] S_x={Sx:.4f} (singvals={np.round(svals_x,4)})")
    print(f"[metric-A] S_t={St:.4f} (singvals={np.round(svals_t,4)})")
    print(f"[metric-A] S_avg={S_all:.4f}")
    
    # for name, arr, axis_vals in zip(("ft", "fx"), (ft_vals, fx_vals), (ts, xs)):
    #     plt.figure(figsize=(4, 3))

    #     for j in range(arr.shape[1]):
    #         plt.plot(axis_vals, arr[:, j], label=f"rank {j}")
        
    #     plt.xlabel("t" if name == "ft" else "x")
    #     plt.ylabel(f"{name}(t)" if name == "ft" else f"{name}(x)")
    #     plt.title(f"Expert {expert_idx} — {name}(t)")
    #     plt.legend()
    #     plt.tight_layout()

    #     fname = os.path.join(outdir, f"{name}_expert{expert_idx}.png")
    #     plt.savefig(fname, dpi=300)
    #     plt.close()
    #     print(f"[vis] saved {fname}")

class PINN_Transport:
    def __init__(self, c=C_SPEED, num_experts=2, expert_hidden=32, expert_rank=16, router_hidden=64, router_depth=4):
        self.c = c
        self.w_ic  = 10000.0
        self.w_per = 10.0
        self.w_f_initial = 0.01
        self.w_f_final   = 1.0

        self.N_interior = 16384
        self.N_ic = 8192
        self.N_per = 2048

        self.pinn = DomainMoE(in_features=2,
                              num_experts=num_experts, expert_hidden=expert_hidden, expert_rank=expert_rank,
                              router_hidden=router_hidden, router_depth=router_depth).to(device)
        
        self.optimizer_adam = torch.optim.Adam(self.pinn.parameters(), lr=1e-4) # SIRENs prefer smaller LR
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_adam, T_max=20000, eta_min=1e-6)
        
        self.optimizer_lbfgs = torch.optim.LBFGS(
            self.pinn.parameters(), max_iter=20000, 
            tolerance_grad=1e-9, tolerance_change=1e-12,
            history_size=100, line_search_fn="strong_wolfe"
        )

    def _normalize(self, coords):
        t = coords[:, 0:1]
        x = coords[:, 1:2]
        t_norm = 2.0 * (t - T_MIN) / (T_MAX - T_MIN) - 1.0
        x_norm = 2.0 * (x - X_MIN) / (X_MAX - X_MIN) - 1.0
        # Normalization for the network input, not for derivatives
        return torch.cat([t_norm, x_norm], dim=-1)

    @staticmethod
    def _autograd_grads(u, coords):
        return torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True)[0]

    def pde_residual(self, coords):
        coords_with_grad = coords.clone().detach().requires_grad_(True)
        coords_norm = self._normalize(coords_with_grad)
        u_pred, _ = self.pinn(coords_norm)
        
        # Calculate derivatives w.r.t original coordinates (t, x)
        du_dcoords = self._autograd_grads(u_pred, coords_with_grad)
        u_t = du_dcoords[:, 0:1]
        u_x = du_dcoords[:, 1:2]
        
        r = u_t + self.c * u_x
        return r, u_pred

    def loss_func(self, coords_f=None, coords_ic=None, periodic_pairs=None):
        if coords_f is None:
            coords_f = sample_interior(self.N_interior, device=device)
        if coords_ic is None:
            coords_ic = sample_ic(self.N_ic, device=device)
        if periodic_pairs is None:
            c0, c1 = sample_periodic_pairs(self.N_per, device=device)
        else:
            c0, c1 = periodic_pairs

        # PDE Residual Loss
        resid, u_on_f_pts = self.pde_residual(coords_f)
        loss_f = F.mse_loss(resid, torch.zeros_like(resid))

        # IC Loss
        coords_ic_norm = self._normalize(coords_ic)
        u_ic_pred, _ = self.pinn(coords_ic_norm)
        u_ic_true = u0_piecewise_torch(coords_ic[:, 1:2])
        loss_ic = F.mse_loss(u_ic_pred, u_ic_true)

        # Periodicity Loss
        c0_norm, c1_norm = self._normalize(c0), self._normalize(c1)
        u0, _ = self.pinn(c0_norm)
        u1, _ = self.pinn(c1_norm)
        loss_per = F.mse_loss(u0, u1)

        # L2 Relative Error for logging
        with torch.no_grad():
             u_true_f = true_u_torch(coords_f, self.c)
             error = torch.norm(u_on_f_pts - u_true_f) / torch.norm(u_true_f)

        return loss_f, loss_ic, loss_per, error

    def train(self, n_epochs_adam=10000, viz_every=2000):
        print("--- Starting Adam Optimization ---")
        anneal_epochs = n_epochs_adam * 0.75

        for ep in range(1, n_epochs_adam + 1):
            self.pinn.train()
            
            # Loss weight annealing
            w_f = self.w_f_initial + (self.w_f_final - self.w_f_initial) * min(ep / anneal_epochs, 1.0)
            # w_f = self.w_f_final
            self.optimizer_adam.zero_grad()
            loss_f, loss_ic, loss_per, error = self.loss_func()
            loss = w_f * loss_f + self.w_ic * loss_ic + self.w_per * loss_per
            loss.backward()
            self.optimizer_adam.step()
            self.scheduler.step()

            if ep % 1000 == 0:
                print(f"[Adam {ep:05d}] loss={loss.item():.4e} w_f={w_f:.2f} "
                      f"(f={loss_f.item():.2e}, ic={loss_ic.item():.2e}, per={loss_per.item():.2e}, L2_err={error.item():.2e})")

            if viz_every and ep % viz_every == 0 and ep > 0:
                # visualize_expert_curves(model=self.pinn, expert_idx=0, res = 256, f_const = 3.0)
                self.visualize(ep)

        # <<< FIX: L-BFGS optimization with a fixed dataset
        print("\n--- Starting L-BFGS Optimization ---")
        self.pinn.train()
        
        # 1. Create a fixed set of points for the deterministic loss
        coords_f_lbfgs = sample_interior(20000, device=device) # Use more points for L-BFGS
        coords_ic_lbfgs = sample_ic(5000, device=device)
        periodic_pairs_lbfgs = sample_periodic_pairs(5000, device=device)
        
        self.lbfgs_iter = 0
        def closure():
            self.optimizer_lbfgs.zero_grad()
            loss_f, loss_ic, loss_per, error = self.loss_func(
                coords_f=coords_f_lbfgs,
                coords_ic=coords_ic_lbfgs,
                periodic_pairs=periodic_pairs_lbfgs
            )
            # Use final PDE weight for L-BFGS
            loss = self.w_f_final * loss_f + self.w_ic * loss_ic + self.w_per * loss_per
            loss.backward()
            
            self.lbfgs_iter += 1
            if self.lbfgs_iter % 100 == 0:
                print(f'[L-BFGS {self.lbfgs_iter:05d}] loss={loss.item():.4e} (L2_err={error.item():.2e})')
            return loss
        
        # 2. Run the optimizer
        self.optimizer_lbfgs.step(closure)
        
        # Final evaluation
        grid_t, grid_x, u_np, u_true, g_np, final_error = self.predict_on_grid()
        print(f"\n--- Optimization Finished ---")
        print(f"Final L2 Relative Error: {final_error:.4e}")
        # visualize_expert_curves(model=self.pinn, expert_idx=0, res = 256, f_const = 3.0)


    @torch.no_grad()
    def predict_on_grid(self, res_t=1010, res_x=2570):
        self.pinn.eval()
        t_vals = np.linspace(T_MIN, T_MAX, res_t)
        x_vals = np.linspace(X_MIN, X_MAX, res_x)
        grid_t, grid_x = np.meshgrid(t_vals, x_vals, indexing="ij")
        coords = np.stack([grid_t.ravel(), grid_x.ravel()], axis=-1)
        coords_t = torch.from_numpy(coords).float().to(device)
        u_pred, gates = self.pinn(self._normalize(coords_t))
        u_np = u_pred.detach().cpu().numpy().reshape(res_t, res_x)
        g_np = gates.detach().cpu().numpy().reshape(res_t, res_x, -1)
        # u_true = np.sin(grid_x - self.c*grid_t)
        u_true = true_u_torch(coords_t, c=10)
        u_true = u_true.detach().cpu().numpy().reshape(res_t, res_x)
        error = np.linalg.norm(u_np - u_true) / np.linalg.norm(u_true)
        return (grid_t, grid_x, u_np, u_true, g_np, error)

    @torch.no_grad()
    def visualize(self, ep=0, out_dir="transport_viz"):
        os.makedirs(out_dir, exist_ok=True)
        
        TT, XX, U, U_true, G, error = self.predict_on_grid()
        # print(G_flat.shape)
        print("U_true min/max:", U_true.min(), U_true.max())

        print(f"  [viz @ ep {ep}] L2 Rel Error: {error:.4e}. Saving plots...")

        fig = plt.figure(figsize=(15, 4.5))

        plt.subplot(1, 3, 1)
        plt.pcolormesh(TT, XX, U_true, cmap='rainbow', shading='auto')
        plt.colorbar(); plt.xlabel('t'); plt.ylabel('x'); plt.title(f'True u')

        plt.subplot(1, 3, 2)
        plt.pcolormesh(TT, XX, U, cmap='rainbow', shading='auto')
        plt.colorbar(); plt.xlabel('t'); plt.ylabel('x'); plt.title(f'Predicted u')

        plt.subplot(1, 3, 3)
        sq = (U - U_true)**2
        plt.pcolormesh(TT, XX, np.sqrt(sq), cmap='jet', shading='auto')
        plt.colorbar(); plt.xlabel('t'); plt.ylabel('x'); plt.title('L2 Relative Error')

        plt.tight_layout()
        png = os.path.join(out_dir, f"ep{ep:05d}.png")
        plt.savefig(png, dpi=300)
        plt.close(fig)

        # fig = plt.figure(figsize=(6, 5))
        # plt.pcolormesh(TT, XX, U_true, cmap='rainbow', shading='auto')
        # cbar = plt.colorbar()
        # cbar.ax.tick_params(labelsize=14)
        
        # plt.xlabel('t', fontsize=18, labelpad=6)
        # plt.ylabel('x', fontsize=18, labelpad=6)
        # plt.tick_params(axis='both', labelsize=14)
        
        # plt.title('Exact u(t,x)', fontsize=16, pad=6)
        # plt.tight_layout()

        # png1 = os.path.join(out_dir, f"transport_exact.png")
        # plt.savefig(png1, dpi=300)
        # plt.close(fig)
        # print(f"[vis] saved {png1}")
   

        # # ---- 图 2: 区域分解图 (Domain Decomposition) ----
        # E = G.shape[-1]  # 专家数量
        # for i in range(E):
        #     fig = plt.figure(figsize=(6, 5))
        #     plt.pcolormesh(TT.T, XX.T, G[..., i].T, cmap='hot', shading='auto', vmin=0, vmax=1)
        #     cbar = plt.colorbar()
        #     plt.xlabel('t', fontsize=18, labelpad=6)
        #     plt.ylabel('x', fontsize=18, labelpad=6)
        #     plt.tick_params(axis='both', labelsize=14)
        #     plt.title(f'Expert {i+1} Gate Weight, K=5', fontsize=16, pad=6)
        #     plt.tight_layout()
        #     plt.savefig(f'Expert {i+1} Gate Weight_ep.png', dpi=300)

        print(f"  [viz] Saved plots!")

if __name__ == '__main__':
    num_experts = 3
    expert_hidden = 64
    expert_rank = 16
    router_hidden = 64
    router_depth  = 4

    model = PINN_Transport(
        c=C_SPEED,
        num_experts=num_experts, expert_hidden=expert_hidden, expert_rank=expert_rank,
        router_hidden=router_hidden, router_depth=router_depth
    )

    # print(f"Allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB")

    model.pinn.to(device)  # DomainMoE
    model.pinn.to(device)  # DomainMoE
    total_trainable = sum(p.numel() for p in model.pinn.parameters() if p.requires_grad)
    total_all       = sum(p.numel() for p in model.pinn.parameters())
    print(f"Trainable: {total_trainable:,}  |  All: {total_all:,}")


    start = time.time()
    model.train(n_epochs_adam=15000, viz_every=5000)
    print(f"Total training time: {time.time()-start:.2f}s")
    model.visualize(ep=99999) # Final visualization
