import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.io
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import time
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

np.random.seed(1234)
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(1234)

class MLP(nn.Module):
    def __init__(self, in_features=2, hidden=64, num_experts=3, depth=4):
        super().__init__()
        layers = []
        d = in_features
        for _ in range(depth):
            layers += [nn.Linear(d, hidden), nn.Tanh()]
            d = hidden
        layers.append(nn.Linear(d, num_experts))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
        
class NAMLayer_shared(nn.Module):
    def __init__(self, hidden=64, depth=2, r=16):

        super().__init__()
        layers = [nn.Linear(2, hidden), nn.Tanh()]
        for _ in range(depth - 1):
            layers += [nn.Linear(hidden, hidden), nn.Tanh()]
        layers += [nn.Linear(hidden, r)]
        self.net = nn.Sequential(*layers)

        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, t_or_x):
        return self.net(t_or_x)
    
class NAMExpert_shared(nn.Module):
    def __init__(self, hidden=32, r=16):
        super().__init__()
        self.r = r
        self.shared = NAMLayer_shared(hidden, r=r)

    def _eval_dim(self, coord: torch.Tensor, dim_id: int) -> torch.Tensor:
        
        dim_col = torch.full_like(coord, float(dim_id))      # same shape as coord
        pair    = torch.cat([coord, dim_col], dim=-1)        # (B,N,2)
        return self.shared(pair)

    def forward(self, coords):    
        x, t = coords[..., 0:1], coords[..., 1:2]
        ft = self._eval_dim(t, 0)
        fx = self._eval_dim(x, 1)
        
        prod = ft * fx
        u = prod.sum(dim=-1)
        return u.unsqueeze(-1)

    
class DomainMoE_shared(nn.Module):
    def __init__(self, in_features=2, num_experts=3, r=16):
        super().__init__()
        self.num_experts = num_experts
        
        self.router = MLP(num_experts=self.num_experts)
        
        self.experts = nn.ModuleList([
            NAMExpert_shared(r=r) for _ in range(self.num_experts)
        ])

    def forward(self, x):
        logits = self.router(x)
        gates = F.softmax(logits, dim=-1)

        expert_outputs = []
        for i in range(self.num_experts):
            output = self.experts[i](x)
            expert_outputs.append(output)

        expert_outputs_stack = torch.stack(expert_outputs, dim=-1)

        gated_outputs = expert_outputs_stack * gates.unsqueeze(1)
        combined_output = torch.sum(gated_outputs, dim=-1)

        return combined_output, gates
        
class PINN_Burger_shared:
    def __init__(self, X_u_train, u_train, X_f_train, lb, ub, nu, num_experts, r, coords_eval, u_star):
        self.lb = torch.tensor(lb, dtype=torch.float32).to(device)
        self.ub = torch.tensor(ub, dtype=torch.float32).to(device)
        
        self.X_u_train = torch.tensor(X_u_train, dtype=torch.float32, requires_grad=True).to(device)
        self.u_train = torch.tensor(u_train, dtype=torch.float32).to(device)
        self.X_f_train = torch.tensor(X_f_train, dtype=torch.float32, requires_grad=True).to(device)
        
        xcol = self.X_u_train[:, 0]
        tcol = self.X_u_train[:, 1]
        self.mask_ic  = (tcol == 0.0)
        self.mask_bcl = (xcol == -1.0)
        self.mask_bcr = (xcol ==  1.0)

        self.w_ic  = 100.0
        self.w_bcl = 1.0
        self.w_bcr = 1.0
        self.w_f   = 1.0

        self.nu = nu
        self.r = r  # Store r value
        self.pinn = DomainMoE_shared(num_experts=num_experts, r=self.r).to(device)
        
        self.optimizer_adam = torch.optim.Adam(self.pinn.parameters(), lr=2e-3, weight_decay=1e-6)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_adam, T_max=20000, eta_min=1e-6)
        
        self.optimizer_lbfgs = torch.optim.LBFGS(
            self.pinn.parameters(), max_iter=50000, max_eval=None, 
            tolerance_grad=1e-7, tolerance_change=1.0 * np.finfo(float).eps,
            history_size=50, line_search_fn="strong_wolfe"
        )
        self.iter = 0
        self.err_steps = []
        self.err_values = []
        self._global_step = 0

        self.coords_eval = coords_eval
        self.u_star = u_star

    def _rel_l2_on(self, coords, u_star):
        self.pinn.eval()
        u_pred_eval, _, _ = self.predict(coords)
        error_u = np.linalg.norm(u_star - u_pred_eval, 2) / np.linalg.norm(u_star, 2)
        return float(error_u)

    def _log_error(self):
        err = self._rel_l2_on(self.coords_eval, self.u_star)
        self.err_steps.append(self._global_step)
        self.err_values.append(err)
        return err

    def net_u(self, x, t):
        X = torch.cat([x, t], dim=1)
        X_normalized = 2.0 * (X - self.lb) / (self.ub - self.lb) - 1.0
        u, _ = self.pinn(X_normalized)
        return u

    def net_f(self, x, t):
        u = self.net_u(x, t)
        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), retain_graph=True, create_graph=True)[0]
        f = u_t + u * u_x - self.nu * u_xx
        return f

    def loss_func(self):
        self.pinn.train() 
        u_pred = self.net_u(self.X_u_train[:, 0:1], self.X_u_train[:, 1:2])
        f_pred = self.net_f(self.X_f_train[:, 0:1], self.X_f_train[:, 1:2])
        loss_f = torch.mean(f_pred ** 2)

        def safe_mse(mask):
            if mask.sum() == 0:
                return torch.tensor(0.0, device=device)
            diff = (self.u_train[mask] - u_pred[mask])**2
            return diff.mean()

        loss_ic  = safe_mse(self.mask_ic)
        loss_bcl = safe_mse(self.mask_bcl)
        loss_bcr = safe_mse(self.mask_bcr)
        
        total_loss = self.w_ic*loss_ic + self.w_bcl*loss_bcl + self.w_bcr*loss_bcr + self.w_f*loss_f
        return total_loss

    def train(self, n_epochs_adam, X_star, u_star):
        print("--- Starting Adam Optimization ---")
        self._global_step = 0
        for epoch in range(n_epochs_adam):
            self.pinn.train()
            self.optimizer_adam.zero_grad()
            loss = self.loss_func()
            loss.backward()
            self.optimizer_adam.step()
            self.scheduler.step()
            self._global_step += 1
            if epoch % 100 == 0:
                self._log_error()

            if (epoch + 1) % 1000 == 0:
                lr = self.optimizer_adam.param_groups[0]['lr']
                error_u = self._rel_l2_on(X_star, u_star)
                print(f'Epoch {epoch:05d} => Loss: {loss.item():.4e}, L2 Error: {error_u:.4e}, LR: {lr:.4e}')
        
        print("\n--- Starting L-BFGS Optimization ---")
        self.lbfgs_iter = 0
        self.pinn.train()
        def closure():
            self.optimizer_lbfgs.zero_grad()
            loss = self.loss_func()
            loss.backward()
            self.lbfgs_iter += 1
            self._global_step += 1
            if self.lbfgs_iter % 100 == 0:
                self._log_error()
                print(f'Iter {self.lbfgs_iter:05d} => Loss: {loss.item():.4e}')
            return loss
        self.optimizer_lbfgs.step(closure)
        error_u = self._rel_l2_on(X_star, u_star)
        print(f'Final L2 Error: {error_u:.4e}')
        
        # Use self.r to create a dynamic label for the plot legend
        return {"label": f"r = {self.r}", "steps": self.err_steps[:], "errors": self.err_values[:]}

    def predict(self, X_star):
        self.pinn.eval()
        with torch.no_grad():
            X_star_tensor = torch.tensor(X_star, dtype=torch.float32).to(device)
            X_normalized = 2.0 * (X_star_tensor - self.lb) / (self.ub - self.lb) - 1.0
            u_star, gates_star = self.pinn(X_normalized)
            
        X_star_tensor.requires_grad_(True)
        f_star = self.net_f(X_star_tensor[:, 0:1], X_star_tensor[:, 1:2])
        
        return u_star.detach().cpu().numpy(), f_star.detach().cpu().numpy(), gates_star.detach().cpu().numpy()

def plot_r_comparison(histories, save_path="r_comparison_error_vs_steps.png", logy=True,
                      fs_label=14, fs_legend=14, fs_ticks=12):

    if not histories:
        print("Error: 'histories' list is empty. Nothing to plot.")
        return

    # Find the minimum number of recorded steps across all histories
    min_len = min(len(h["steps"]) for h in histories)
    print(f"Plotting all runs up to the {min_len}-th recorded data point.")

    plt.figure(figsize=(8, 6))
    
    # Use a colormap to automatically assign different colors to each line
    colors = plt.cm.viridis(np.linspace(0, 1, len(histories)))

    for i, h in enumerate(histories):
        # Plot data truncated to the minimum length
        plt.plot(
            h['steps'][:min_len],
            h["errors"][:min_len],
            linewidth=2.5,
            color=colors[i],
            label=h["label"]
        )

    if logy:
        plt.yscale("log")

    plt.xlabel("Training steps", fontsize=fs_label, labelpad=6)
    plt.ylabel("Relative $L_2$ error", fontsize=fs_label, labelpad=6)
    plt.tick_params(axis='both', which='both', labelsize=fs_ticks)
    
    plt.grid(True, which="both", linestyle="--", alpha=0.6)
    # Add a title to the legend
    plt.legend(fontsize=fs_legend, frameon=True, framealpha=0.8, title="Rank")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=220)
    print(f"Saved comparison figure -> {save_path}")
    plt.show()

if __name__ == '__main__':
    nu = 0.01 / np.pi
    N_u = 900
    N_f = 10000
    n_epochs_adam = 10000
    num_experts = 2

    r_values = [1, 4, 8, 16]
    all_histories = []

    # --- Data loading and preparation (same for all runs) ---
    data = scipy.io.loadmat('burgers_shock.mat')
    t = data['t'].flatten()[:, None]
    x = data['x'].flatten()[:, None]
    Exact = np.real(data['usol']).T
    X, T = np.meshgrid(x, t)
    X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None]))
    u_star = Exact.flatten()[:, None]
    lb = X_star.min(0)
    ub = X_star.max(0)

    X_ic  = np.hstack((x, np.zeros_like(x)));    u_ic  = -np.sin(np.pi * x)
    X_bcl = np.hstack((-np.ones_like(t), t));    u_bcl = np.zeros_like(t)
    X_bcr = np.hstack(( np.ones_like(t), t));    u_bcr = np.zeros_like(t)

    N_ic  = min(300, X_ic.shape[0])
    N_bcl = min(300, X_bcl.shape[0])
    N_bcr = min(300, X_bcr.shape[0])

    idx_ic  = np.random.choice(X_ic.shape[0],  N_ic,  replace=False)
    idx_bcl = np.random.choice(X_bcl.shape[0], N_bcl, replace=False)
    idx_bcr = np.random.choice(X_bcr.shape[0], N_bcr, replace=False)

    X_u_train = np.vstack([X_ic[idx_ic],  X_bcl[idx_bcl],  X_bcr[idx_bcr]])
    u_train   = np.vstack([u_ic[idx_ic],  u_bcl[idx_bcl],  u_bcr[idx_bcr]])

    N_u = min(N_u, X_u_train.shape[0]) 
    idx = np.random.choice(X_u_train.shape[0], N_u, replace=False)
    X_u_train = X_u_train[idx, :]
    u_train = u_train[idx, :]
    X_f_train = lb + (ub - lb) * np.random.rand(N_f, 2)
    
    # --- Loop over r values, train models, and collect histories ---
    for r_val in r_values:
        print(f"\n{'='*50}")
        print(f"           TRAINING MODEL WITH r = {r_val}           ")
        print(f"{'='*50}\n")
        
        # Instantiate and train the model for the current r value
        model = PINN_Burger_shared(
            X_u_train, u_train, X_f_train, lb, ub, nu, 
            num_experts=num_experts, r=r_val, 
            coords_eval=X_star, u_star=u_star
        )
        model.pinn.to(device)
        
        # total_trainable = sum(p.numel() for p in model.pinn.parameters() if p.requires_grad)
        # print(f"Model for r={r_val} has {total_trainable:,} trainable parameters.")
        
        # Train and store the history
        history = model.train(n_epochs_adam, X_star, u_star)
        all_histories.append(history)
        
    # --- Plot the comparison ---
    plot_r_comparison(all_histories, save_path="r_comparison_error_vs_steps.png", logy=True)
