# src/data_and_models.py
# Contains the PDE solver, data generation, and models for the Burgers' Equation.

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

# =============================================================================
# 1. High-Fidelity Burgers' Equation Simulator
# =============================================================================
class BurgersSimulator:
    """A high-fidelity numerical solver for the 1D viscous Burgers' equation."""
    def __init__(self, config):
        self.config = config
        self.dx = config['L'] / (config['NX_SOLVER'] - 1)
        self.dt = config['T_FINAL'] / (config['NT_SOLVER'] - 1)
        
        lambda_ = config['VISCOSITY'] * self.dt / (2 * self.dx**2)
        
        A_main_diag = np.full(config['NX_SOLVER'], 1 + 2 * lambda_)
        A_off_diag = np.full(config['NX_SOLVER'], -lambda_)
        B_main_diag = np.full(config['NX_SOLVER'], 1 - 2 * lambda_)
        B_off_diag = np.full(config['NX_SOLVER'], lambda_)

        self.A_cn = spdiags([A_off_diag, A_main_diag, A_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        self.B_cn = spdiags([B_off_diag, B_main_diag, B_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        
        # Dirichlet boundary conditions (v=0 at ends)
        self.A_cn[0, 0], self.A_cn[0, 1] = 1, 0
        self.A_cn[-1, -1], self.A_cn[-1, -2] = 1, 0
        self.B_cn[0, 0], self.B_cn[0, 1] = 1, 0
        self.B_cn[-1, -1], self.B_cn[-1, -2] = 1, 0
        
        self.basis_funcs = self._create_basis_functions()

    def _create_basis_functions(self):
        x = np.linspace(0, self.config['L'], self.config['NX_SOLVER'])
        basis = np.zeros((self.config['NUM_BASIS_FUNCTIONS'], self.config['NX_SOLVER']))
        for i in range(self.config['NUM_BASIS_FUNCTIONS']):
            basis[i, :] = np.sin((i + 1) * np.pi * x / self.config['L'])
        return basis
    
    def _advect(self, v):
        adv = np.zeros_like(v)
        dvdx_fwd = (v[2:] - v[1:-1]) / self.dx
        dvdx_bwd = (v[1:-1] - v[:-2]) / self.dx
        v_interior = v[1:-1]
        adv[1:-1] = np.where(v_interior > 0, v_interior * dvdx_bwd, v_interior * dvdx_fwd)
        return adv

    def step(self, current_state, u_weights):
        force = np.einsum('i,ij->j', u_weights, self.basis_funcs)
        rhs = self.B_cn @ current_state - self.dt * self._advect(current_state) + self.dt * force
        rhs[0], rhs[-1] = 0, 0 # Enforce BC on the right-hand side
        return spsolve(self.A_cn, rhs)

# =============================================================================
# 2. Data Generation Logic
# =============================================================================
def generate_grf_time_series(config, num_series, length_scale_range=(0.8, 2.5)):
    t = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    length_scale = np.random.uniform(*length_scale_range)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(config['NT_SOLVER'])
    return np.random.multivariate_normal(np.zeros(config['NT_SOLVER']), cov_matrix, size=num_series).T

def create_burgers_dataset(config, num_simulations, filename):
    print(f"--- Generating Burgers' Equation Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    simulator = BurgersSimulator(config)
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    x_grid_sensors = np.linspace(0, config['L'], config['M_SENSORS'])
    
    state_sequences_on_sensors = []
    control_sequences = []

    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")

        initial_state = np.zeros(config['NX_SOLVER'])
        for _ in range(np.random.randint(1, 4)):
            initial_state += np.random.uniform(-1, 1) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
        
        w_sequence = generate_grf_time_series(config, config['NUM_BASIS_FUNCTIONS'])
        w_sequence = np.clip(w_sequence * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])

        current_state = initial_state
        state_history_fine_grid = [current_state]
        for k in range(config['NT_SOLVER'] - 1):
            next_state = simulator.step(current_state, w_sequence[k, :])
            state_history_fine_grid.append(next_state)
            current_state = next_state
        state_history_fine_grid = np.array(state_history_fine_grid)

        interpolator = interp1d(x_grid_solver, state_history_fine_grid, axis=1, kind='cubic', fill_value="extrapolate")
        state_history_sensor_grid = interpolator(x_grid_sensors)
            
        state_sequences_on_sensors.append(state_history_sensor_grid)
        control_sequences.append(w_sequence)

    np.savez_compressed(
        filename,
        state_sequences=np.array(state_sequences_on_sensors),
        control_sequences=np.array(control_sequences)
    )
    print(f"Dataset saved to {filename}")

# =============================================================================
# 3. Model Architectures (Now fully configurable)
# =============================================================================
class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        branch_input_size = M_sensors + num_basis_functions
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x_k_sensors, u_k_weights, x_locations):
        branch_input = torch.cat([x_k_sensors, u_k_weights], dim=1)
        branch_out = self.branch(branch_input)
        B = x_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, control_scale,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        self.control_scale = control_scale
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, x_k, x_final, hidden_state=None):
        lstm_input = torch.cat([x_k, x_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        w_k_normalized = self.output_head(lstm_out.squeeze(1))
        # Scale the Tanh output to the desired control bounds
        w_k = w_k_normalized * self.control_scale
        return w_k, new_hidden_state