# src/data_and_models.py
# MODIFIED: Contains PDE solver, data generation, and models for the Burgers' Equation
#           with spatially-varying viscosity nu(x).

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

# =============================================================================
# 1. High-Fidelity Burgers' Equation Simulator for nu(x)
# =============================================================================
class BurgersSimulator:
    """
    A high-fidelity numerical solver for the 1D viscous Burgers' equation
    with spatially-varying viscosity: v_t + v*v_x = (nu(x)*v_x)_x.
    """
    def __init__(self, config):
        self.config = config
        self.L = config['L']
        self.NX = config['NX_SOLVER']
        self.NT = config['NT_SOLVER']
        self.dx = self.L / (self.NX - 1)
        self.dt = config['T_FINAL'] / (self.NT - 1)
        self.x_grid = np.linspace(0, self.L, self.NX)
        
        # Pre-compute basis functions for the control force
        self.basis_funcs = self._create_basis_functions()

    def _create_basis_functions(self):
        basis = np.zeros((self.config['NUM_BASIS_FUNCTIONS'], self.NX))
        for i in range(self.config['NUM_BASIS_FUNCTIONS']):
            basis[i, :] = np.sin((i + 1) * np.pi * self.x_grid / self.L)
        return basis

    # MODIFIED: Build Crank-Nicolson matrices for a specific viscosity profile nu(x)
    def _build_cn_matrices(self, viscosity_profile):
        """Builds the Crank-Nicolson matrices A and B for a given nu(x)."""
        nu = viscosity_profile
        lambda_ = self.dt / (2 * self.dx**2)
        
        # Staggered grid viscosity
        nu_half = (nu[:-1] + nu[1:]) / 2
        
        # Main diagonal of A and B matrices
        A_main = 1 + lambda_ * (np.concatenate(([0], nu_half)) + np.concatenate((nu_half, [0])))
        B_main = 1 - lambda_ * (np.concatenate(([0], nu_half)) + np.concatenate((nu_half, [0])))
        
        # Off-diagonals
        A_upper = -lambda_ * np.concatenate((nu_half, [0]))
        A_lower = -lambda_ * np.concatenate(([0], nu_half))
        B_upper = lambda_ * np.concatenate((nu_half, [0]))
        B_lower = lambda_ * np.concatenate(([0], nu_half))
        
        A = spdiags([A_lower, A_main, A_upper], [-1, 0, 1], self.NX, self.NX, format='csc')
        B = spdiags([B_lower, B_main, B_upper], [-1, 0, 1], self.NX, self.NX, format='csc')
        
        # Enforce Dirichlet boundary conditions (v=0 at ends)
        A[0, 0], A[0, 1] = 1, 0
        A[-1, -1], A[-1, -2] = 1, 0
        B[0, 0], B[0, 1] = 1, 0
        B[-1, -1], B[-1, -2] = 1, 0
        
        return A, B

    def _advect(self, v):
        """Computes the advection term v*v_x using an upwind scheme."""
        adv = np.zeros_like(v)
        # Interior points
        v_interior = v[1:-1]
        dvdx_fwd = (v[2:] - v[1:-1]) / self.dx
        dvdx_bwd = (v[1:-1] - v[:-2]) / self.dx
        adv[1:-1] = np.where(v_interior > 0, v_interior * dvdx_bwd, v_interior * dvdx_fwd)
        return adv

    # MODIFIED: Encapsulated the full simulation run in one method.
    def run(self, initial_state, control_weights_sequence, viscosity_profile):
        """
        Runs a full simulation for a given initial state, control sequence, and viscosity.
        """
        A_cn, B_cn = self._build_cn_matrices(viscosity_profile)
        
        state_history = [initial_state]
        current_state = initial_state.copy()
        
        for k in range(self.NT - 1):
            force = np.einsum('i,ij->j', control_weights_sequence[k, :], self.basis_funcs)
            
            rhs = B_cn @ current_state - self.dt * self._advect(current_state) + self.dt * force
            rhs[0], rhs[-1] = 0, 0 # Enforce BC on the right-hand side
            
            next_state = spsolve(A_cn, rhs)
            state_history.append(next_state)
            current_state = next_state
            
        return np.array(state_history)

# =============================================================================
# 2. Data Generation Logic
# =============================================================================
# NEW: Function to generate spatial GRF for nu(x).
def generate_grf_spatial_series(config, num_series, length_scale):
    x = np.linspace(0, config['L'], config['NX_SOLVER'])
    dist_matrix = np.abs(x[:, None] - x[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(config['NX_SOLVER'])
    # Generate series with mean 0, variance 1
    return np.random.multivariate_normal(np.zeros(config['NX_SOLVER']), cov_matrix, size=num_series)

def generate_grf_time_series(config, num_series, length_scale_range=(0.8, 2.5)):
    t = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    length_scale = np.random.uniform(*length_scale_range)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(config['NT_SOLVER'])
    return np.random.multivariate_normal(np.zeros(config['NT_SOLVER']), cov_matrix, size=num_series).T

# MODIFIED: The dataset creation now includes generating and saving nu(x).
def create_burgers_dataset(config, num_simulations, filename):
    print(f"--- Generating Burgers' Equation Dataset with nu(x): {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    simulator = BurgersSimulator(config)
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    x_grid_sensors = np.linspace(0, config['L'], config['M_SENSORS'])
    
    state_sequences_on_sensors = []
    control_sequences = []
    # NEW: Store viscosity profiles
    viscosity_profiles_on_sensors = []

    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")

        # 1. Generate initial state
        initial_state = np.zeros(config['NX_SOLVER'])
        for _ in range(np.random.randint(1, 4)):
            initial_state += np.random.uniform(-1, 1) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
        
        # 2. Generate time-varying control weights
        w_sequence = generate_grf_time_series(config, config['NUM_BASIS_FUNCTIONS'])
        w_sequence = np.clip(w_sequence * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])

        # 3. NEW: Generate a spatially-varying viscosity profile
        nu_raw = generate_grf_spatial_series(config, 1, config['VISCOSITY_LENGTH_SCALE']).squeeze()
        nu_min, nu_max = config['VISCOSITY_RANGE']
        # Scale GRF to be positive and within the desired range
        viscosity_profile = nu_min + (nu_max - nu_min) * (0.5 * (np.tanh(nu_raw) + 1))

        # 4. Run the simulation
        state_history_fine_grid = simulator.run(initial_state, w_sequence, viscosity_profile)

        # 5. Interpolate results onto the sensor grid
        interpolator_state = interp1d(x_grid_solver, state_history_fine_grid, axis=1, kind='cubic', fill_value="extrapolate")
        state_history_sensor_grid = interpolator_state(x_grid_sensors)
        
        interpolator_nu = interp1d(x_grid_solver, viscosity_profile, kind='cubic', fill_value="extrapolate")
        viscosity_profile_sensor_grid = interpolator_nu(x_grid_sensors)
        
        state_sequences_on_sensors.append(state_history_sensor_grid)
        control_sequences.append(w_sequence)
        viscosity_profiles_on_sensors.append(viscosity_profile_sensor_grid)

    # NEW: Save the viscosity profiles in the dataset file
    np.savez_compressed(
        filename,
        state_sequences=np.array(state_sequences_on_sensors),
        control_sequences=np.array(control_sequences),
        viscosity_profiles=np.array(viscosity_profiles_on_sensors)
    )
    print(f"Dataset saved to {filename}")

# =============================================================================
# 3. Model Architectures (MODIFIED to be physics-aware)
# =============================================================================
class PropagatorDeepONet(nn.Module):
    # MODIFIED: Init signature and branch input size
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        # NEW: Branch input includes state y(x), control weights c, and viscosity nu(x)
        branch_input_size = M_sensors + num_basis_functions + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    # MODIFIED: Forward pass now accepts the viscosity profile
    def forward(self, x_k_sensors, u_k_weights, viscosity_profile, x_locations):
        branch_input = torch.cat([x_k_sensors, u_k_weights, viscosity_profile], dim=1)
        branch_out = self.branch(branch_input)
        
        B = x_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

class RecurrentController(nn.Module):
    # MODIFIED: Init signature and controller input dim
    def __init__(self, M_sensors, num_basis_functions, control_scale,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        self.control_scale = control_scale
        # NEW: Controller input includes current state, target state, and viscosity nu(x)
        controller_input_dim = M_sensors + M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    # MODIFIED: Forward pass now accepts the viscosity profile
    def forward(self, x_k, x_final, viscosity_profile, hidden_state=None):
        lstm_input = torch.cat([x_k, x_final, viscosity_profile], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        
        w_k_normalized = self.output_head(lstm_out.squeeze(1))
        # Scale the Tanh output to the desired control bounds
        w_k = w_k_normalized * self.control_scale
        return w_k, new_hidden_state