# src/data_and_models.py

import torch
import torch.nn as nn
import numpy as np
from scipy.sparse import spdiags
from scipy.sparse.linalg import spsolve
import os

# --- 1. Model Definition ---

class DeepONetWithBias(nn.Module):
    """
    A fully configurable DeepONet model.
    
    The architecture (depth, width, activation) of both the branch and trunk
    networks can be specified at initialization, making it ideal for
    hyperparameter sweeps.
    """
    def __init__(self, branch_input_dim, trunk_input_dim, latent_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, activation_fn):
        super(DeepONetWithBias, self).__init__()

        # Select the activation function based on the input string
        if activation_fn.lower() == 'relu':
            activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh':
            activation = nn.Tanh()
        elif activation_fn.lower() == 'silu':
            activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation function: {activation_fn}")

        # --- Dynamically build the Branch network ---
        branch_layers = [nn.Linear(branch_input_dim, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)

        # --- Dynamically build the Trunk network ---
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        # A learnable scalar bias term
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, u_branch, y_trunk):
        """ Forward pass of the DeepONet. """
        branch_out = self.branch(u_branch)
        trunk_out = self.trunk(y_trunk)
        # Dot product of branch and trunk outputs, plus a bias
        return torch.sum(branch_out * trunk_out, dim=1, keepdim=True) + self.bias


# --- 2. Ground Truth PDE Solver ---

def solve_pde(config, u_control_profile, V_ref_profile):
    """
    Solves the 1D reaction-diffusion PDE using the Crank-Nicolson method.
    All PDE and discretization parameters are read from the config dictionary.
    """
    NX_SOLVER = config['NX_SOLVER']
    NT_SOLVER = config['NT_SOLVER']
    
    dx = config['L'] / (NX_SOLVER - 1)
    dt = config['T_FINAL'] / (NT_SOLVER - 1)
    V_current = np.full(NX_SOLVER, config['INITIAL_STATE_VAL'])
    
    source_term = config['ALPHA'] * u_control_profile + config['BETA'] * V_ref_profile
    lambda_ = config['D'] * dt / (2 * dx**2)

    # Construct the A matrix for the left-hand side of the linear system
    A_diag_val = 1 + 2 * lambda_ + 0.5 * config['BETA'] * dt
    A_diagonals_data = np.array([np.full(NX_SOLVER, -lambda_), np.full(NX_SOLVER, A_diag_val), np.full(NX_SOLVER, -lambda_)])
    A = spdiags(A_diagonals_data, [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    A[0, 1], A[-1, -2] = -2 * lambda_, -2 * lambda_  # Neumann boundary conditions
    
    # Construct the B matrix for the right-hand side of the linear system
    B_diag_val = 1 - 2 * lambda_ - 0.5 * config['BETA'] * dt
    B_diagonals_data = np.array([np.full(NX_SOLVER, lambda_), np.full(NX_SOLVER, B_diag_val), np.full(NX_SOLVER, lambda_)])
    B = spdiags(B_diagonals_data, [-1, 0, 1], NX_SOLVER, NX_SOLVER, format='csc')
    B[0, 1], B[-1, -2] = 2 * lambda_, 2 * lambda_  # Neumann boundary conditions

    V_history = [V_current.copy()]
    for _ in range(NT_SOLVER - 1):
        b = B @ V_current + source_term * dt
        V_next = spsolve(A, b)
        V_current = V_next
        V_history.append(V_current.copy())
        
    return np.array(V_history)


# --- 3. Data Generation Logic ---

def generate_grf_samples(config, num_samples, length_scale):

    x_grid = np.linspace(0, config['L'], config['NX_SOLVER'])
    dist_matrix = np.abs(x_grid[:, None] - x_grid[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2))
    # Add a small jitter for numerical stability
    cov_matrix += 1e-9 * np.eye(config['NX_SOLVER'])
    return np.random.multivariate_normal(np.zeros(config['NX_SOLVER']), cov_matrix, size=num_samples)

def create_deeponet_dataset(config, num_functions, P_samples, filename):

    print(f"--- Generating DeepONet dataset to be saved at {filename} ---")
    
    # Generate a batch of random control functions u(x)
    u_functions = generate_grf_samples(config, num_functions, config['GRF_LENGTH_SCALE'])
    
    branch_inputs, trunk_inputs, outputs = [], [], []

    # Define the grids and sensor locations
    x_grid_solver = np.linspace(0, config['L'], config['NX_SOLVER'])
    t_grid_solver = np.linspace(0, config['T_FINAL'], config['NT_SOLVER'])
    sensor_locations = np.linspace(0, config['L'], config['M_SENSORS'])
    V_ref_profile = np.full(config['NX_SOLVER'], config['V_REF_VAL'])

    for i in range(num_functions):
        if (i + 1) % 100 == 0:
            print(f"Processing function {i+1}/{num_functions}...")
        
        # Clip the control function to specified bounds
        u_profile = np.clip(u_functions[i, :], config['U_MIN'], config['U_MAX'])
        
        # Get the full PDE solution for this control function
        V_solution = solve_pde(config, u_profile, V_ref_profile)
        
        # Sample the control function at the M_SENSORS locations. This is the branch input.
        u_at_sensors = np.interp(sensor_locations, x_grid_solver, u_profile)
        
        # Sample P points from the solution space
        for _ in range(P_samples):
            # Pick a random time and space index
            t_idx, x_idx = np.random.randint(0, config['NT_SOLVER']), np.random.randint(0, config['NX_SOLVER'])
            
            # The branch input is the same for all P samples of this function
            branch_inputs.append(u_at_sensors)
            
            # The trunk input is the specific (x, t) coordinate
            trunk_inputs.append([x_grid_solver[x_idx], t_grid_solver[t_idx]])
            
            # The output is the value of the solution V at that (x, t)
            outputs.append([V_solution[t_idx, x_idx]])

    # The calling script (e.g., generate_data.py) ensures the directory exists
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    np.savez_compressed(
        filename,
        branch_inputs=np.array(branch_inputs),
        trunk_inputs=np.array(trunk_inputs),
        outputs=np.array(outputs)
    )
    print(f"Dataset successfully saved to {filename}")