# src/data_and_models.py

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import interp1d
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        branch_input_size = M_sensors + num_basis_functions
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, T_k_sensors, w_k, x_locations):
        branch_input = torch.cat([T_k_sensors, w_k], dim=1)
        branch_out = self.branch(branch_input)
        B = T_k_sensors.shape[0]
        expanded_x_locs = x_locations.expand(B, -1, -1)
        trunk_out = self.trunk(expanded_x_locs)
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, T_k, T_final, hidden_state=None):
        lstm_input = torch.cat([T_k, T_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        w_k = self.output_head(lstm_out.squeeze(1))
        return w_k, new_hidden_state

def solve_pde_time_varying(config, u_control_sequence):
    NX_SOLVER, NT_SOLVER = config['NX_SOLVER'], config['NT_SOLVER']
    dx = config['L'] / (NX_SOLVER - 1)
    dt = config['T_FINAL'] / (NT_SOLVER - 1)
    V_current = np.full(NX_SOLVER, config['INITIAL_STATE_VAL'])
    lambda_ = config['D'] * dt / (2 * dx**2)
    beta_term = 0.5 * config['BETA'] * dt
    A_main = np.full(NX_SOLVER, 1 + 2 * lambda_ + beta_term)
    A_off = np.full(NX_SOLVER, -lambda_)
    A = spdiags([A_off, A_main, A_off], [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    B_main = np.full(NX_SOLVER, 1 - 2 * lambda_ - beta_term)
    B_off = np.full(NX_SOLVER, lambda_)
    B = spdiags([B_off, B_main, B_off], [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    A[0, 1], A[-1, -2] = -2 * lambda_, -2 * lambda_
    B[0, 1], B[-1, -2] = 2 * lambda_, 2 * lambda_

    V_history = [V_current.copy()]
    for k in range(NT_SOLVER - 1):
        avg_u_in_step = (u_control_sequence[k] + u_control_sequence[k+1]) / 2.0
        source_term = config['ALPHA'] * avg_u_in_step + config['BETA'] * config['V_REF_VAL']
        b_vec = B @ V_current + source_term * dt
        V_current = spsolve(A, b_vec)
        V_history.append(V_current.copy())
    return np.array(V_history)

def generate_grf_time_series(config, num_steps, num_series, length_scale):
    t = np.linspace(0, config['T_FINAL'], num_steps)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(num_steps)
    return np.random.multivariate_normal(np.zeros(num_steps), cov_matrix, size=num_series).T

def create_recurrent_dataset(config, num_simulations, filename):
    print(f"--- Generating Recurrent Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    sensor_locs_np = np.linspace(0, config['L'], config['M_SENSORS'])
    basis_functions = np.cos(np.arange(config['NUM_BASIS_FUNCTIONS']) * np.pi * x_grid_solver[:, None] / config['L'])
    
    control_sequences, state_sequences_at_sensors = [], []
    for i in range(num_simulations):
        if (i + 1) % 50 == 0: print(f"  Generating simulation {i+1}/{num_simulations}...")
        length_scale = np.random.uniform(0.8, 2.5)
        w_sequence = generate_grf_time_series(config, config['NT_SOLVER'], config['NUM_BASIS_FUNCTIONS'], length_scale)
        w_sequence = np.clip(w_sequence * 0.7, -1.0, 1.0)
        u_xt_sequence = w_sequence @ basis_functions.T
        V_xt_solution = solve_pde_time_varying(config, u_xt_sequence)
        interpolator = interp1d(x_grid_solver, V_xt_solution, axis=1, kind='cubic', fill_value="extrapolate")
        V_at_sensors = interpolator(sensor_locs_np)
        control_sequences.append(w_sequence)
        state_sequences_at_sensors.append(V_at_sensors)

    np.savez_compressed(filename,
        control_sequences=np.array(control_sequences),
        state_sequences=np.array(state_sequences_at_sensors))
    print(f"Dataset saved to {filename}")