# src/data_and_models.py

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

class BurgersSimulator:
    """A high-fidelity numerical solver for the 1D viscous Burgers' equation."""
    def __init__(self, config):
        self.config = config
        self.dx = config['L'] / (config['NX_SOLVER'] - 1)
        self.dt = config['T_FINAL'] / (config['NT_SOLVER'] - 1)
        
        lambda_ = config['VISCOSITY'] * self.dt / (2 * self.dx**2)
        
        A_main_diag = np.full(config['NX_SOLVER'], 1 + 2 * lambda_)
        A_off_diag = np.full(config['NX_SOLVER'], -lambda_)
        B_main_diag = np.full(config['NX_SOLVER'], 1 - 2 * lambda_)
        B_off_diag = np.full(config['NX_SOLVER'], lambda_)

        self.A_cn = spdiags([A_off_diag, A_main_diag, A_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        self.B_cn = spdiags([B_off_diag, B_main_diag, B_off_diag], [-1, 0, 1], m=config['NX_SOLVER'], n=config['NX_SOLVER'], format='csc')
        
        # Dirichlet boundary conditions (v=0 at ends)
        self.A_cn[0, 0], self.A_cn[0, 1] = 1, 0
        self.A_cn[-1, -1], self.A_cn[-1, -2] = 1, 0
        self.B_cn[0, 0], self.B_cn[0, 1] = 1, 0
        self.B_cn[-1, -1], self.B_cn[-1, -2] = 1, 0
        
        self.basis_funcs = self._create_basis_functions()

    def _create_basis_functions(self):
        x = np.linspace(0, self.config['L'], self.config['NX_SOLVER'])
        basis = np.zeros((self.config['NUM_BASIS_FUNCTIONS'], self.config['NX_SOLVER']))
        for i in range(self.config['NUM_BASIS_FUNCTIONS']):
            basis[i, :] = np.sin((i + 1) * np.pi * x / self.config['L'])
        return basis
    
    def _advect(self, v):
        adv = np.zeros_like(v)
        dvdx_fwd = (v[2:] - v[1:-1]) / self.dx
        dvdx_bwd = (v[1:-1] - v[:-2]) / self.dx
        v_interior = v[1:-1]
        adv[1:-1] = np.where(v_interior > 0, v_interior * dvdx_bwd, v_interior * dvdx_fwd)
        return adv

    def step(self, current_state, u_weights):
        force = np.einsum('i,ij->j', u_weights, self.basis_funcs)
        rhs = self.B_cn @ current_state - self.dt * self._advect(current_state) + self.dt * force
        rhs[0], rhs[-1] = 0, 0 # Enforce BC on the right-hand side
        return spsolve(self.A_cn, rhs)

def generate_grf_time_series(config, num_series, length_scale_range=(0.8, 2.5)):
    t = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    length_scale = np.random.uniform(*length_scale_range)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(config['NT_SOLVER'])
    return np.random.multivariate_normal(np.zeros(config['NT_SOLVER']), cov_matrix, size=num_series).T

def create_burgers_dataset(config, num_simulations, filename):
    print(f"--- Generating Burgers' Equation Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    simulator = BurgersSimulator(config)
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    x_grid_sensors = np.linspace(0, config['L'], config['M_SENSORS'])
    
    state_sequences_on_sensors = []
    control_sequences = []

    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")

        initial_state = np.zeros(config['NX_SOLVER'])
        for _ in range(np.random.randint(1, 4)):
            initial_state += np.random.uniform(-1, 1) * np.sin(np.random.randint(1, 4) * np.pi * x_grid_solver)
        
        w_sequence = generate_grf_time_series(config, config['NUM_BASIS_FUNCTIONS'])
        w_sequence = np.clip(w_sequence * 1.5, -config['CONTROL_SCALE'], config['CONTROL_SCALE'])

        current_state = initial_state
        state_history_fine_grid = [current_state]
        for k in range(config['NT_SOLVER'] - 1):
            next_state = simulator.step(current_state, w_sequence[k, :])
            state_history_fine_grid.append(next_state)
            current_state = next_state
        state_history_fine_grid = np.array(state_history_fine_grid)

        interpolator = interp1d(x_grid_solver, state_history_fine_grid, axis=1, kind='cubic', fill_value="extrapolate")
        state_history_sensor_grid = interpolator(x_grid_sensors)
            
        state_sequences_on_sensors.append(state_history_sensor_grid)
        control_sequences.append(w_sequence)

    np.savez_compressed(
        filename,
        state_sequences=np.array(state_sequences_on_sensors),
        control_sequences=np.array(control_sequences)
    )
    print(f"Dataset saved to {filename}")

class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        branch_input_size = M_sensors + num_basis_functions
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x_k_sensors, u_k_weights, x_locations):
        branch_input = torch.cat([x_k_sensors, u_k_weights], dim=1)
        branch_out = self.branch(branch_input)
        B = x_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, control_scale,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        self.control_scale = control_scale
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, x_k, x_final, hidden_state=None):
        lstm_input = torch.cat([x_k, x_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        w_k_normalized = self.output_head(lstm_out.squeeze(1))
        # Scale the Tanh output to the desired control bounds
        w_k = w_k_normalized * self.control_scale
        return w_k, new_hidden_state