from google.colab import drive

drive.mount("/content/drive")


import os

DRIVE_OUTPUT = "/content/drive/My Drive/MSN_Paper13_Discovery_v5/exp5_singular_forcing"
os.makedirs(DRIVE_OUTPUT, exist_ok=True)

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from pathlib import Path


torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


BETA = -0.5
TARGET_ALPHA = 1.5


def true_solution(x):

    return (4 / 3) * x - (4 / 3) * x ** (3 / 2)


def forcing_term(x):

    return x**BETA


x_test = torch.linspace(0.01, 1.0, 100)
u_test = true_solution(x_test)


print(f"u(0.01) = {true_solution(torch.tensor(0.01)):.6f} (should be small)")
print(f"u(1) = {true_solution(torch.tensor(1.0)):.6f} (should be 0)")


fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(x_test.numpy(), u_test.numpy(), "b-", linewidth=2)
axes[0].set_xlabel("x")
axes[0].set_ylabel("u(x)")
axes[0].set_title("True Solution: $u(x) = \\frac{4}{3}x(1 - \\sqrt{x})$")
axes[0].grid(True, alpha=0.3)

x_forcing = torch.linspace(0.01, 1.0, 100)
axes[1].plot(x_forcing.numpy(), forcing_term(x_forcing).numpy(), "r-", linewidth=2)
axes[1].set_xlabel("x")
axes[1].set_ylabel("f(x)")
axes[1].set_title("Singular Forcing: $f(x) = x^{-1/2}$")
axes[1].set_ylim([0, 10])
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


class MSN_PINN_SingularForcing(nn.Module):

    def __init__(self, K=4, mu_init=None, mu_min=0.5, mu_max=3.0):
        super().__init__()
        self.K = K
        self.mu_min = mu_min
        self.mu_max = mu_max

        if mu_init is None:

            mu_init = torch.linspace(0.8, 2.2, K)

        self.mu_raw = nn.Parameter(torch.zeros(K))

        with torch.no_grad():

            mu_normalized = (mu_init - mu_min) / (mu_max - mu_min)
            mu_normalized = torch.clamp(mu_normalized, 0.01, 0.99)
            self.mu_raw.data = torch.log(mu_normalized / (1 - mu_normalized))

        self.coeffs = nn.Parameter(torch.randn(K) * 0.1)

    def exponents(self):

        return self.mu_min + (self.mu_max - self.mu_min) * torch.sigmoid(self.mu_raw)

    def forward(self, x):

        x = x.reshape(-1)
        mu = self.exponents()

        x_powers = x.unsqueeze(-1) ** mu.unsqueeze(0)

        u = (x_powers * self.coeffs.unsqueeze(0)).sum(dim=-1)
        return u

    def second_derivative(self, x):

        x = x.reshape(-1)
        mu = self.exponents()

        d2_coeffs = self.coeffs * mu * (mu - 1)

        x_powers = x.unsqueeze(-1) ** (mu.unsqueeze(0) - 2)

        u_xx = (x_powers * d2_coeffs.unsqueeze(0)).sum(dim=-1)
        return u_xx

    def pde_residual(self, x, beta=BETA):

        u_xx = self.second_derivative(x)
        forcing = x.reshape(-1) ** beta
        return u_xx + forcing


model = MSN_PINN_SingularForcing(K=4)
print("Initial exponents:", model.exponents().detach().numpy())
print("Initial coefficients:", model.coeffs.detach().numpy())


x_test = torch.linspace(0.01, 1.0, 10)
u_pred = model(x_test)
print(f"\nPrediction shape: {u_pred.shape}")
print(f"Prediction range: [{u_pred.min():.4f}, {u_pred.max():.4f}]")


res = model.pde_residual(x_test)
print(f"\nResidual shape: {res.shape}")
print(f"Residual range: [{res.min():.4f}, {res.max():.4f}]")


def train_msn_pinn(
    model,
    n_epochs=10000,
    n_colloc=200,
    lr_mu=0.01,
    lr_c=0.01,
    w_res=1.0,
    w_bc=100.0,
    w_sparse=0.01,
    log_every=500,
):

    opt_mu = torch.optim.Adam([model.mu_raw], lr=lr_mu)
    opt_c = torch.optim.Adam([model.coeffs], lr=lr_c)

    x_colloc = torch.linspace(0.01, 1.0, n_colloc).to(device)

    x_bc_left = torch.tensor([0.001]).to(device)
    x_bc_right = torch.tensor([1.0]).to(device)

    history = {
        "loss": [],
        "loss_res": [],
        "loss_bc": [],
        "loss_sparse": [],
        "exponents": [],
        "coefficients": [],
        "exp_error": [],
    }

    model.to(device)
    model.train()

    for epoch in range(n_epochs):
        opt_mu.zero_grad()
        opt_c.zero_grad()

        res = model.pde_residual(x_colloc)
        loss_res = (res**2).mean()

        u_left = model(x_bc_left)
        u_right = model(x_bc_right)
        loss_bc = (u_left**2).mean() + (u_right**2).mean()

        loss_sparse = torch.abs(model.coeffs).mean()

        loss = w_res * loss_res + w_bc * loss_bc + w_sparse * loss_sparse

        loss.backward()

        torch.nn.utils.clip_grad_norm_([model.mu_raw], 1.0)
        torch.nn.utils.clip_grad_norm_([model.coeffs], 1.0)

        opt_mu.step()
        opt_c.step()

        if epoch % log_every == 0 or epoch == n_epochs - 1:
            mu = model.exponents().detach().cpu().numpy()
            c = model.coeffs.detach().cpu().numpy()

            exp_error = min(abs(m - TARGET_ALPHA) for m in mu)

            history["loss"].append(loss.item())
            history["loss_res"].append(loss_res.item())
            history["loss_bc"].append(loss_bc.item())
            history["loss_sparse"].append(loss_sparse.item())
            history["exponents"].append(mu.tolist())
            history["coefficients"].append(c.tolist())
            history["exp_error"].append(exp_error)

            print(
                f"Epoch {epoch:5d} | Loss: {loss.item():.6f} | "
                f"Res: {loss_res.item():.6f} | BC: {loss_bc.item():.6f} | "
                f"Exp error: {exp_error:.4f}"
            )
            print(f"           Exponents: {[f'{m:.4f}' for m in mu]}")
            print(f"           Coeffs:    {[f'{c_:.4f}' for c_ in c]}")

    return history


model = MSN_PINN_SingularForcing(K=4, mu_min=0.5, mu_max=3.0)

print("=" * 60)
print("Exp5: Singular Forcing Discovery")
print("=" * 60)
print(f"PDE: -u''(x) = x^{BETA}")
print(f"Target exponent: {TARGET_ALPHA} (= {BETA} + 2)")
print(f"Initial exponents: {model.exponents().detach().numpy()}")
print("=" * 60)


history = train_msn_pinn(
    model,
    n_epochs=15000,
    n_colloc=200,
    lr_mu=0.005,
    lr_c=0.01,
    w_res=1.0,
    w_bc=100.0,
    w_sparse=0.001,
    log_every=1000,
)


model.eval()
final_mu = model.exponents().detach().cpu().numpy()
final_c = model.coeffs.detach().cpu().numpy()

print("\n" + "=" * 60)
print("FINAL RESULTS")
print("=" * 60)
print(f"Target exponents: 1.0 (linear) and {TARGET_ALPHA} (from singular forcing)")
print(f"")
print(f"Discovered exponents:")
for i, (mu, c) in enumerate(zip(final_mu, final_c)):
    print(f"  mu[{i}] = {mu:.4f}, c[{i}] = {c:.4f}")


closest_to_1 = min(final_mu, key=lambda m: abs(m - 1.0))
closest_to_15 = min(final_mu, key=lambda m: abs(m - TARGET_ALPHA))

error_1 = abs(closest_to_1 - 1.0) / 1.0 * 100
error_15 = abs(closest_to_15 - TARGET_ALPHA) / TARGET_ALPHA * 100

print(f"")
print(f"Closest to 1.0: {closest_to_1:.4f} (error: {error_1:.2f}%)")
print(f"Closest to 1.5: {closest_to_15:.4f} (error: {error_15:.2f}%)")
print("=" * 60)


fig, axes = plt.subplots(2, 2, figsize=(14, 10))


ax = axes[0, 0]
exp_history = np.array(history["exponents"])
epochs = np.arange(0, len(exp_history) * 1000, 1000)

for k in range(exp_history.shape[1]):
    ax.plot(epochs, exp_history[:, k], "-", linewidth=2, label=f"mu[{k}]")

ax.axhline(y=1.0, color="green", linestyle="--", linewidth=2, label="Target: 1.0")
ax.axhline(
    y=TARGET_ALPHA,
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Target: {TARGET_ALPHA}",
)
ax.set_xlabel("Epoch", fontsize=12)
ax.set_ylabel("Exponent mu", fontsize=12)
ax.set_title("Exponent Discovery Trajectory", fontsize=14)
ax.legend(loc="best")
ax.grid(True, alpha=0.3)


ax = axes[0, 1]
ax.semilogy(epochs, history["loss"], "b-", linewidth=2, label="Total")
ax.semilogy(epochs, history["loss_res"], "r-", linewidth=1.5, label="Residual")
ax.semilogy(epochs, history["loss_bc"], "g-", linewidth=1.5, label="BC")
ax.set_xlabel("Epoch", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("Training Loss", fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)


ax = axes[1, 0]
x_plot = torch.linspace(0.01, 1.0, 200)
with torch.no_grad():
    u_pred = model(x_plot.to(device)).cpu().numpy()
u_true = true_solution(x_plot).numpy()

ax.plot(x_plot.numpy(), u_true, "b-", linewidth=2, label="True solution")
ax.plot(x_plot.numpy(), u_pred, "r--", linewidth=2, label="MSN-PINN")
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("u(x)", fontsize=12)
ax.set_title("Solution Comparison", fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)


ax = axes[1, 1]
error = np.abs(u_pred - u_true)
ax.semilogy(x_plot.numpy(), error, "k-", linewidth=2)
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("|u_pred - u_true|", fontsize=12)
ax.set_title("Absolute Error", fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("exp5_results.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"\nSolution MSE: {np.mean(error**2):.6e}")
print(f"Solution max error: {np.max(error):.6e}")


fig, ax = plt.subplots(figsize=(10, 6))


colors = plt.cm.tab10(np.linspace(0, 1, len(final_mu)))

for i, (mu, c, color) in enumerate(zip(final_mu, final_c, colors)):
    ax.bar(
        mu, abs(c), width=0.05, color=color, alpha=0.7, label=f"mu={mu:.3f}, c={c:.3f}"
    )

ax.axvline(x=1.0, color="green", linestyle="--", linewidth=2, label="Target: 1.0")
ax.axvline(
    x=TARGET_ALPHA,
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Target: {TARGET_ALPHA}",
)

ax.set_xlabel("Exponent mu", fontsize=14)
ax.set_ylabel("|Coefficient|", fontsize=14)
ax.set_title("Discovered Exponent Spectrum", fontsize=16)
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
ax.set_xlim([0.3, 3.2])

plt.tight_layout()
plt.savefig("exp5_spectrum.png", dpi=150, bbox_inches="tight")
plt.show()


x_verify = torch.linspace(0.02, 0.98, 100).to(device)

with torch.no_grad():
    residual = model.pde_residual(x_verify).cpu().numpy()

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x_verify.cpu().numpy(), residual, "b-", linewidth=2)
ax.axhline(y=0, color="k", linestyle="--", linewidth=1)
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("Residual: u''(x) + x^beta", fontsize=12)
ax.set_title("PDE Residual (should be ~ 0)", fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("exp5_residual.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"Residual MSE: {np.mean(residual**2):.6e}")
print(f"Residual max: {np.max(np.abs(residual)):.6e}")


import json
from pathlib import Path


def convert_to_serializable(obj):

    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(i) for i in obj]
    return obj


output_dir = Path("exp5_singular_forcing")
output_dir.mkdir(exist_ok=True)


torch.save(
    {
        "model_state_dict": model.state_dict(),
        "final_exponents": final_mu.tolist(),
        "final_coefficients": final_c.tolist(),
        "target_exponents": [1.0, TARGET_ALPHA],
        "forcing_exponent": BETA,
    },
    output_dir / "model.pt",
)


history_serializable = convert_to_serializable(history)


with open(output_dir / "training_history.json", "w") as f:
    json.dump(history_serializable, f, indent=2)


results = {
    "experiment": "Exp5: Singular Forcing Discovery",
    "pde": f"-u'' = x^{BETA}",
    "forcing_exponent": float(BETA),
    "target_exponents": [1.0, float(TARGET_ALPHA)],
    "discovered_exponents": [float(x) for x in final_mu],
    "discovered_coefficients": [float(x) for x in final_c],
    "closest_to_1.0": float(closest_to_1),
    "closest_to_1.5": float(closest_to_15),
    "error_1.0_percent": float(error_1),
    "error_1.5_percent": float(error_15),
    "final_loss": float(history["loss"][-1]),
    "final_residual_loss": float(history["loss_res"][-1]),
    "final_bc_loss": float(history["loss_bc"][-1]),
}

with open(output_dir / "results.json", "w") as f:
    json.dump(results, f, indent=2)

os.sync()
print(f"Results saved to {output_dir}/")


import shutil

for f in output_dir.iterdir():
    shutil.copy2(f, DRIVE_OUTPUT)
    print(f"Copied {f.name} to Drive")


os.sync()
print(f"\nAll files saved to: {DRIVE_OUTPUT}")
