# src/data_and_models.py
# FINAL version for the Propagator project.
# Includes both the PropagatorDeepONet and the RecurrentController models.

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

# --- 1. Propagator Model Definition (Configurable) ---
class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        branch_input_size = M_sensors + num_basis_functions
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, T_k_sensors, w_k, x_locations):
        branch_input = torch.cat([T_k_sensors, w_k], dim=1)
        branch_out = self.branch(branch_input)
        B = T_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

# --- 2. Recurrent Controller Definition (NEW, Configurable) ---
class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, T_k, T_final, hidden_state=None):
        lstm_input = torch.cat([T_k, T_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        w_k = self.output_head(lstm_out.squeeze(1))
        return w_k, new_hidden_state

# --- 3. Ground Truth PDE Solver ---
def solve_pde_time_varying(config, u_control_sequence):
    NX_SOLVER, NT_SOLVER = config['NX_SOLVER'], config['NT_SOLVER']
    dx = config['L'] / (NX_SOLVER - 1)
    dt = config['T_FINAL'] / (NT_SOLVER - 1)
    V_current = np.full(NX_SOLVER, config['INITIAL_STATE_VAL'])
    lambda_ = config['D'] * dt / (2 * dx**2)
    beta_term = 0.5 * config['BETA'] * dt
    A_main = np.full(NX_SOLVER, 1 + 2 * lambda_ + beta_term)
    A_off = np.full(NX_SOLVER, -lambda_)
    A = spdiags([A_off, A_main, A_off], [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    B_main = np.full(NX_SOLVER, 1 - 2 * lambda_ - beta_term)
    B_off = np.full(NX_SOLVER, lambda_)
    B = spdiags([B_off, B_main, B_off], [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    A[0, 1], A[-1, -2] = -2 * lambda_, -2 * lambda_
    B[0, 1], B[-1, -2] = 2 * lambda_, 2 * lambda_

    V_history = [V_current.copy()]
    for k in range(NT_SOLVER - 1):
        avg_u_in_step = (u_control_sequence[k] + u_control_sequence[k+1]) / 2.0
        source_term = config['ALPHA'] * avg_u_in_step + config['BETA'] * config['V_REF_VAL']
        b_vec = B @ V_current + source_term * dt
        V_current = spsolve(A, b_vec)
        V_history.append(V_current.copy())
    return np.array(V_history)

# --- 4. Data Generation Logic ---
def generate_grf_time_series(config, num_steps, num_series, length_scale):
    t = np.linspace(0, config['T_FINAL'], num_steps)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(num_steps)
    return np.random.multivariate_normal(np.zeros(num_steps), cov_matrix, size=num_series).T

def create_recurrent_dataset(config, num_simulations, filename):
    print(f"--- Generating Recurrent Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    basis_functions = np.cos(np.arange(config['NUM_BASIS_FUNCTIONS']) * np.pi * x_grid_solver[:, None] / config['L'])
    
    control_sequences, state_sequences_at_sensors = [], []
    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")
        length_scale = np.random.uniform(0.8, 2.5)
        w_sequence = generate_grf_time_series(config, config['NT_SOLVER'], config['NUM_BASIS_FUNCTIONS'], length_scale)
        w_sequence = np.clip(w_sequence * 0.7, -1.0, 1.0)
        u_xt_sequence = w_sequence @ basis_functions.T
        V_xt_solution = solve_pde_time_varying(config, u_xt_sequence)
        interpolator = interp1d(x_grid_solver, V_xt_solution, axis=1, kind='cubic', fill_value="extrapolate")
        V_at_sensors = interpolator(sensor_locs_np)
        control_sequences.append(w_sequence)
        state_sequences_at_sensors.append(V_at_sensors)

    np.savez_compressed(filename,
        control_sequences=np.array(control_sequences),
        state_sequences=np.array(state_sequences_at_sensors))
    print(f"Dataset saved to {filename}")