#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations

import math
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.integrate import odeint

# -------------------------
# Global Config
# -------------------------
torch.manual_seed(42)
np.random.seed(42)

# Physics Params 
GAMMA   = 1.4
EPS     = 1.0     # Pressure 
MU      = 0.05    # Viscosity
V_PLUS  = 2.0     
V_MINUS = 1.0     
L_DOMAIN = 10.0   # x in [-L, L]
T_DOMAIN = 5.0    # t in [0, T]

# RaNN params
USE_FOURIER_FEATURES   = True
FEATURE_SCALE          = 3.5        
INPUT_SCALES           = [1.0, 0.1] # [t, x]

# SCALING CONFIG
N_TRAIN_PER_FEATURE    = 30         #
BATCH_SIZE             = 5000       #
N_TEST_GLOBAL          = 20000      
WIDTH_LIST             = [10*k for k in range(1, 26)]
RUNS_PER_WIDTH         = 3          

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")

# -------------------------
# Utilities
# -------------------------
def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)

# -------------------------
# RaNN Definition
# -------------------------
class RFNNnD(nn.Module):
    def __init__(self, input_dim: int, num_random_features: int, 
                 feature_scale: float = 1.0, *, use_fourier: bool = True,
                 input_scales: list[float] = None, seed: int = 12345):
        super().__init__()
        torch.manual_seed(seed)
        self.input_dim = input_dim
        self.num_random_features = int(num_random_features)
        self.use_fourier = bool(use_fourier)
        
        if input_scales is None:
            input_scales = [1.0] * self.input_dim
        self.register_buffer("input_scales", _to_tensor(input_scales).view(1, -1))

        if self.use_fourier:
            assert self.num_random_features % 2 == 0
            rf = self.num_random_features // 2
            # Sampling weights
            self.random_weights = torch.randn(rf, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(rf, device=DEVICE) * 2 * math.pi
        else:
            self.random_weights = torch.randn(self.num_random_features, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(self.num_random_features, device=DEVICE)

        self.random_weights = nn.Parameter(self.random_weights, requires_grad=False)
        self.random_biases = nn.Parameter(self.random_biases, requires_grad=False)

    def _feature_map(self, xt: torch.Tensor) -> torch.Tensor:
        # xt: [Batch, input_dim]
        x_scaled = xt * self.input_scales
        proj = x_scaled @ self.random_weights.t() # [Batch, rf]

        if self.use_fourier:
            cosf = torch.cos(proj + self.random_biases)
            sinf = torch.sin(proj + self.random_biases)
            scale = (self.num_random_features // 2) ** (-0.5)
            return scale * torch.cat([cosf, sinf], dim=1)
        else:
            return torch.tanh(proj + self.random_biases)

# -------------------------
# Physics (Navier-Stokes ODE Ground Truth)
# -------------------------
class NSTravelingWave:
    """
    Singleton to compute and interpolate the ODE solution for 
    Compressible NS Traveling Shock.
    """
    def __init__(self):
        # 1. Derive Shock Parameters (Rankine-Hugoniot)
        self.v_plus = V_PLUS
        self.v_minus = V_MINUS
        
        def p(v): return EPS / (v**GAMMA)
        
        p_plus = p(self.v_plus)
        p_minus = p(self.v_minus)
        
        # Shock speed squared
        self.s2 = - (p_plus - p_minus) / (self.v_plus - self.v_minus)
        self.s = math.sqrt(self.s2)
        
        # Velocity relation
        self.u_plus = 0.0
        # const = s*v + u -> u = const - s*v
        C = self.s * self.v_plus + self.u_plus
        self.u_minus = C - self.s * self.v_minus
        
        # 2. Integrate ODE profile
        # dv/dxi = (v / (mu * s)) * ( p(v) - p_plus + s^2(v - v_plus) )
        def deriv(v, xi):
            if v <= 0: return 0.0
            term_p = p(v) - p_plus
            term_inert = self.s2 * (v - self.v_plus)
            return (v / (MU * self.s)) * (term_p + term_inert)

        # Integrate from near v_minus to near v_plus
        v_start = self.v_minus + 1e-4
        xi_span = np.linspace(-20, 20, 5000)
        
        # Solve IVP (scipy)
        sol = odeint(deriv, v_start, xi_span)
        self.xi_grid = xi_span
        self.v_grid = sol.flatten()
        
        # Center the shock at xi=0 where v is average
        v_mid = 0.5 * (self.v_minus + self.v_plus)
        idx_mid = np.argmin(np.abs(self.v_grid - v_mid))
        self.xi_grid = self.xi_grid - self.xi_grid[idx_mid]
        
        # Compute u grid
        self.u_grid = C - self.s * self.v_grid

        # Convert to tensor for interpolation
        self.xi_t = torch.tensor(self.xi_grid, device=DEVICE, dtype=torch.float32)
        self.v_t  = torch.tensor(self.v_grid, device=DEVICE, dtype=torch.float32)
        self.u_t  = torch.tensor(self.u_grid, device=DEVICE, dtype=torch.float32)
        self.C_const = C

    def get_solution(self, t, x):
        """
        Returns v(t,x), u(t,x)
        xi = x - s*t
        """
        xi_q = x - self.s * t
        
        # Manual Linear Interpolation on GPU
        # Ideally we use searchsorted or bucketize
        idx = torch.bucketize(xi_q.squeeze(), self.xi_t)
        idx = torch.clamp(idx, 1, len(self.xi_t)-1)
        
        x0 = self.xi_t[idx-1]
        x1 = self.xi_t[idx]
        
        w = (xi_q.squeeze() - x0) / (x1 - x0 + 1e-8)
        w = torch.clamp(w, 0.0, 1.0)
        
        v_interp = (1-w)*self.v_t[idx-1] + w*self.v_t[idx]
        u_interp = (1-w)*self.u_t[idx-1] + w*self.u_t[idx]
        
        # Handle boundaries (constant states)
        mask_left = xi_q.squeeze() < self.xi_t[0]
        mask_right = xi_q.squeeze() > self.xi_t[-1]
        
        v_interp[mask_left] = self.v_minus
        u_interp[mask_left] = self.u_minus
        v_interp[mask_right] = self.v_plus
        u_interp[mask_right] = self.u_plus
        
        return torch.stack([v_interp, u_interp], dim=1) # [N, 2]

# Instantiate physics once
PHYSICS = NSTravelingWave()

def generate_dataset(n_samples: int):
    # Mix of global uniform and focused samples near shock
    n_focus = n_samples // 2
    n_global = n_samples - n_focus

    # Global uniform
    x_global = (torch.rand(n_global, 1, device=DEVICE) * 2 - 1) * L_DOMAIN
    t_global = torch.rand(n_global, 1, device=DEVICE) * T_DOMAIN

    # Focused near shock: x ~ s*t + noise
    t_focus = torch.rand(n_focus, 1, device=DEVICE) * T_DOMAIN
    x_center = PHYSICS.s * t_focus
    x_noise  = torch.randn(n_focus, 1, device=DEVICE) * 1.0 # Sigma=1.0 around shock
    x_focus  = x_center + x_noise
    x_focus  = torch.clamp(x_focus, -L_DOMAIN, L_DOMAIN)

    x = torch.cat([x_global, x_focus], dim=0)
    t = torch.cat([t_global, t_focus], dim=0)
    
    # Ground Truth
    vu = PHYSICS.get_solution(t, x) # [N, 2]
    xt = torch.cat([t, x], dim=1)   # [N, 2]
    
    perm = torch.randperm(xt.size(0), device=DEVICE)
    return xt[perm], vu[perm]

def rel_L2_error(vu_pred, vu_true):
    # Error over both v and u fields
    # norm(pred - true)^2 / norm(true)^2
    diff_sq = torch.sum((vu_pred - vu_true) ** 2)
    norm_sq = torch.sum(vu_true ** 2) + 1e-16
    return float(torch.sqrt(diff_sq / norm_sq).item())

# -------------------------
# MEMORY EFFICIENT SOLVER
# -------------------------
def batched_ridge_solve(model, xt_train, u_train, batch_size=5000, lam=1e-2):
    """
    Computes W = (Phi^T Phi + lam*I)^-1 Phi^T U
    Works for multi-output U [N, 2]
    """
    N_samples = xt_train.shape[0]
    M_features = model.num_random_features
    output_dim = u_train.shape[1]
    
    G = torch.zeros((M_features, M_features), device=DEVICE)
    B = torch.zeros((M_features, output_dim), device=DEVICE)
    
    for i in range(0, N_samples, batch_size):
        end = min(i + batch_size, N_samples)
        xt_batch = xt_train[i:end]
        u_batch = u_train[i:end]
        
        with torch.no_grad():
            phi_batch = model._feature_map(xt_batch) # [batch, M]
            
        G += phi_batch.T @ phi_batch
        B += phi_batch.T @ u_batch
        
        del phi_batch 
        
    G += lam * torch.eye(M_features, device=DEVICE) * N_samples
    
    try:
        L = torch.linalg.cholesky(G)
        w = torch.cholesky_solve(B, L)
    except:
        w = torch.linalg.solve(G, B)
        
    return w

def batched_predict(model, w, xt_in, batch_size=5000):
    N = xt_in.shape[0]
    preds = []
    for i in range(0, N, batch_size):
        end = min(i + batch_size, N)
        with torch.no_grad():
            phi = model._feature_map(xt_in[i:end])
            preds.append(phi @ w)
    return torch.cat(preds, dim=0)

# -------------------------
# Main Loop
# -------------------------
def main():
    print(f"Generating global test dataset ({N_TEST_GLOBAL} points)...")
    xt_test, vu_test = generate_dataset(N_TEST_GLOBAL)

    train_means, train_stds = [], []
    test_means,  test_stds  = [], []

    for width in WIDTH_LIST:
        n_train = N_TRAIN_PER_FEATURE * width
        print(f"\n=== Width N = {width} | N_train = {n_train} ===")

        train_errs = []
        test_errs  = []

        for run in range(RUNS_PER_WIDTH):
            seed = 1000 * width + run
            
            # 1. Generate Data
            xt_train, vu_train = generate_dataset(n_train)
            
            # 2. Init Model (Fixed features)
            model = RFNNnD(
                input_dim=2, # t, x
                num_random_features=width, 
                feature_scale=FEATURE_SCALE,
                use_fourier=USE_FOURIER_FEATURES,
                input_scales=INPUT_SCALES,
                seed=seed
            ).to(DEVICE)
            
            # 3. Batched Solve
            w = batched_ridge_solve(model, xt_train, vu_train, BATCH_SIZE)
            
            # 4. Predict
            vu_pred_train = batched_predict(model, w, xt_train, BATCH_SIZE)
            vu_pred_test  = batched_predict(model, w, xt_test, BATCH_SIZE)
            
            tr = rel_L2_error(vu_pred_train, vu_train)
            te = rel_L2_error(vu_pred_test, vu_test)
            
            train_errs.append(tr)
            test_errs.append(te)

        train_errs = np.array(train_errs)
        test_errs  = np.array(test_errs)

        train_means.append(train_errs.mean())
        test_means.append(test_errs.mean())
        
        print(f"Width {width}: Test Error = {test_means[-1]:.4e}")

    # -------------------------
    # PLOTTING CODE
    # -------------------------
    W = np.array(WIDTH_LIST, dtype=float)
    E = np.array(test_means, dtype=float)

    # 1. Calculate the constant C for the reference curve 
    C = E[0] * np.sqrt(W[0])

    # 2. Generate a dense set of N values
    W_smooth = np.linspace(W.min(), W.max(), 200)
    E_ref_smooth = C / np.sqrt(W_smooth)

    plt.figure(figsize=(8, 6))
    
    # Plot the Reference Curve
    plt.plot(W_smooth, E_ref_smooth, 'k--', linewidth=2, alpha=0.7, label=r'$C / \sqrt{N}$')

    # Plot Actual Data
    plt.plot(W, E, 'bo-', linewidth=2, markersize=8, label='NS RaNN Error')

    # Formatting
    plt.xlabel(r'Width ($N$)')
    plt.ylabel('Relative $L^2$ Error')
    plt.title(f'1D Navier-Stokes Shock Convergence')
    plt.legend(fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.6)
    
    plt.ylim(bottom=0)
    
    plt.tight_layout()
    plt.savefig(f"scaling_linear_curve_NS.png", dpi=150)
    plt.show()

if __name__ == "__main__":
    main()