

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

class BurgersSimulator:
    """A high-fidelity numerical solver for the 1D viscous Burgers' equation."""
    def __init__(self, config):
        self.config = config
        self.dx = config['L'] / (config['NX_SOLVER'] - 1)
        self.dt = config['T_FINAL'] / (config['NT_SOLVER'] - 1)
        
        lambda_ = config['VISCOSITY'] * self.dt / (2 * self.dx**2)
        
        A_main_diag = np.full(config['NX_SOLVER'], 1 + 2 * lambda_)
        A_off_diag = np.full(config['NX_SOLVER'], -lambda_)
        B_main_diag = np.full(config['NX_SOLVER'], 1 - 2 * lambda_)
        B_off_diag = np.full(config['NX_SOLVER'], lambda_)

        self.A_cn = spdiags([A_off_diag, A_main_diag, A_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        self.B_cn = spdiags([B_off_diag, B_main_diag, B_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        
        # Dirichlet boundary conditions (v=0 at ends)
        self.A_cn[0, 0], self.A_cn[0, 1] = 1, 0
        self.A_cn[-1, -1], self.A_cn[-1, -2] = 1, 0
        self.B_cn[0, 0], self.B_cn[0, 1] = 1, 0
        self.B_cn[-1, -1], self.B_cn[-1, -2] = 1, 0

    def _advect(self, v):
        adv = np.zeros_like(v)
        dvdx_fwd = (v[2:] - v[1:-1]) / self.dx
        dvdx_bwd = (v[1:-1] - v[:-2]) / self.dx
        v_interior = v[1:-1]
        adv[1:-1] = np.where(v_interior > 0, v_interior * dvdx_bwd, v_interior * dvdx_fwd)
        return adv

    def step(self, current_state, u_control_on_grid):
        """
        MODIFIED: Accepts a control vector defined on the solver grid directly.
        """
        force = u_control_on_grid
        rhs = self.B_cn @ current_state - self.dt * self._advect(current_state) + self.dt * force
        rhs[0], rhs[-1] = 0, 0 # Enforce BC on the right-hand side
        return spsolve(self.A_cn, rhs)

def generate_2d_grf_fast(x_grid, t_grid, num_samples, x_scale, t_scale):

    nx, nt = len(x_grid), len(t_grid)
    
    # 1. Compute the temporal and spatial covariance matrices
    t_dist = np.abs(t_grid[:, None] - t_grid[None, :])
    cov_t = np.exp(-0.5 * (t_dist**2) / (t_scale**2)) + 1e-6 * np.eye(nt)
    
    x_dist = np.abs(x_grid[:, None] - x_grid[None, :])
    cov_x = np.exp(-0.5 * (x_dist**2) / (x_scale**2)) + 1e-6 * np.eye(nx)
    
    # 2. Compute the Cholesky decomposition of the smaller matrices
    # These operations are fast: O(nt^3) and O(nx^3)
    L_t = np.linalg.cholesky(cov_t)
    L_x = np.linalg.cholesky(cov_x)
    
    # 3. Generate samples efficiently
    samples = np.zeros((num_samples, nt, nx))
    for i in range(num_samples):
        # Generate a matrix of uncorrelated standard normal noise
        Z_uncorrelated = np.random.randn(nt, nx)
        
        # Apply the Cholesky factors to introduce the desired spatial and temporal correlations
        # Y = L_t @ Z @ L_x.T
        correlated_sample = L_t @ Z_uncorrelated @ L_x.T
        samples[i, :, :] = correlated_sample
        
    return samples

def create_burgers_dataset(config, num_simulations, filename):
    print(f"--- Generating Burgers' Equation Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    simulator = BurgersSimulator(config)
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    t_grid_solver = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    x_grid_sensors = np.linspace(0, config['L'], config['M_SENSORS'])
    
    # Generate all control fields u(x, t) at once using a 2D GRF
    u_xt_fields = generate_2d_grf_fast(x_grid_solver, t_grid_solver, num_simulations, x_scale=0.5, t_scale=1.5)

    state_sequences_on_sensors = []
    control_sequences_on_sensors = []

    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")

        initial_state = np.zeros(config['NX_SOLVER'])
        for _ in range(np.random.randint(1, 4)):
            initial_state += np.random.uniform(-1, 1) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
        
        # Get the full u(x,t) field for this simulation and scale it
        u_xt_sequence = np.clip(u_xt_fields[i] * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])

        current_state = initial_state
        state_history_fine_grid = [current_state]
        for k in range(config['NT_SOLVER'] - 1):
            # Pass the control slice u(x, t_k) directly to the simulator
            next_state = simulator.step(current_state, u_xt_sequence[k, :])
            state_history_fine_grid.append(next_state)
            current_state = next_state
        state_history_fine_grid = np.array(state_history_fine_grid)

        # Interpolate the state history onto the sensor grid
        state_interpolator = interp1d(x_grid_solver, state_history_fine_grid, axis=1, kind='cubic', fill_value="extrapolate")
        state_history_sensor_grid = state_interpolator(x_grid_sensors)
        
        # Interpolate the control history onto the sensor grid to be saved as the new target
        control_interpolator = interp1d(x_grid_solver, u_xt_sequence, axis=1, kind='linear', fill_value="extrapolate")
        control_history_sensor_grid = control_interpolator(x_grid_sensors)
            
        state_sequences_on_sensors.append(state_history_sensor_grid)
        control_sequences_on_sensors.append(control_history_sensor_grid)

    np.savez_compressed(
        filename,
        state_sequences=np.array(state_sequences_on_sensors),
        control_sequences=np.array(control_sequences_on_sensors) # Save control values at sensors
    )
    print(f"Dataset saved to {filename}")

class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        # MODIFIED: Input is state at sensors + control u at sensors
        branch_input_size = M_sensors + M_sensors
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x_k_sensors, u_k_at_sensors, x_locations):
        # The variable u_k_at_sensors now represents direct control values, not weights
        branch_input = torch.cat([x_k_sensors, u_k_at_sensors], dim=1)
        branch_out = self.branch(branch_input)
        B = x_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, control_scale,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        self.control_scale = control_scale
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, M_sensors), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, x_k, x_final, hidden_state=None):
        lstm_input = torch.cat([x_k, x_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        u_k_normalized = self.output_head(lstm_out.squeeze(1))
        # Scale the Tanh output to the desired control bounds
        u_k_at_sensors = u_k_normalized * self.control_scale
        return u_k_at_sensors, new_hidden_state