# src/data_and_models_2d.py
# FINAL version for the 2D Propagator project.
# Adapted from the 1D version to handle 2D heat equation.

import torch
import torch.nn as nn
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy.sparse import spdiags, eye, kronsum
from scipy.sparse.linalg import spsolve
import os

# --- 1. Propagator Model Definition (Unchanged from 1D) ---
# This model is general enough to work for 2D with adjusted input dimensions,
# which are determined by the config file.
class PropagatorDeepONet(nn.Module):
    def __init__(self, M_sensors, num_basis_functions, trunk_input_dim,
                 branch_depth, branch_width, trunk_depth, trunk_width, latent_dim, activation_fn):
        super(PropagatorDeepONet, self).__init__()
        
        # The branch input size is now based on the total number of sensors (NX*NY)
        # and total basis functions (NB_X*NB_Y).
        branch_input_size = M_sensors + num_basis_functions
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        branch_layers = [nn.Linear(branch_input_size, branch_width), activation]
        for _ in range(branch_depth - 1):
            branch_layers.extend([nn.Linear(branch_width, branch_width), activation])
        branch_layers.append(nn.Linear(branch_width, latent_dim))
        self.branch = nn.Sequential(*branch_layers)
        
        # The trunk input dimension is now 2 for (x, y) coordinates.
        trunk_layers = [nn.Linear(trunk_input_dim, trunk_width), activation]
        for _ in range(trunk_depth - 1):
            trunk_layers.extend([nn.Linear(trunk_width, trunk_width), activation])
        trunk_layers.append(nn.Linear(trunk_width, latent_dim))
        self.trunk = nn.Sequential(*trunk_layers)
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, T_k_sensors, w_k, xy_locations):
        branch_input = torch.cat([T_k_sensors, w_k], dim=1)
        branch_out = self.branch(branch_input)
        
        # xy_locations shape: [Batch, Num_Sensors, 2]
        trunk_out = self.trunk(xy_locations)
        
        output = torch.einsum('bi,bsi->bs', branch_out, trunk_out)
        return output.unsqueeze(-1) + self.bias

# --- 2. Recurrent Controller Definition (Unchanged from 1D) ---
# This model is also general. It takes flattened 2D state/target grids
# and outputs weights for the 2D basis functions.
class RecurrentController(nn.Module):
    def __init__(self, M_sensors, num_basis_functions,
                 hidden_dim, num_layers, activation_fn):
        super(RecurrentController, self).__init__()
        
        # Input dim is 2 * total number of sensors (current state + final target)
        controller_input_dim = M_sensors + M_sensors
        
        if activation_fn.lower() == 'relu': activation = nn.ReLU()
        elif activation_fn.lower() == 'tanh': activation = nn.Tanh()
        else: activation = nn.ReLU()

        self.lstm = nn.LSTM(
            input_size=controller_input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        output_head_layers = [nn.Linear(hidden_dim, hidden_dim // 2), activation,
                              nn.Linear(hidden_dim // 2, num_basis_functions), nn.Tanh()]
        self.output_head = nn.Sequential(*output_head_layers)

    def forward(self, T_k, T_final, hidden_state=None):
        lstm_input = torch.cat([T_k, T_final], dim=1).unsqueeze(1)
        lstm_out, new_hidden_state = self.lstm(lstm_input, hidden_state)
        w_k = self.output_head(lstm_out.squeeze(1))
        return w_k, new_hidden_state

# --- 3. Ground Truth 2D PDE Solver (CORRECTED) ---
def solve_pde_2d(config, u_control_sequence):
    NX, NY, NT = config['NX_SOLVER'], config['NY_SOLVER'], config['NT_SOLVER']
    dx, dy = config['L_X'] / (NX - 1), config['L_Y'] / (NY - 1)
    dt = config['T_FINAL'] / (NT - 1)
    
    # Create 1D Laplacian operators for Neumann boundary conditions
    def laplacian_1d(N, d_space):
        main_diag = -2 * np.ones(N)
        # --- FIX: Changed N-1 to N ---
        off_diag = np.ones(N)
        # ---------------------------
        D2 = spdiags([off_diag, main_diag, off_diag], [-1, 0, 1], N, N, format='csc')
        D2[0, 1] = 2
        D2[-1, -2] = 2
        return D2 / d_space**2

    D2x = laplacian_1d(NX, dx)
    D2y = laplacian_1d(NY, dy)
    
    # Create 2D Laplacian using Kronecker sum
    L_2D = kronsum(D2y, D2x)
    
    # Crank-Nicolson matrices
    I = eye(NX * NY)
    A = I - 0.5 * config['D'] * dt * L_2D + 0.5 * config['BETA'] * dt * I
    B = I + 0.5 * config['D'] * dt * L_2D - 0.5 * config['BETA'] * dt * I

    V_current = np.full(NX * NY, config['INITIAL_STATE_VAL'])
    V_history = [V_current.reshape((NY, NX)).copy()]

    for k in range(NT - 1):
        u_avg = (u_control_sequence[k] + u_control_sequence[k+1]) / 2.0
        source_term = config['ALPHA'] * u_avg.flatten() + config['BETA'] * config['V_REF_VAL']
        
        b_vec = B @ V_current + source_term * dt
        V_current = spsolve(A, b_vec)
        V_history.append(V_current.reshape((NY, NX)).copy())
        
    return np.array(V_history)

# --- 4. Data Generation Logic for 2D (Unchanged) ---
def generate_grf_time_series(config, num_steps, num_series, length_scale):
    t = np.linspace(0, config['T_FINAL'], num_steps)
    dist_matrix = np.abs(t[:, None] - t[None, :])
    cov_matrix = np.exp(-0.5 * (dist_matrix**2) / (length_scale**2)) + 1e-6 * np.eye(num_steps)
    return np.random.multivariate_normal(np.zeros(num_steps), cov_matrix, size=num_series).T

def create_recurrent_dataset_2d(config, num_simulations, filename):
    print(f"--- Generating 2D Recurrent Dataset: {filename} ---")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    # Create solver and sensor grids
    x_solver = np.linspace(0, config['L_X'], config['NX_SOLVER'])
    y_solver = np.linspace(0, config['L_Y'], config['NY_SOLVER'])
    x_sensors = np.linspace(0, config['L_X'], config['NX_SENSORS'])
    y_sensors = np.linspace(0, config['L_Y'], config['NY_SENSORS'])
    
    # Create 2D basis functions
    NB_X, NB_Y = config['NUM_BASIS_X'], config['NUM_BASIS_Y']
    xx, yy = np.meshgrid(x_solver, y_solver)
    basis_functions = []
    for i in range(NB_X):
        for j in range(NB_Y):
            basis = np.cos(i * np.pi * xx / config['L_X']) * np.cos(j * np.pi * yy / config['L_Y'])
            basis_functions.append(basis)
    basis_functions = np.array(basis_functions).reshape(NB_X * NB_Y, -1).T # Shape: (NX*NY, NB_X*NB_Y)

    control_sequences, state_sequences_at_sensors = [], []
    for sim_idx in range(num_simulations):
        if (sim_idx + 1) % 50 == 0: print(f"  Generating simulation {sim_idx+1}/{num_simulations}...")
        
        length_scale = np.random.uniform(0.8, 2.5)
        w_sequence = generate_grf_time_series(config, config['NT_SOLVER'], NB_X * NB_Y, length_scale)
        w_sequence = np.clip(w_sequence * 0.7, -1.0, 1.0)
        
        # Create the time-varying control field u(t, x, y)
        u_txy_sequence = (w_sequence @ basis_functions.T).reshape(-1, config['NY_SOLVER'], config['NX_SOLVER'])
        
        V_txy_solution = solve_pde_2d(config, u_txy_sequence)
        
        # Interpolate solution onto the sensor grid
        interpolator = RegularGridInterpolator((y_solver, x_solver), V_txy_solution[0]) # Test with one slice
        sensor_points_yy, sensor_points_xx = np.meshgrid(y_sensors, x_sensors, indexing='ij')
        sensor_points = np.vstack([sensor_points_yy.ravel(), sensor_points_xx.ravel()]).T
        
        V_at_sensors_t = []
        for t_step in range(config['NT_SOLVER']):
            interpolator.values = V_txy_solution[t_step]
            V_slice_at_sensors = interpolator(sensor_points).reshape(config['NY_SENSORS'], config['NX_SENSORS'])
            V_at_sensors_t.append(V_slice_at_sensors)
        
        # Flatten the spatial dimensions for the dataset
        V_at_sensors_flat = np.array(V_at_sensors_t).reshape(config['NT_SOLVER'], -1)
        
        control_sequences.append(w_sequence)
        state_sequences_at_sensors.append(V_at_sensors_flat)

    np.savez_compressed(filename,
        control_sequences=np.array(control_sequences),
        state_sequences=np.array(state_sequences_at_sensors))
    print(f"2D Dataset saved to {filename}")