#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations

import math
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from typing import Tuple

# -------------------------
# Global Config
# -------------------------
torch.manual_seed(40)
np.random.seed(40)

# Problem params
SPATIAL_DIM = 1          # d
M_EXPONENT  = 2.0        # m in PME
T0_IC       = 1e-3       # shift time
T_DOMAIN    = 0.1       # t in [0, T]
L_VEC       = [1.0] * SPATIAL_DIM
SUPPORT_FRAC = 0.60      

# RaNN params
USE_FOURIER_FEATURES   = True
NON_FOURIER_ACTIVATION = 'tanh'
FEATURE_SCALE          = 7.0       
POSITIVE_HEAD          = 'square'
INPUT_SCALES           = [1.0] * SPATIAL_DIM + [10.0]

# SCALING CONFIG
# need enough samples so that integration error < approximation error
N_TRAIN_PER_FEATURE    = 10       
BATCH_SIZE             = 5000      # <--- Process samples in chunks to save RAM
N_TEST_GLOBAL          = 20000     
WIDTH_LIST             = [20, 50, 100, 150, 250, 500, 750, 1000, 1500, 2000, 2500]
RUNS_PER_WIDTH         = 1        

LEARNING_RATE = 1e-3 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
torch.set_float32_matmul_precision("high")

# -------------------------
# Utilities
# -------------------------
def _to_tensor(x, device=DEVICE, dtype=torch.float32):
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.tensor(x, device=device, dtype=dtype)

def rand_in_box(n: int, L_vec: list[float]) -> torch.Tensor:
    d = len(L_vec)
    u = torch.rand(n, d, device=DEVICE)
    L = _to_tensor(L_vec).view(1, d)
    return u * L

# -------------------------
# RaNN Definition
# -------------------------
class RFNNnD(nn.Module):
    def __init__(self, spatial_dim: int, num_random_features: int, 
                 feature_scale: float = 1.0, *, use_fourier: bool = True,
                 input_scales: list[float] = None, seed: int = 12345):
        super().__init__()
        torch.manual_seed(seed)
        self.d = spatial_dim
        self.input_dim = self.d + 1
        self.num_random_features = int(num_random_features)
        self.use_fourier = bool(use_fourier)
        
        if input_scales is None:
            input_scales = [1.0] * self.d + [10.0]
        self.register_buffer("input_scales", _to_tensor(input_scales).view(1, -1))

        if self.use_fourier:
            assert self.num_random_features % 2 == 0
            rf = self.num_random_features // 2
            self.random_weights = torch.randn(rf, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(rf, device=DEVICE) * 2 * math.pi
        else:
            self.random_weights = torch.randn(self.num_random_features, self.input_dim, device=DEVICE) * feature_scale
            self.random_biases = torch.rand(self.num_random_features, device=DEVICE)

        self.random_weights = nn.Parameter(self.random_weights, requires_grad=False)
        self.random_biases = nn.Parameter(self.random_biases, requires_grad=False)

    def _feature_map(self, xt: torch.Tensor) -> torch.Tensor:
        # xt: [Batch, d+1]
        x_scaled = xt * self.input_scales
        proj = x_scaled @ self.random_weights.t() # [Batch, rf]

        if self.use_fourier:
            cosf = torch.cos(proj + self.random_biases)
            sinf = torch.sin(proj + self.random_biases)
            scale = (self.num_random_features // 2) ** (-0.5)
            return scale * torch.cat([cosf, sinf], dim=1)
        else:
            return torch.tanh(proj + self.random_biases)

# -------------------------
# Physics (Barenblatt)
# -------------------------
def barenblatt_true_shifted_dd(t_model, x, L_vec, m, t0, support_fraction):
    d = x.shape[1]
    x0 = _to_tensor([0.5 * L for L in L_vec]).view(1, d)

    lam = 2.0 + d * (m - 1.0)
    alpha = d / lam
    beta = 1.0 / lam
    B_const = (m - 1.0) * beta / (2.0 * m)

    R0 = support_fraction * (min(L_vec) / 2.0)
    A = B_const * (R0**2) / (t0 ** (2.0 * beta))

    t_phys = t_model + t0
    r2 = torch.sum((x - x0) ** 2, dim=1, keepdim=True)
    z = A - B_const * r2 / (t_phys ** (2.0 * beta))
    z = torch.clamp(z, min=0.0)
    u = (t_phys ** (-alpha)) * (z ** (1.0 / (m - 1.0)))
    return u

def generate_dataset(n_samples: int):
    n_focus = n_samples // 2
    n_global = n_samples - n_focus

    x_global = rand_in_box(n_global, L_VEC)
    t_global = torch.rand(n_global, 1, device=DEVICE) * T_DOMAIN

    d = len(L_VEC)
    # Focus on center [0.2, 0.8]
    x_focus = torch.rand(n_focus, d, device=DEVICE) * 0.6 + 0.2
    t_focus = torch.rand(n_focus, 1, device=DEVICE) * T_DOMAIN

    x = torch.cat([x_global, x_focus], dim=0)
    t = torch.cat([t_global, t_focus], dim=0)
    
    u = barenblatt_true_shifted_dd(t, x, L_VEC, M_EXPONENT, T0_IC, SUPPORT_FRAC)
    xt = torch.cat([x, t], dim=1)
    
    perm = torch.randperm(xt.size(0), device=DEVICE)
    return xt[perm], u[perm]

def rel_L2_error(u_pred, u_true):
    num = torch.sum((u_pred - u_true) ** 2)
    den = torch.sum(u_true ** 2) + 1e-16
    return float(torch.sqrt(num / den).item())

# -------------------------
# MEMORY EFFICIENT SOLVER
# -------------------------
def batched_ridge_solve(model, xt_train, u_train, batch_size=5000, lam=1e-5):
    """
    Computes (Phi^T Phi + lam*I)^-1 Phi^T u using batches to save memory.
    Complexity: O(M^2) memory, regardless of N_samples.
    """
    N_samples = xt_train.shape[0]
    M_features = model.num_random_features
    
    # Accumulators for Normal Equations
    # G = Phi^T * Phi
    # B = Phi^T * u
    G = torch.zeros((M_features, M_features), device=DEVICE)
    B = torch.zeros((M_features, 1), device=DEVICE)
    
    # Process in chunks
    for i in range(0, N_samples, batch_size):
        end = min(i + batch_size, N_samples)
        xt_batch = xt_train[i:end]
        u_batch = u_train[i:end]
        
        # Forward pass just for this batch
        with torch.no_grad():
            phi_batch = model._feature_map(xt_batch) # [batch, M]
            
        # Accumulate
        G += phi_batch.T @ phi_batch
        B += phi_batch.T @ u_batch
        
        del phi_batch # Free memory immediately
        
    # Add Regularization
    G += lam * torch.eye(M_features, device=DEVICE) * N_samples
    
    # Solve linear system
    # using cholesky solve is faster/stable for positive definite matrices
    try:
        L = torch.linalg.cholesky(G)
        w = torch.cholesky_solve(B, L)
    except:
        # Fallback if slightly ill-conditioned
        w = torch.linalg.solve(G, B)
        
    return w

def batched_predict(model, w, xt_in, batch_size=5000):
    N = xt_in.shape[0]
    preds = []
    for i in range(0, N, batch_size):
        end = min(i + batch_size, N)
        with torch.no_grad():
            phi = model._feature_map(xt_in[i:end])
            preds.append(phi @ w)
    return torch.cat(preds, dim=0)

# -------------------------
# Main Loop
# -------------------------
def main():
    print(f"Generating global test dataset ({N_TEST_GLOBAL} points)...")
    xt_test, u_test = generate_dataset(N_TEST_GLOBAL)

    train_means, train_stds = [], []
    test_means,  test_stds  = [], []

    for width in WIDTH_LIST:
        n_train = N_TRAIN_PER_FEATURE * width
        print(f"\n=== Width N = {width} | N_train = {n_train} ===")

        train_errs = []
        test_errs  = []

        for run in range(RUNS_PER_WIDTH):
            seed = 1000 * width + run
            
            # 1. Generate Data
            xt_train, u_train = generate_dataset(n_train)
            
            # 2. Init Model (Fixed features)
            model = RFNNnD(
                SPATIAL_DIM, width, 
                feature_scale=FEATURE_SCALE,
                use_fourier=USE_FOURIER_FEATURES,
                input_scales=INPUT_SCALES,
                seed=seed
            ).to(DEVICE)
            
            # 3. Batched Solve (The Fix)
            w = batched_ridge_solve(model, xt_train, u_train, BATCH_SIZE)
            
            # 4. Predict
            u_pred_train = batched_predict(model, w, xt_train, BATCH_SIZE)
            u_pred_test  = batched_predict(model, w, xt_test, BATCH_SIZE)
            
            tr = rel_L2_error(u_pred_train, u_train)
            te = rel_L2_error(u_pred_test, u_test)
            
            train_errs.append(tr)
            test_errs.append(te)

        train_errs = np.array(train_errs)
        test_errs  = np.array(test_errs)

        train_means.append(train_errs.mean())
        test_means.append(test_errs.mean())
        
        print(f"Width {width}: Test Error = {test_means[-1]:.4e}")

# -------------------------
    # PLOTTING CODE ONLY
    # -------------------------
    W = np.array(WIDTH_LIST, dtype=float)
    E = np.array(test_means, dtype=float)

    # 1. Calculate C for the reference curve 
    C = E[0] * np.sqrt(W[0])

    # 2. Generate a dense set of N values to draw a smooth curve
    W_smooth = np.linspace(W.min(), W.max(), 200)
    E_ref_smooth = C / np.sqrt(W_smooth)

    plt.figure(figsize=(8, 6))
    
    # Plot Reference Curve ( C/sqrt(N) line)
    plt.plot(W_smooth, E_ref_smooth, 'k--', linewidth=2, alpha=0.7, label=r'$C / \sqrt{N}$')

    # Plot  Actual Data 
    plt.plot(W, E, 'bo-', linewidth=2, markersize=8, label='RaNN Error')

    # Formatting
    plt.xlabel(r'Width ($N$)')
    plt.ylabel('Relative $L^2$ Error')
    plt.title(f'{SPATIAL_DIM}D Convergence: Linear Scale')
    plt.legend(fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.6)

    plt.ylim(bottom=0)
    
    plt.tight_layout()
    plt.savefig(f"scaling_linear_curve_d{SPATIAL_DIM}.png", dpi=150)
    plt.show()

if __name__ == "__main__":
    main()