
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import time
import scipy.sparse as sp
from scipy.sparse.linalg import spsolve
import argparse
import pickle  # Ensure pickle is imported

# ---------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True



class FeedbackControllerBase:
    """Base class for feedback controllers."""

    def __init__(self, dataset_params, device='cpu'):
        self.params = dataset_params
        self.device = device
        self.reset()

    def reset(self):
        """Resets the internal state of the controller for a new simulation."""
        raise NotImplementedError

    def compute_bc(self, U_hat_k, external_bcs_k, t_k):
        """
        Computes the boundary conditions for the next step (k+1) based on the
        current predicted state U_hat_k.

        Args:
            U_hat_k (torch.Tensor): Predicted solution at step k, shape [1, nx].
            external_bcs_k (torch.Tensor): External parts of the BCs at step k,
                                            containing reference signals or external controls.
            t_k (float): Current time.

        Returns:
            torch.Tensor: The full BC_State+BC_Control vector for step k+1.
        """
        raise NotImplementedError


class ConvDiffIntegralController(FeedbackControllerBase):


    def reset(self):
        self.integral_left = 0.0
        self.integral_right = 0.0

    def compute_bc(self, U_hat_k, external_bcs_k, t_k):

        K_I = 1.0
        dt = self.params['dt']

        u_hat_left = U_hat_k[0, 0]
        u_hat_right = U_hat_k[0, -1]

        mean_u = self.params['norm_factors']['U_mean']
        std_u = self.params['norm_factors']['U_std']
        u_hat_left_phys = u_hat_left * std_u + mean_u
        u_hat_right_phys = u_hat_right * std_u + mean_u


        r_left_k = external_bcs_k[0, 0]
        r_right_k = external_bcs_k[0, 1]

        # Update integral error (discrete approximation)
        self.integral_left += (r_left_k - u_hat_left_phys).item() * dt
        self.integral_right += (r_right_k - u_hat_right_phys).item() * dt

        # Compute control action
        c_left = K_I * self.integral_left
        c_right = K_I * self.integral_right

        # Compute final boundary values
        u_bc_left = r_left_k + c_left
        u_bc_right = r_right_k + c_right



        # Normalize the computed physical BCs before feeding back to the model
        bc_state_means = self.params['norm_factors']['BC_State_means']
        bc_state_stds = self.params['norm_factors']['BC_State_stds']

        u_bc_left_norm = (u_bc_left - bc_state_means[0]) / bc_state_stds[0]
        u_bc_right_norm = (u_bc_right - bc_state_means[1]) / bc_state_stds[1]

        # The external controls also need to be normalized
        bc_ctrl_means = self.params['norm_factors']['BC_Control_means']
        bc_ctrl_stds = self.params['norm_factors']['BC_Control_stds']
        r_left_norm = (r_left_k - bc_ctrl_means[0]) / bc_ctrl_stds[0]
        r_right_norm = (r_right_k - bc_ctrl_means[1]) / bc_ctrl_stds[1]

        computed_bcs = torch.tensor([[
            u_bc_left_norm, u_bc_right_norm, r_left_norm, r_right_norm
        ]], dtype=torch.float32, device=self.device)

        return computed_bcs


class ReactionDiffusionIntegralController(FeedbackControllerBase):

    def reset(self):
        pass  # This controller is stateless

    def compute_bc(self, U_hat_k, external_bcs_k, t_k):

        K_fb = 1.0
        L = 1.0
        nx = self.params['nx']
        dx = L / (nx - 1)

        # Denormalize predictions for physical feedback calculation
        mean_u = self.params['norm_factors']['U_mean']
        std_u = self.params['norm_factors']['U_std']
        U_hat_k_phys = U_hat_k * std_u + mean_u

        # Spatially integrate the predicted state
        integral_u = torch.trapezoid(U_hat_k_phys.squeeze(), dx=dx).item()

        # Get external signals from the ground truth tensor
        # Here BC_Control is assumed to be [g0(t)+c0(t), gL(t)+cL(t)]
        external_left = external_bcs_k[0, 2]  # Index 2 for first control
        external_right = external_bcs_k[0, 3]  # Index 3 for second control

        # The left boundary is a simple Dirichlet
        u_bc_left = external_left

        # The right boundary is the Neumann condition with feedback
        dudx_bc_right = external_right + K_fb * integral_u

        # Normalize values before returning
        bc_state_means = self.params['norm_factors']['BC_State_means']
        bc_state_stds = self.params['norm_factors']['BC_State_stds']
        u_bc_left_norm = (u_bc_left - bc_state_means[0]) / bc_state_stds[0]
        dudx_bc_right_norm = (dudx_bc_right - bc_state_means[1]) / bc_state_stds[1]

        computed_bcs = torch.cat([
            torch.tensor([[u_bc_left_norm, dudx_bc_right_norm]], dtype=torch.float32),
            external_bcs_k[:, 2:]  # Pass through the original normalized control signals
        ], dim=-1).to(self.device)

        return computed_bcs


class HeatNonlinearIntegralController(FeedbackControllerBase):


    def reset(self):
        pass

    def compute_bc(self, U_hat_k, external_bcs_k, t_k):

        K1, K2 = 1.0, 0.5
        L = 1.0
        nx = self.params['nx']
        dx = L / (nx - 1)

        # Denormalize predictions for physical feedback calculation
        mean_u = self.params['norm_factors']['U_mean']
        std_u = self.params['norm_factors']['U_std']
        U_hat_k_phys = U_hat_k * std_u + mean_u

        # Spatially integrate the predicted state
        s = torch.trapezoid(U_hat_k_phys.squeeze(), dx=dx).item()

        # Apply non-linear feedback function F
        feedback_val = K1 * s + K2 * (s ** 2)

        # Get external signals
        external_left = external_bcs_k[0, 2]  # Index 2 for first control
        external_right = external_bcs_k[0, 3]  # Index 3 for second control

        # Left boundary is simple Dirichlet
        u_bc_left = external_left
        # Right boundary has the feedback term
        u_bc_right = external_right + feedback_val

        # Normalize values before returning
        bc_state_means = self.params['norm_factors']['BC_State_means']
        bc_state_stds = self.params['norm_factors']['BC_State_stds']
        u_bc_left_norm = (u_bc_left - bc_state_means[0]) / bc_state_stds[0]
        u_bc_right_norm = (u_bc_right - bc_state_means[1]) / bc_state_stds[1]

        computed_bcs = torch.cat([
            torch.tensor([[u_bc_left_norm, u_bc_right_norm]], dtype=torch.float32),
            external_bcs_k[:, 2:]  # Pass through original normalized control signals
        ], dim=-1).to(self.device)

        return computed_bcs


# =============================================================================
# =============================================================================
class UniversalPDEDataset(Dataset):
    def __init__(self, data_list, dataset_type, train_nt_limit=None):  # Added train_nt_limit

        if not data_list:
            raise ValueError("data_list cannot be empty")
        self.data_list = data_list
        self.dataset_type = dataset_type
        self.train_nt_limit = train_nt_limit  # Store the limit

        first_sample = data_list[0]

        if dataset_type == 'heat_delayed_feedback':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1  # Assuming 1D for this example
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2
        elif dataset_type == 'reaction_diffusion_neumann_feedback':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2
        elif dataset_type == 'heat_nonlinear_feedback_gain':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.ny_from_sample = 1
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = self.ny_from_sample
            self.expected_bc_state_dim = 2
        elif dataset_type == 'convdiff':
            self.nt_from_sample = first_sample['U'].shape[0]
            self.nx_from_sample = first_sample['U'].shape[1]
            self.state_keys = ['U']
            self.num_state_vars = 1
            self.nx = self.nx_from_sample
            self.ny = 1
            self.expected_bc_state_dim = 2
        else:
            raise ValueError(f"Unknown dataset_type: {dataset_type}")

        self.effective_nt = self.train_nt_limit if self.train_nt_limit is not None else self.nt_from_sample
        self.spatial_dim = self.nx * self.ny

        self.bc_state_key = 'BC_State'
        if self.bc_state_key not in first_sample:
            raise KeyError(f"'{self.bc_state_key}' not found in the first sample!")
        actual_bc_state_dim = first_sample[self.bc_state_key].shape[1]

        if actual_bc_state_dim != self.expected_bc_state_dim:
            print(f"Warning: BC_State dimension mismatch for {dataset_type}. "
                  f"Expected {self.expected_bc_state_dim}, got {actual_bc_state_dim}. "
                  f"Using actual dimension: {actual_bc_state_dim}")
            self.bc_state_dim = actual_bc_state_dim
        else:
            self.bc_state_dim = self.expected_bc_state_dim

        self.bc_control_key = 'BC_Control'
        if self.bc_control_key in first_sample and first_sample[self.bc_control_key] is not None:
            if first_sample[self.bc_control_key].ndim == 2:  # Expected shape [nt, num_controls]
                self.num_controls = first_sample[self.bc_control_key].shape[1]
            elif first_sample[self.bc_control_key].ndim == 1 and first_sample[self.bc_control_key].shape[
                0] == self.effective_nt:
                # This case might mean num_controls = 1, and it's [nt] instead of [nt, 1]
                # Or it could be an error in data generation. For now, assume 0 if not 2D.
                print(f"Warning: '{self.bc_control_key}' is 1D. Assuming num_controls = 0 or check data format.")
                self.num_controls = 0  # Or handle as 1 if appropriate
            else:  # Not 2D and not matching nt for 1D case
                self.num_controls = 0
        else:  # Key not found or is None
            self.num_controls = 0

        if self.num_controls == 0 and self.bc_control_key in first_sample and first_sample[
            self.bc_control_key] is not None:
            print(
                f"Info: '{self.bc_control_key}' found but num_controls is 0 for sample 0. Shape: {first_sample[self.bc_control_key].shape if hasattr(first_sample[self.bc_control_key], 'shape') else 'N/A'}")
        elif self.num_controls == 0:
            print(f"Info: '{self.bc_control_key}' not found or is None in the first sample. Assuming num_controls = 0.")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sample = self.data_list[idx]
        norm_factors = {}
        current_nt = self.effective_nt

        state_tensors_norm_list = []
        for key in self.state_keys:
            try:
                state_seq_full = sample[key]
                state_seq = state_seq_full[:current_nt, ...]
            except KeyError:
                raise KeyError(
                    f"State variable key '{key}' not found in sample {idx} for dataset type '{self.dataset_type}'")

            if state_seq.shape[0] != current_nt:
                raise ValueError(
                    f"Time dimension mismatch after potential truncation for {key}. Expected {current_nt}, got {state_seq.shape[0]}")

            state_mean = np.mean(state_seq)
            state_std = np.std(state_seq) + 1e-8  # Add epsilon for stability
            state_norm = (state_seq - state_mean) / state_std
            state_tensors_norm_list.append(torch.tensor(state_norm).float())
            norm_factors[f'{key}_mean'] = state_mean
            norm_factors[f'{key}_std'] = state_std

        # --- BC State ---
        bc_state_seq_full = sample[self.bc_state_key]
        bc_state_seq = bc_state_seq_full[:current_nt, :]

        if bc_state_seq.shape[0] != current_nt:
            raise ValueError(
                f"Time dimension mismatch for BC_State. Expected {current_nt}, got {bc_state_seq.shape[0]}")

        bc_state_norm = np.zeros_like(bc_state_seq, dtype=np.float32)
        norm_factors[f'{self.bc_state_key}_means'] = np.zeros(self.bc_state_dim)
        norm_factors[f'{self.bc_state_key}_stds'] = np.ones(self.bc_state_dim)
        if self.bc_state_dim > 0:  # Ensure there are BC state dimensions to process
            for k_dim in range(self.bc_state_dim):
                col = bc_state_seq[:, k_dim]
                mean_k = np.mean(col)
                std_k = np.std(col)
                if std_k > 1e-8:
                    bc_state_norm[:, k_dim] = (col - mean_k) / std_k
                    norm_factors[f'{self.bc_state_key}_means'][k_dim] = mean_k
                    norm_factors[f'{self.bc_state_key}_stds'][k_dim] = std_k
                else:  # Handle constant columns
                    bc_state_norm[:, k_dim] = col - mean_k  # Centered, std is effectively 1 for norm_factors
                    norm_factors[f'{self.bc_state_key}_means'][k_dim] = mean_k
                    # norm_factors[f'{self.bc_state_key}_stds'][k_dim] remains 1.0
        bc_state_tensor_norm = torch.tensor(bc_state_norm).float()

        # --- BC Control ---
        if self.num_controls > 0:
            try:
                bc_control_seq_full = sample[self.bc_control_key]
                bc_control_seq = bc_control_seq_full[:current_nt, :]

                if bc_control_seq.shape[0] != current_nt:
                    raise ValueError(
                        f"Time dimension mismatch for BC_Control. Expected {current_nt}, got {bc_control_seq.shape[0]}.")
                if bc_control_seq.shape[1] != self.num_controls:
                    raise ValueError(
                        f"Control dimension mismatch for sample {idx}. Expected {self.num_controls}, got {bc_control_seq.shape[1]}")

                bc_control_norm = np.zeros_like(bc_control_seq, dtype=np.float32)
                norm_factors[f'{self.bc_control_key}_means'] = np.zeros(self.num_controls)
                norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
                for k_dim in range(self.num_controls):
                    col = bc_control_seq[:, k_dim]
                    mean_k = np.mean(col)
                    std_k = np.std(col)
                    if std_k > 1e-8:
                        bc_control_norm[:, k_dim] = (col - mean_k) / std_k
                        norm_factors[f'{self.bc_control_key}_means'][k_dim] = mean_k
                        norm_factors[f'{self.bc_control_key}_stds'][k_dim] = std_k
                    else:
                        bc_control_norm[:, k_dim] = col - mean_k
                        norm_factors[f'{self.bc_control_key}_means'][k_dim] = mean_k
                bc_control_tensor_norm = torch.tensor(bc_control_norm).float()

            except KeyError:
                # This case should be less likely if num_controls > 0 was derived from sample[0]
                print(f"Warning: Sample {idx} missing '{self.bc_control_key}' despite num_controls > 0. Using zeros.")
                bc_control_tensor_norm = torch.zeros((current_nt, self.num_controls), dtype=torch.float32)
                norm_factors[f'{self.bc_control_key}_means'] = np.zeros(self.num_controls)
                norm_factors[f'{self.bc_control_key}_stds'] = np.ones(self.num_controls)
        else:  # num_controls is 0
            bc_control_tensor_norm = torch.empty((current_nt, 0), dtype=torch.float32)  # Empty tensor

        # Concatenate normalized BC_State and BC_Control
        bc_ctrl_tensor_norm = torch.cat((bc_state_tensor_norm, bc_control_tensor_norm), dim=-1)
        return state_tensors_norm_list, bc_ctrl_tensor_norm, norm_factors



def compute_pod_basis_generic(data_list, dataset_type, state_variable_key,
                              nx, nt, basis_dim,
                              max_snapshots_pod=100):
    snapshots = []
    count = 0
    current_nx = nx
    lin_interp = np.linspace(0, 1, current_nx)[np.newaxis, :]
    print(f"  Computing POD for '{state_variable_key}' using {nt} timesteps, linear interp U_B...")

    for sample_idx, sample in enumerate(data_list):
        if count >= max_snapshots_pod:
            break
        if state_variable_key not in sample:
            print(f"Warning: Key '{state_variable_key}' not found in sample {sample_idx}. Skipping.")
            continue

        U_seq_full = sample[state_variable_key]
        U_seq = U_seq_full[:nt, :]

        if U_seq.shape[0] != nt:
            print(
                f"Warning: U_seq actual timesteps {U_seq.shape[0]} != requested nt {nt} for POD in sample {sample_idx}. Skipping.")
            continue
        if U_seq.shape[1] != current_nx:
            print(
                f"Warning: Mismatch nx in sample {sample_idx} for {state_variable_key}. Expected {current_nx}, got {U_seq.shape[1]}. Skipping.")
            continue

        bc_left_val = U_seq[:, 0:1]
        bc_right_val = U_seq[:, -1:]

        if np.isnan(bc_left_val).any() or np.isinf(bc_left_val).any() or \
                np.isnan(bc_right_val).any() or np.isinf(bc_right_val).any():
            print(
                f"Warning: NaN/Inf in boundary values for sample {sample_idx}, key '{state_variable_key}'. Skipping sample for POD.")
            continue

        U_B = bc_left_val * (1 - lin_interp) + bc_right_val * lin_interp
        U_star = U_seq - U_B
        snapshots.append(U_star)
        count += 1

    if not snapshots:
        print(
            f"Error: No valid snapshots collected for POD for '{state_variable_key}'. Ensure 'nt' ({nt}) is appropriate.")
        return None
    try:
        all_snapshots_np = np.concatenate(snapshots, axis=0)
    except ValueError as e:
        print(
            f"Error concatenating snapshots for '{state_variable_key}': {e}. Check snapshot shapes. Collected {len(snapshots)} lists of snapshots.")
        return None  # Or attempt recovery as in your original code

    if np.isnan(all_snapshots_np).any() or np.isinf(all_snapshots_np).any():
        print(f"Warning: NaN/Inf found in snapshots for '{state_variable_key}' before POD. Clamping.")
        all_snapshots_np = np.nan_to_num(all_snapshots_np, nan=0.0, posinf=1e6, neginf=-1e6)
        if np.all(np.abs(all_snapshots_np) < 1e-12):
            print(f"Error: All snapshots became zero after clamping for '{state_variable_key}'.")
            return None

    U_mean = np.mean(all_snapshots_np, axis=0, keepdims=True)
    U_centered = all_snapshots_np - U_mean

    try:
        U_data_svd, S_data_svd, Vh_data_svd = np.linalg.svd(U_centered, full_matrices=False)
        rank = np.sum(S_data_svd > 1e-10)
        actual_basis_dim = min(basis_dim, rank, current_nx)
        if actual_basis_dim == 0:
            print(f"Error: Data rank is zero or too low for '{state_variable_key}' after SVD. Effective rank: {rank}")
            return None
        if actual_basis_dim < basis_dim:
            print(
                f"Warning: Requested basis_dim {basis_dim} but effective data rank is ~{rank} for '{state_variable_key}'. Using {actual_basis_dim}.")
        basis = Vh_data_svd[:actual_basis_dim, :].T
    except np.linalg.LinAlgError as e:
        print(f"SVD failed for '{state_variable_key}': {e}.")
        return None
    except Exception as e:
        print(f"Error during SVD for '{state_variable_key}': {e}.")
        return None

    basis_norms = np.linalg.norm(basis, axis=0)
    basis_norms[basis_norms < 1e-10] = 1.0
    basis = basis / basis_norms[np.newaxis, :]

    if actual_basis_dim < basis_dim:
        print(f"Padding POD basis for '{state_variable_key}' from dim {actual_basis_dim} to {basis_dim}")
        padding = np.zeros((current_nx, basis_dim - actual_basis_dim))
        basis = np.hstack((basis, padding))

    print(f"  Successfully computed POD basis for '{state_variable_key}' with shape {basis.shape}.")
    return basis.astype(np.float32)



class ImprovedUpdateFFN(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_layers=3, dropout=0.1, output_dim=None):  # Added output_dim
        super().__init__()
        layers = []
        current_dim = input_dim
        if output_dim is None:
            output_dim = input_dim  # Default to input_dim if output_dim not specified

        for i in range(num_layers - 1):
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            current_dim = hidden_dim
        layers.append(nn.Linear(current_dim, output_dim))  # Use specified output_dim
        self.mlp = nn.Sequential(*layers)

        self.layernorm = nn.LayerNorm(output_dim)
        self.input_dim_for_residual = input_dim
        self.output_dim_for_residual = output_dim

    def forward(self, x):
        mlp_out = self.mlp(x)

        if self.input_dim_for_residual == self.output_dim_for_residual:
            out = self.layernorm(mlp_out + x)
        else:
            out = self.layernorm(mlp_out)
        return out


class UniversalLifting(nn.Module):
    def __init__(self, num_state_vars, bc_state_dim, num_controls, output_dim_per_var, nx,
                 hidden_dims_state_branch=64,
                 hidden_dims_control=[64, 128],
                 hidden_dims_fusion=[256, 512, 256],
                 dropout=0.1):
        super().__init__()
        self.num_state_vars = num_state_vars
        self.bc_state_dim = bc_state_dim
        self.num_controls = num_controls
        self.nx = nx
        assert output_dim_per_var == nx, "output_dim_per_var must equal nx"

        self.state_branches = nn.ModuleList()
        if self.bc_state_dim > 0:
            for _ in range(bc_state_dim):
                self.state_branches.append(nn.Sequential(
                    nn.Linear(1, hidden_dims_state_branch),
                    nn.GELU(),
                ))
            state_feature_dim = bc_state_dim * hidden_dims_state_branch
        else:
            state_feature_dim = 0

        control_feature_dim = 0
        if self.num_controls > 0:
            control_layers = []
            current_dim_ctrl = num_controls
            for h_dim in hidden_dims_control:
                control_layers.append(nn.Linear(current_dim_ctrl, h_dim))
                control_layers.append(nn.GELU())
                control_layers.append(nn.Dropout(dropout))
                current_dim_ctrl = h_dim
            self.control_mlp = nn.Sequential(*control_layers)
            control_feature_dim = current_dim_ctrl
        else:
            self.control_mlp = nn.Sequential()

        fusion_input_dim = state_feature_dim + control_feature_dim
        fusion_layers = []

        if fusion_input_dim > 0:  # Only build fusion if there are inputs
            current_dim_fusion = fusion_input_dim
            for h_dim in hidden_dims_fusion:
                fusion_layers.append(nn.Linear(current_dim_fusion, h_dim))
                fusion_layers.append(nn.GELU())
                fusion_layers.append(nn.Dropout(dropout))
                current_dim_fusion = h_dim
            fusion_layers.append(nn.Linear(current_dim_fusion, num_state_vars * nx))
            self.fusion = nn.Sequential(*fusion_layers)
        else:  # No inputs to fusion network, e.g. if bc_state_dim and num_controls are both 0
            self.fusion = None  # Or a dummy module that returns zeros

    def forward(self, BC_Ctrl):
        # BC_Ctrl has shape [batch, bc_state_dim + num_controls]
        if self.fusion is None:  # Handle case of no inputs
            batch_size = BC_Ctrl.shape[0] if BC_Ctrl is not None and BC_Ctrl.nelement() > 0 else 1
            return torch.zeros(batch_size, self.num_state_vars, self.nx,
                               device=BC_Ctrl.device if BC_Ctrl is not None else 'cpu')

        features_to_concat = []
        if self.bc_state_dim > 0:
            BC_state = BC_Ctrl[:, :self.bc_state_dim]
            state_features_list = []
            for i in range(self.bc_state_dim):
                branch_out = self.state_branches[i](BC_state[:, i:i + 1])
                state_features_list.append(branch_out)
            state_features = torch.cat(state_features_list, dim=-1)
            features_to_concat.append(state_features)

        if self.num_controls > 0:
            BC_control = BC_Ctrl[:, self.bc_state_dim:]  # Slice from bc_state_dim onwards
            control_features = self.control_mlp(BC_control)
            features_to_concat.append(control_features)

        if not features_to_concat:  # Should be caught by self.fusion is None if fusion_input_dim is 0
            batch_size = BC_Ctrl.shape[0] if BC_Ctrl is not None and BC_Ctrl.nelement() > 0 else 1
            return torch.zeros(batch_size, self.num_state_vars, self.nx,
                               device=BC_Ctrl.device if BC_Ctrl is not None else 'cpu')

        if len(features_to_concat) == 1:
            concat_features = features_to_concat[0]
        else:
            concat_features = torch.cat(features_to_concat, dim=-1)

        fused_output = self.fusion(concat_features)
        U_B_stacked = fused_output.view(-1, self.num_state_vars, self.nx)
        return U_B_stacked


class MultiHeadAttentionROM(nn.Module):  # This is the base attention, not the full ROM model
    def __init__(self, basis_dim, d_model, num_heads):  # basis_dim here is sequence_length for attention
        super().__init__()
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.head_dim = d_model // num_heads
        self.out_proj = nn.Linear(d_model, d_model)  # Projects concatenated heads back to d_model

    def forward(self, Q, K, V):  # Q, K, V are [batch, seq_len, d_model]
        batch_size, seq_len_q, d_model_q = Q.size()
        _, seq_len_kv, d_model_kv = K.size()  # K and V have same seq_len and d_model

        assert seq_len_q == seq_len_kv, "Sequence lengths of Q and K/V must match for this attention"
        assert d_model_q == d_model_kv, "d_model of Q and K/V must match"

        seq_len = seq_len_q
        d_model = d_model_q

        Q_reshaped = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K_reshaped = K.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V_reshaped = V.view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        KV = torch.matmul(K_reshaped.transpose(-2, -1), V_reshaped)
        z = torch.matmul(Q_reshaped, KV)

        z = z.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, d_model)
        z = self.out_proj(z)
        return z


class MultiVarAttentionROM(nn.Module):
    def __init__(self, state_variable_keys, nx, basis_dim, d_model,
                 bc_state_dim, num_controls, num_heads=8,
                 add_error_estimator=False, shared_attention=False,
                 dropout_lifting=0.1, dropout_ffn=0.1, initial_alpha=0.1,
                 use_fixed_lifting=False,
                 bc_processed_dim=64,
                 hidden_bc_processor_dim=128
                 ):
        super().__init__()
        self.state_keys = state_variable_keys
        self.num_state_vars = len(state_variable_keys)
        self.nx = nx
        self.basis_dim = basis_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.bc_state_dim = bc_state_dim
        self.num_controls = num_controls
        self.add_error_estimator = add_error_estimator
        self.shared_attention = shared_attention
        self.use_fixed_lifting = use_fixed_lifting
        self.bc_processed_dim = bc_processed_dim

        # --- Learnable Bases $\Phi$ ---
        self.Phi = nn.ParameterDict()
        for key in self.state_keys:
            phi_param = nn.Parameter(torch.randn(nx, basis_dim))
            nn.init.orthogonal_(phi_param)
            self.Phi[key] = phi_param

        # --- Lifting Network ---
        if not self.use_fixed_lifting:
            self.lifting = UniversalLifting(
                num_state_vars=self.num_state_vars,
                bc_state_dim=bc_state_dim,
                num_controls=num_controls,
                output_dim_per_var=nx,
                nx=nx,
                dropout=dropout_lifting
            )
        else:
            self.lifting = None
            lin_interp_coeffs = torch.linspace(0, 1, self.nx, dtype=torch.float32)
            self.register_buffer('lin_interp_coeffs', lin_interp_coeffs.view(1, 1, -1))

        # --- Attention & FFN Components ---
        self.W_Q = nn.ModuleDict()
        self.W_K = nn.ModuleDict()
        self.W_V = nn.ModuleDict()
        self.multihead_attn = nn.ModuleDict()
        self.proj_to_coef = nn.ModuleDict()
        self.update_ffn = nn.ModuleDict()
        self.a0_mapping = nn.ModuleDict()
        self.alphas = nn.ParameterDict()

        self.bc_feature_processor = nn.ModuleDict()
        self.bc_to_a_update = nn.ModuleDict()

        total_bc_ctrl_dim = self.bc_state_dim + self.num_controls

        if shared_attention:
            self.W_Q['shared'] = nn.Linear(1, d_model)
            self.W_K['shared'] = nn.Linear(nx, d_model)
            self.W_V['shared'] = nn.Linear(nx, d_model)
            self.multihead_attn['shared'] = MultiHeadAttentionROM(basis_dim, d_model, num_heads)
            self.proj_to_coef['shared'] = nn.Linear(d_model, 1)
            self.update_ffn['shared'] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                          hidden_dim=d_model, dropout=dropout_ffn)
            self.a0_mapping['shared'] = nn.Sequential(
                nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
            )
            self.alphas['shared'] = nn.Parameter(torch.tensor(initial_alpha))

            if total_bc_ctrl_dim > 0:
                self.bc_feature_processor['shared'] = nn.Sequential(
                    nn.Linear(total_bc_ctrl_dim, hidden_bc_processor_dim),
                    nn.GELU(),
                    nn.Linear(hidden_bc_processor_dim, self.bc_processed_dim)
                )
                self.bc_to_a_update['shared'] = nn.Linear(self.bc_processed_dim, basis_dim)
            else:
                self.bc_feature_processor['shared'] = nn.Sequential()
                self.bc_to_a_update['shared'] = nn.Linear(0, basis_dim) if self.bc_processed_dim == 0 else nn.Linear(
                    self.bc_processed_dim, basis_dim)

        else:  # Not shared_attention
            for key in self.state_keys:
                self.W_Q[key] = nn.Linear(1, d_model)
                self.W_K[key] = nn.Linear(nx, d_model)
                self.W_V[key] = nn.Linear(nx, d_model)
                self.multihead_attn[key] = MultiHeadAttentionROM(basis_dim, d_model, num_heads)
                self.proj_to_coef[key] = nn.Linear(d_model, 1)
                self.update_ffn[key] = ImprovedUpdateFFN(input_dim=basis_dim, output_dim=basis_dim,
                                                         hidden_dim=d_model, dropout=dropout_ffn)
                self.a0_mapping[key] = nn.Sequential(
                    nn.Linear(basis_dim, basis_dim), nn.ReLU(), nn.LayerNorm(basis_dim)
                )
                self.alphas[key] = nn.Parameter(torch.tensor(initial_alpha))

                if total_bc_ctrl_dim > 0:
                    self.bc_feature_processor[key] = nn.Sequential(
                        nn.Linear(total_bc_ctrl_dim, hidden_bc_processor_dim),
                        nn.GELU(),
                        nn.Linear(hidden_bc_processor_dim, self.bc_processed_dim)
                    )
                    self.bc_to_a_update[key] = nn.Linear(self.bc_processed_dim, basis_dim)
                else:
                    self.bc_feature_processor[key] = nn.Sequential()
                    self.bc_to_a_update[key] = nn.Linear(0, basis_dim) if self.bc_processed_dim == 0 else nn.Linear(
                        self.bc_processed_dim, basis_dim)

        if self.add_error_estimator:
            total_basis_dim_all_vars = self.num_state_vars * basis_dim
            self.error_estimator = nn.Linear(total_basis_dim_all_vars, 1)

    def _get_layer(self, module_dict, key):
        return module_dict['shared'] if self.shared_attention else module_dict[key]

    def _get_alpha(self, key):
        return self.alphas['shared'] if self.shared_attention else self.alphas[key]

    def _compute_U_B(self, BC_Ctrl_n):
        if not self.use_fixed_lifting:
            if self.lifting is None:
                raise ValueError("Lifting network is None but use_fixed_lifting is False.")
            return self.lifting(BC_Ctrl_n)
        else:
            if self.num_state_vars == 1:
                if self.bc_state_dim < 2:
                    print(
                        f"Warning: Fixed lifting for 1 state var expects bc_state_dim >= 2, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)

                bc_left_val = BC_Ctrl_n[:, 0:1].unsqueeze(-1)
                bc_right_val = BC_Ctrl_n[:, 1:2].unsqueeze(-1)
                U_B_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                          bc_right_val * self.lin_interp_coeffs
                return U_B_var
            else:
                U_B_list = []
                if self.bc_state_dim < 2 * self.num_state_vars and self.num_state_vars > 0:
                    print(
                        f"Warning: Fixed lifting for {self.num_state_vars} vars expects bc_state_dim >= {2 * self.num_state_vars}, got {self.bc_state_dim}. Returning zeros for U_B.")
                    return torch.zeros(BC_Ctrl_n.shape[0], self.num_state_vars, self.nx, device=BC_Ctrl_n.device)

                for i_var in range(self.num_state_vars):
                    idx_left = i_var * 2
                    idx_right = i_var * 2 + 1
                    if idx_right >= self.bc_state_dim:
                        print(
                            f"Warning: Not enough BC_State values for fixed lifting of var {i_var}. Using zeros for this var's U_B.")
                        U_B_single_var = torch.zeros(BC_Ctrl_n.shape[0], 1, self.nx, device=BC_Ctrl_n.device)
                    else:
                        bc_left_val = BC_Ctrl_n[:, idx_left:idx_left + 1].unsqueeze(-1)
                        bc_right_val = BC_Ctrl_n[:, idx_right:idx_right + 1].unsqueeze(-1)
                        U_B_single_var = bc_left_val * (1 - self.lin_interp_coeffs) + \
                                         bc_right_val * self.lin_interp_coeffs
                    U_B_list.append(U_B_single_var)
                return torch.cat(U_B_list, dim=1)

    def forward_step(self, a_n_dict, BC_Ctrl_n, params=None):
        batch_size = list(a_n_dict.values())[0].size(0)
        a_next_dict = {}
        U_hat_dict = {}

        U_B_stacked = self._compute_U_B(BC_Ctrl_n)

        bc_features_processed_dict = {}
        total_bc_ctrl_dim = self.bc_state_dim + self.num_controls

        if total_bc_ctrl_dim > 0:
            if self.shared_attention:
                bc_proc_layer = self._get_layer(self.bc_feature_processor, 'shared')
                if hasattr(bc_proc_layer, 'weight') or len(list(bc_proc_layer.parameters())) > 0:
                    shared_bc_features = bc_proc_layer(BC_Ctrl_n)
                    for key in self.state_keys:
                        bc_features_processed_dict[key] = shared_bc_features
                else:
                    for key in self.state_keys: bc_features_processed_dict[key] = None
            else:
                for key in self.state_keys:
                    bc_proc_layer = self._get_layer(self.bc_feature_processor, key)
                    if hasattr(bc_proc_layer, 'weight') or len(list(bc_proc_layer.parameters())) > 0:
                        bc_features_processed_dict[key] = bc_proc_layer(BC_Ctrl_n)
                    else:
                        bc_features_processed_dict[key] = None
        else:
            for key in self.state_keys:
                bc_features_processed_dict[key] = None

        for i_var, key in enumerate(self.state_keys):
            a_n_var = a_n_dict[key]
            Phi_var = self.Phi[key]

            W_Q_var = self._get_layer(self.W_Q, key)
            W_K_var = self._get_layer(self.W_K, key)
            W_V_var = self._get_layer(self.W_V, key)
            attn_module_var = self._get_layer(self.multihead_attn, key)
            proj_var = self._get_layer(self.proj_to_coef, key)
            ffn_var = self._get_layer(self.update_ffn, key)
            alpha_var = self._get_alpha(key)
            bc_to_a_update_layer = self._get_layer(self.bc_to_a_update, key)

            Phi_basis_vectors = Phi_var.transpose(0, 1).unsqueeze(0).expand(batch_size, -1, -1)

            K_flat = W_K_var(Phi_basis_vectors.reshape(-1, self.nx))
            V_flat = W_V_var(Phi_basis_vectors.reshape(-1, self.nx))
            K = K_flat.view(batch_size, self.basis_dim, self.d_model)
            V = V_flat.view(batch_size, self.basis_dim, self.d_model)

            a_n_unsq_for_Q = a_n_var.unsqueeze(-1)
            Q_base = W_Q_var(a_n_unsq_for_Q)

            ffn_update_intrinsic = ffn_var(a_n_var)
            Q_for_attention = Q_base + alpha_var * ffn_update_intrinsic.unsqueeze(-1)

            z = attn_module_var(Q_for_attention, K, V)
            z = z / np.sqrt(float(self.d_model))
            a_update_attn = proj_var(z.reshape(-1, self.d_model)).view(batch_size, self.basis_dim)

            bc_driven_a_update = torch.zeros_like(a_n_var)
            current_bc_features = bc_features_processed_dict[key]
            if current_bc_features is not None and (
                    hasattr(bc_to_a_update_layer, 'weight') or len(list(bc_to_a_update_layer.parameters())) > 0):
                bc_driven_a_update = bc_to_a_update_layer(current_bc_features)

            a_next_var = a_n_var + a_update_attn + alpha_var * ffn_update_intrinsic + bc_driven_a_update
            a_next_dict[key] = a_next_var

            U_B_var = U_B_stacked[:, i_var, :].unsqueeze(-1)
            Phi_expanded = Phi_var.unsqueeze(0).expand(batch_size, -1, -1)
            a_next_unsq = a_next_var.unsqueeze(-1)
            U_recon_star = torch.bmm(Phi_expanded, a_next_unsq)
            U_hat_dict[key] = U_B_var + U_recon_star

        err_est = None
        if self.add_error_estimator:
            a_next_combined_for_err_est = torch.cat(list(a_next_dict.values()), dim=-1)
            err_est = self.error_estimator(a_next_combined_for_err_est)

        return a_next_dict, U_hat_dict, err_est

    def forward(self, a0_dict, BC_Ctrl_seq, T, params=None):

        a_current_dict = {}
        for key in self.state_keys:
            a0_map_layer = self._get_layer(self.a0_mapping, key)
            a_current_dict[key] = a0_map_layer(a0_dict[key])

        U_hat_seq_dict = {key: [] for key in self.state_keys}
        err_seq = [] if self.add_error_estimator else None

        for t_step in range(T):
            # Teacher-forcing: Use the ground truth BC for the current step
            BC_Ctrl_n_step = BC_Ctrl_seq[:, t_step, :]
            a_next_dict_step, U_hat_dict_step, err_est_step = self.forward_step(
                a_current_dict, BC_Ctrl_n_step, params
            )

            for key in self.state_keys:
                U_hat_seq_dict[key].append(U_hat_dict_step[key])

            if self.add_error_estimator and err_est_step is not None:
                err_seq.append(err_est_step)

            a_current_dict = a_next_dict_step

        return U_hat_seq_dict, err_seq


        def predict_autoregressive(self, a0_dict, T, full_BC_Ctrl_seq_gt, feedback_controller):

            a_current_dict = {}
            for key in self.state_keys:
                a0_map_layer = self._get_layer(self.a0_mapping, key)
                a_current_dict[key] = a0_map_layer(a0_dict[key])

            U_hat_seq_dict = {key: [] for key in self.state_keys}


            BC_Ctrl_t0 = full_BC_Ctrl_seq_gt[:, 0, :]

            # Reconstruct U_hat at t=0 to kick off the loop
            U_B_stacked_t0 = self._compute_U_B(BC_Ctrl_t0)
            U_hat_dict_t0 = {}
            for i_var, key in enumerate(self.state_keys):
                Phi_var = self.Phi[key]
                # U_I = Phi * a
                U_I_t0 = torch.bmm(Phi_var.unsqueeze(0), a_current_dict[key].unsqueeze(-1))
                U_B_t0 = U_B_stacked_t0[:, i_var, :].unsqueeze(-1)
                U_hat_dict_t0[key] = U_B_t0 + U_I_t0

            for key in self.state_keys:
                U_hat_seq_dict[key].append(U_hat_dict_t0[key])

            # This will be the input to the controller for the first iteration
            U_hat_current_unstacked = {key: val.squeeze(-1) for key, val in U_hat_dict_t0.items()}

            # --- Loop for subsequent steps (from t=0 to t=T-2 to predict for t=1 to t=T-1) ---
            for t_step in range(T - 1):
                # 1. COMPUTE NEXT BC for step t_step+1


                # This is U_hat at time k=t_step
                U_hat_k_norm = U_hat_current_unstacked[self.state_keys[0]]


                external_bcs_k_norm = full_BC_Ctrl_seq_gt[:, t_step, :]


                BC_Ctrl_next_computed = feedback_controller.compute_bc(
                    U_hat_k_norm, external_bcs_k_norm, t_step * feedback_controller.params['dt']
                )

                # 2. EVOLVE THE STATE using the computed BCs to get the state at t_step+1
                a_next_dict_step, U_hat_dict_step, _ = self.forward_step(
                    a_current_dict, BC_Ctrl_next_computed
                )

                # 3. STORE results and UPDATE state for the next loop
                for key in self.state_keys:
                    U_hat_seq_dict[key].append(U_hat_dict_step[key])

                a_current_dict = a_next_dict_step
                U_hat_current_unstacked = {key: val.squeeze(-1) for key, val in U_hat_dict_step.items()}

            return U_hat_seq_dict, None

    def get_basis(self, key):
        return self.Phi[key]




# train_multivar_model - Unchanged, still uses teacher forcing
def train_multivar_model(model, data_loader, dataset_type, train_nt_target,
                         lr=1e-3, num_epochs=50, device='cuda',
                         checkpoint_path='rom_checkpoint.pt', lambda_res=0.05,
                         lambda_orth=0.001, lambda_bc_penalty=0.01,
                         clip_grad_norm=1.0):
    model.to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-5)
    mse_loss = nn.MSELoss()
    best_val_loss = float('inf')
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Loading existing checkpoint from {checkpoint_path} ...")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            if 'optimizer_state_dict' in checkpoint:
                try:
                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                except:
                    print("Warning: Could not load optimizer state. Reinitializing optimizer.")
            start_epoch = checkpoint.get('epoch', 0) + 1
            best_val_loss = checkpoint.get('loss', float('inf'))
            print(f"Resuming training from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting fresh.")

    state_keys = model.state_keys

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        count = 0

        for i, (state_data, BC_Ctrl_tensor, norm_factors) in enumerate(data_loader):
            if isinstance(state_data, list):
                state_tensors = [s.to(device) for s in state_data]
                batch_size, nt, nx = state_tensors[0].shape
            else:
                state_tensors = [state_data.to(device)]
                batch_size, nt, nx = state_tensors[0].shape
            BC_Ctrl_tensor = BC_Ctrl_tensor.to(device)
            if nt != train_nt_target:
                raise ValueError(f"Mismatch: nt from DataLoader ({nt}) != train_nt_target ({train_nt_target})")

            optimizer.zero_grad()

            a0_dict = {}
            BC_ctrl_combined_t0 = BC_Ctrl_tensor[:, 0, :]

            # Use _compute_U_B which is agnostic to fixed or learned lifting
            U_B_stacked_t0 = model._compute_U_B(BC_ctrl_combined_t0)

            for k, key in enumerate(state_keys):
                U0_var = state_tensors[k][:, 0, :].unsqueeze(-1)
                U_B_t0_var = U_B_stacked_t0[:, k, :].unsqueeze(-1)
                U0_star_var = U0_var - U_B_t0_var

                Phi_var = model.get_basis(key)
                Phi_T_var = Phi_var.transpose(0, 1).unsqueeze(0).expand(batch_size, -1, -1)
                a0_dict[key] = torch.bmm(Phi_T_var, U0_star_var).squeeze(-1)

            U_hat_seq_dict, _ = model(a0_dict, BC_Ctrl_tensor, T=train_nt_target, params=None)

            total_batch_loss = 0.0
            mse_recon_loss = 0.0
            residual_orth_loss = 0.0
            orth_loss = 0.0
            boundary_penalty = 0.0

            for k, key in enumerate(state_keys):
                Phi_var = model.get_basis(key)
                Phi_T_var = Phi_var.transpose(0, 1).unsqueeze(0).expand(batch_size, -1, -1)
                U_hat_seq_var = U_hat_seq_dict[key]
                U_target_seq_var = state_tensors[k]

                var_mse_loss = 0.0
                var_res_orth_loss = 0.0
                for t in range(train_nt_target):
                    pred = U_hat_seq_var[t]
                    target = U_target_seq_var[:, t, :].unsqueeze(-1)
                    var_mse_loss += mse_loss(pred, target)

                    if lambda_res > 0:
                        r = target - pred
                        r_proj = torch.bmm(Phi_T_var, r)
                        var_res_orth_loss += mse_loss(r_proj, torch.zeros_like(r_proj))

                mse_recon_loss += (var_mse_loss / nt)
                residual_orth_loss += (var_res_orth_loss / nt)

                if lambda_orth > 0:
                    PhiT_Phi = torch.matmul(Phi_var.transpose(0, 1), Phi_var)
                    I = torch.eye(model.basis_dim, device=device)
                    orth_loss += torch.norm(PhiT_Phi - I, p='fro') ** 2

                if lambda_bc_penalty > 0:
                    if Phi_var.shape[0] > 1:
                        boundary_penalty += mse_loss(Phi_var[0, :], torch.zeros_like(Phi_var[0, :])) + \
                                            mse_loss(Phi_var[-1, :], torch.zeros_like(Phi_var[-1, :]))
                    else:
                        boundary_penalty += mse_loss(Phi_var[0, :], torch.zeros_like(Phi_var[0, :]))

            orth_loss /= len(state_keys)
            boundary_penalty /= len(state_keys)

            total_batch_loss = mse_recon_loss + \
                               lambda_res * residual_orth_loss + \
                               lambda_orth * orth_loss + \
                               lambda_bc_penalty * boundary_penalty

            total_batch_loss.backward()
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()

            epoch_loss += total_batch_loss.item()
            count += 1

        avg_epoch_loss = epoch_loss / count
        print(f"Epoch {epoch + 1}/{num_epochs} finished. Average Training Loss: {avg_epoch_loss:.6f}")

        scheduler.step(avg_epoch_loss)
        if avg_epoch_loss < best_val_loss:
            best_val_loss = avg_epoch_loss
            print(f"Saving checkpoint with loss {best_val_loss:.6f} to {checkpoint_path}")
            save_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'dataset_type': dataset_type,
                'state_keys': model.state_keys,
                'nx': model.nx,
                'basis_dim': model.basis_dim,
                'd_model': model.d_model,
                'bc_state_dim': model.bc_state_dim,
                'num_controls': model.num_controls,
                'num_heads': model.num_heads,
                'shared_attention': model.shared_attention
            }
            torch.save(save_dict, checkpoint_path)

    print("Training finished.")
    if os.path.exists(checkpoint_path):
        print(f"Loading best model from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
    return model


def validate_multivar_model(model, data_loader, dataset_type,
                            train_nt_for_model_training: int,
                            T_value_for_model_training: float,
                            full_T_in_datafile: float,
                            full_nt_in_datafile: int, device='cuda',
                            save_fig_path='rom_result.png'):
    model.eval()
    results = {key: {'mse': [], 'rmse': [], 'relative_error': [], 'max_error': []}
               for key in model.state_keys}
    overall_rel_err_T1 = []

    feedback_datasets = ['convdiff', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain',
                         'heat_delayed_feedback']
    use_autoregressive_prediction = dataset_type in feedback_datasets

    test_horizons_T_values = [T_value_for_model_training,
                              T_value_for_model_training + 0.5 * (full_T_in_datafile - T_value_for_model_training),
                              full_T_in_datafile]
    test_horizons_T_values = sorted(list(set(h for h in test_horizons_T_values if h <= full_T_in_datafile)))

    print(f"Validation Horizons (T values): {test_horizons_T_values}")
    if use_autoregressive_prediction:
        print(">>> Using AUTO-REGRESSIVE prediction for validation (feedback loop enabled).")
    else:
        print(">>> Using TEACHER-FORCING prediction for validation.")

    try:
        state_data_full, BC_Ctrl_tensor_full, norm_factors_batch = next(iter(data_loader))
    except StopIteration:
        print("No validation data. Skipping validation.")
        return

    if isinstance(state_data_full, list):
        state_tensors_full = [s[0:1].to(device) for s in state_data_full]
        _, nt_loaded, nx_loaded = state_tensors_full[0].shape
    else:
        state_tensors_full = [state_data_full[0:1].to(device)]
        _, nt_loaded, nx_loaded = state_tensors_full[0].shape

    BC_Ctrl_tensor_full_sample = BC_Ctrl_tensor_full[0:1].to(device)

    norm_factors_sample = {}
    for key, val_tensor in norm_factors_batch.items():
        if isinstance(val_tensor, (torch.Tensor, np.ndarray)):
            norm_factors_sample[key] = val_tensor[0] if val_tensor.ndim > 0 else val_tensor
        else:
            norm_factors_sample[key] = val_tensor

    state_keys = model.state_keys
    current_batch_size = 1

    a0_dict = {}
    BC0_full = BC_Ctrl_tensor_full_sample[:, 0, :]
    U_B0_lifted = model._compute_U_B(BC0_full)
    for k_var_idx, key in enumerate(state_keys):
        U0_full_var = state_tensors_full[k_var_idx][:, 0, :].unsqueeze(-1)
        U_B0_var = U_B0_lifted[:, k_var_idx, :].unsqueeze(-1)
        Phi = model.get_basis(key).to(device)
        Phi_T = Phi.transpose(0, 1).unsqueeze(0)
        a0 = torch.bmm(Phi_T, U0_full_var - U_B0_var).squeeze(-1)
        a0_dict[key] = a0

    feedback_controller = None
    if use_autoregressive_prediction:
        dt = full_T_in_datafile / (full_nt_in_datafile - 1)
        dataset_params = {
            'dt': dt, 'nx': nx_loaded, 'norm_factors': norm_factors_sample
        }
        if dataset_type == 'convdiff':
            feedback_controller = ConvDiffIntegralController(dataset_params, device)
        elif dataset_type == 'reaction_diffusion_neumann_feedback':
            feedback_controller = ReactionDiffusionIntegralController(dataset_params, device)
        elif dataset_type == 'heat_nonlinear_feedback_gain':
            feedback_controller = HeatNonlinearIntegralController(dataset_params, device)
        else:
            print(f"Warning: No specific controller implemented for {dataset_type}. Using teacher-forcing.")
            use_autoregressive_prediction = False

    for T_test_horizon in test_horizons_T_values:
        nt_for_this_horizon = int((T_test_horizon / full_T_in_datafile) * (full_nt_in_datafile - 1)) + 1
        nt_for_this_horizon = min(nt_for_this_horizon, full_nt_in_datafile)

        print(f"\n--- Validating for T_horizon = {T_test_horizon:.2f} (nt = {nt_for_this_horizon}) ---")

        if use_autoregressive_prediction and feedback_controller is not None:
            feedback_controller.reset()  # Reset controller state for each horizon
            U_hat_seq_dict, _ = model.predict_autoregressive(
                a0_dict, T=nt_for_this_horizon,
                full_BC_Ctrl_seq_gt=BC_Ctrl_tensor_full_sample,
                feedback_controller=feedback_controller
            )
        else:
            BC_seq_for_pred = BC_Ctrl_tensor_full_sample[:, :nt_for_this_horizon, :]
            U_hat_seq_dict, _ = model(a0_dict, BC_seq_for_pred, T=nt_for_this_horizon)

        combined_pred_denorm = []
        combined_gt_denorm = []
        num_vars_plot = len(state_keys)
        fig, axs = plt.subplots(num_vars_plot, 3, figsize=(18, 5 * num_vars_plot), squeeze=False)
        L_vis = 1.0

        for k_var_idx, key in enumerate(state_keys):
            pred_norm_stacked = torch.cat(U_hat_seq_dict[key], dim=0)
            pred_norm_reshaped = pred_norm_stacked.view(nt_for_this_horizon, current_batch_size, nx_loaded)
            pred_norm_final = pred_norm_reshaped.squeeze(1).detach().cpu().numpy()

            mean_k_val = norm_factors_sample[f'{key}_mean']
            std_k_val = norm_factors_sample[f'{key}_std']
            mean_k = mean_k_val.item() if hasattr(mean_k_val, 'item') else mean_k_val
            std_k = std_k_val.item() if hasattr(std_k_val, 'item') else std_k_val
            pred_denorm = pred_norm_final * std_k + mean_k

            gt_norm_full_var = state_tensors_full[k_var_idx].squeeze(0).cpu().numpy()
            gt_norm_sliced = gt_norm_full_var[:nt_for_this_horizon, :]
            gt_denorm = gt_norm_sliced * std_k + mean_k

            combined_pred_denorm.append(pred_denorm.flatten())
            combined_gt_denorm.append(gt_denorm.flatten())

            mse_k = np.mean((pred_denorm - gt_denorm) ** 2)
            rmse_k = np.sqrt(mse_k)
            rel_err_k = np.linalg.norm(pred_denorm - gt_denorm, 'fro') / (np.linalg.norm(gt_denorm, 'fro') + 1e-10)
            max_err_k = np.max(np.abs(pred_denorm - gt_denorm))

            print(f"  '{key}': MSE={mse_k:.3e}, RMSE={rmse_k:.3e}, RelErr={rel_err_k:.3e}, MaxErr={max_err_k:.3e}")

            if abs(T_test_horizon - T_value_for_model_training) < 1e-5:
                results[key]['mse'].append(mse_k)
                results[key]['rmse'].append(rmse_k)
                results[key]['relative_error'].append(rel_err_k)
                results[key]['max_error'].append(max_err_k)

            diff_plot = np.abs(pred_denorm - gt_denorm)
            vmin_plot = min(gt_denorm.min(), pred_denorm.min())
            vmax_plot = max(gt_denorm.max(), pred_denorm.max())

            im0 = axs[k_var_idx, 0].imshow(gt_denorm, aspect='auto', origin='lower', vmin=vmin_plot, vmax=vmax_plot,
                                           extent=[0, L_vis, 0, T_test_horizon], cmap='viridis')
            axs[k_var_idx, 0].set_title(f"GT ({key}) T={T_test_horizon:.1f}");
            axs[k_var_idx, 0].set_ylabel("t")
            plt.colorbar(im0, ax=axs[k_var_idx, 0])

            im1 = axs[k_var_idx, 1].imshow(pred_denorm, aspect='auto', origin='lower', vmin=vmin_plot, vmax=vmax_plot,
                                           extent=[0, L_vis, 0, T_test_horizon], cmap='viridis')
            axs[k_var_idx, 1].set_title(f"Pred ({key}) T={T_test_horizon:.1f}")
            plt.colorbar(im1, ax=axs[k_var_idx, 1])

            im2 = axs[k_var_idx, 2].imshow(diff_plot, aspect='auto', origin='lower',
                                           extent=[0, L_vis, 0, T_test_horizon], cmap='magma')
            axs[k_var_idx, 2].set_title(f"Error ({key}) (Max {max_err_k:.2e})")
            plt.colorbar(im2, ax=axs[k_var_idx, 2])
            for j_plot in range(3): axs[k_var_idx, j_plot].set_xlabel("x")

        overall_rel_err_horizon = np.linalg.norm(
            np.concatenate(combined_pred_denorm) - np.concatenate(combined_gt_denorm)) / \
                                  (np.linalg.norm(np.concatenate(combined_gt_denorm)) + 1e-10)
        print(f"  Overall RelErr for T={T_test_horizon:.1f}: {overall_rel_err_horizon:.3e}")
        if abs(T_test_horizon - T_value_for_model_training) < 1e-5:
            overall_rel_err_T1.append(overall_rel_err_horizon)

        fig.suptitle(
            f"Validation @ T={T_test_horizon:.1f} ({dataset_type.upper()}) — basis={model.basis_dim}, d_model={model.d_model}")
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])
        horizon_fig_path = save_fig_path.replace('.png', f'_T{str(T_test_horizon).replace(".", "p")}.png')
        plt.savefig(horizon_fig_path)
        print(f"Saved validation figure to: {horizon_fig_path}")
        plt.show()

    print(f"\n--- Validation Summary (Metrics for T={T_value_for_model_training:.1f}) ---")
    for key in state_keys:
        if results[key]['mse']:
            avg_mse = np.mean(results[key]['mse'])
            avg_rmse = np.mean(results[key]['rmse'])
            avg_rel = np.mean(results[key]['relative_error'])
            avg_max = np.mean(results[key]['max_error'])
            print(f"  {key}: MSE={avg_mse:.4e}, RMSE={avg_rmse:.4e}, RelErr={avg_rel:.4e}, MaxErr={avg_max:.4e}")
    if overall_rel_err_T1:
        print(f"Overall Avg RelErr for T={T_value_for_model_training:.1f}: {np.mean(overall_rel_err_T1):.4e}")


# =============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train ROM models with various configurations.")
    parser.add_argument('--datatype', type=str, required=True,
                        choices=['heat_delayed_feedback',
                                 'reaction_diffusion_neumann_feedback',
                                 'heat_nonlinear_feedback_gain', 'convdiff'],
                        help='Type of dataset to use.')
    parser.add_argument('--use_fixed_lifting', action='store_true',
                        help='Use fixed linear interpolation for U_B instead of learned lifting.')
    parser.add_argument('--random_phi_init', action='store_true',
                        help='Use random orthogonal initialization for Phi instead of POD.')
    parser.add_argument('--use_lstm_rom', action='store_true',
                        help='Use LSTM-based ROM instead of Attention-based ROM.')

    parser.add_argument('--basis_dim', type=int, default=32,
                        help='Dimension of the reduced basis (modal coefficients).')
    parser.add_argument('--d_model', type=int, default=256, help='Model dimension for Attention ROM.')
    parser.add_argument('--num_heads', type=int, default=4, help='Number of heads for Attention ROM.')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')

    parser.add_argument('--lstm_hidden_dim', type=int, default=256, help='LSTM hidden dimension.')
    parser.add_argument('--num_lstm_layers', type=int, default=1, help='Number of LSTM layers.')
    parser.add_argument('--lstm_control_emb_dim', type=int, default=64, help='LSTM control embedding dimension.')

    parser.add_argument('--bc_processed_dim', type=int, default=64,
                        help='Dimension of processed BC features for explicit Attention ROM.')
    parser.add_argument('--hidden_bc_processor_dim', type=int, default=128,
                        help='Hidden dimension for the BC feature processor MLP in Attention ROM.')

    parser.add_argument('--num_epochs', type=int, default=150)
    parser.add_argument('--lr', type=float, default=5e-4)

    args = parser.parse_args()

    DATASET_TYPE = args.datatype
    USE_FIXED_LIFTING_ABLATION = args.use_fixed_lifting
    RANDOM_PHI_INIT_ABLATION = args.random_phi_init
    USE_LSTM_ROM_ABLATION = args.use_lstm_rom

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    learning_rate = args.lr
    num_epochs = args.num_epochs
    lambda_res = 0.05
    lambda_orth = 0.001
    lambda_bc_penalty = 0.01
    clip_grad_norm = 1.0

    print(f"Using device: {device}")
    print(f"Selected Dataset Type: {DATASET_TYPE.upper()}")

    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain',
                        'convdiff']:
        FULL_T_IN_DATAFILE = 2.0
        FULL_NT_IN_DATAFILE = 300
    else:
        FULL_T_IN_DATAFILE = 2.0
        FULL_NT_IN_DATAFILE = 600

    if DATASET_TYPE in ['heat_delayed_feedback', 'reaction_diffusion_neumann_feedback', 'heat_nonlinear_feedback_gain',
                        'convdiff']:
        TRAIN_T_TARGET = 1.5
    else:
        TRAIN_T_TARGET = 1.0

    if FULL_NT_IN_DATAFILE <= 1:
        TRAIN_NT_FOR_MODEL = FULL_NT_IN_DATAFILE
    else:
        TRAIN_NT_FOR_MODEL = int((TRAIN_T_TARGET / FULL_T_IN_DATAFILE) * (FULL_NT_IN_DATAFILE - 1)) + 1

    print(f"Full data T={FULL_T_IN_DATAFILE}, nt={FULL_NT_IN_DATAFILE}")
    print(f"Training will use T={TRAIN_T_TARGET}, nt={TRAIN_NT_FOR_MODEL}")

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    suffix_lift = "_fixedlift" if USE_FIXED_LIFTING_ABLATION else ""
    suffix_phi = "_randphi" if RANDOM_PHI_INIT_ABLATION else ""

    if USE_LSTM_ROM_ABLATION:
        suffix_rom = f"_lstm_h{args.lstm_hidden_dim}"
    else:
        suffix_rom = f"_attn_d{args.d_model}_bcp{args.bc_processed_dim}"

    run_name = f"{DATASET_TYPE}_b{args.basis_dim}{suffix_rom}{suffix_lift}{suffix_phi}_head{args.num_heads}_v2"

    checkpoint_dir = f"./New_ckpt_explicit_bc/_checkpoints_{DATASET_TYPE}"
    results_dir = f"./result_all_explicit_bc/results_{DATASET_TYPE}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'barom_{run_name}.pt')
    save_fig_path = os.path.join(results_dir, f'barom_result_{run_name}.png')
    basis_dir = os.path.join(checkpoint_dir, 'pod_bases')
    if not RANDOM_PHI_INIT_ABLATION:
        os.makedirs(basis_dir, exist_ok=True)

    if DATASET_TYPE == 'heat_delayed_feedback':
        dataset_path = "./datasets_new_feedback/heat_delayed_feedback_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64
        state_keys = ['U']
    elif DATASET_TYPE == 'reaction_diffusion_neumann_feedback':
        dataset_path = "./datasets_new_feedback/reaction_diffusion_neumann_feedback_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64
        state_keys = ['U']
    elif DATASET_TYPE == 'heat_nonlinear_feedback_gain':
        dataset_path = "./datasets_new_feedback/heat_nonlinear_feedback_gain_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64
        state_keys = ['U']
    elif DATASET_TYPE == 'convdiff':
        dataset_path = "./datasets_new_feedback/convdiff_v1_5000s_64nx_300nt.pkl"
        nx_data_param = 64
        state_keys = ['U']
    else:
        raise ValueError(f"Unknown dataset_type: {DATASET_TYPE}")

    print(f"Loading dataset for {DATASET_TYPE} from {dataset_path}")
    try:
        with open(dataset_path, 'rb') as f:
            data_list = pickle.load(f)
        print(f"Loaded {len(data_list)} samples.")
    except FileNotFoundError:
        print(f"Error: Dataset file not found at {dataset_path}"); exit()
    if not data_list: print("No data generated, exiting."); exit()

    random.shuffle(data_list)
    n_total = len(data_list);
    n_train = int(0.8 * n_total)
    train_data_list = data_list[:n_train];
    val_data_list = data_list[n_train:]
    print(f"Train samples: {len(train_data_list)}, Validation samples: {len(val_data_list)}")

    train_dataset = UniversalPDEDataset(train_data_list, dataset_type=DATASET_TYPE, train_nt_limit=TRAIN_NT_FOR_MODEL)
    val_dataset = UniversalPDEDataset(val_data_list, dataset_type=DATASET_TYPE, train_nt_limit=None)
    num_workers = 6
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=num_workers)

    current_nx_model = train_dataset.nx
    model_num_controls = train_dataset.num_controls
    model_bc_state_dim = train_dataset.bc_state_dim

    print(f"Model Configuration: nx={current_nx_model}, basis_dim={args.basis_dim}")
    print(f"Model BC_State dim: {model_bc_state_dim}, Model Num Controls: {model_num_controls}")

    online_model = None
    if USE_LSTM_ROM_ABLATION:
        print("LSTM ROM selected - this script focuses on AttentionROM upgrades.")
        pass
    else:
        print(f"Initializing Upgraded MultiVarAttentionROM with explicit BC processing:")
        online_model = MultiVarAttentionROM(
            state_variable_keys=state_keys,
            nx=current_nx_model,
            basis_dim=args.basis_dim,
            d_model=args.d_model,
            bc_state_dim=model_bc_state_dim,
            num_controls=model_num_controls,
            num_heads=args.num_heads,
            add_error_estimator=False,
            shared_attention=False,
            use_fixed_lifting=USE_FIXED_LIFTING_ABLATION,
            bc_processed_dim=args.bc_processed_dim,
            hidden_bc_processor_dim=args.hidden_bc_processor_dim
        )

    if online_model is None:
        print("Error: Model was not instantiated. Exiting.")
        exit()

    if not RANDOM_PHI_INIT_ABLATION:
        print("\nInitializing Phi with POD bases...")
        pod_bases = {}
        if not train_dataset: print("Error: Training dataset is empty. Cannot compute POD."); exit()
        try:
            first_sample_data_pod, _, _ = train_dataset[0]
            first_state_tensor_pod = first_sample_data_pod[0]
            actual_nt_for_pod = first_state_tensor_pod.shape[0]
        except IndexError:
            print("Error: Could not access shape from the first sample in train_dataset for POD."); exit()

        for key_pod_loop in state_keys:
            basis_filename = f'pod_basis_{key_pod_loop}_nx{current_nx_model}_nt{actual_nt_for_pod}_bdim{args.basis_dim}.npy'
            basis_path = os.path.join(basis_dir, basis_filename)
            loaded_basis = None
            if os.path.exists(basis_path):
                print(f"  Loading existing POD basis for '{key_pod_loop}' from {basis_path}...")
                try:
                    loaded_basis = np.load(basis_path)
                except Exception as e:
                    print(f"  Error loading {basis_path}: {e}. Will recompute."); loaded_basis = None
                if loaded_basis is not None and loaded_basis.shape != (current_nx_model, args.basis_dim):
                    print(f"  Shape mismatch for loaded basis '{key_pod_loop}'. Recomputing.");
                    loaded_basis = None
            if loaded_basis is None:
                print(
                    f"  Computing POD basis for '{key_pod_loop}' (using nt={actual_nt_for_pod}, basis_dim={args.basis_dim})...")
                computed_basis = compute_pod_basis_generic(
                    data_list=train_data_list, dataset_type=DATASET_TYPE, state_variable_key=key_pod_loop,
                    nx=current_nx_model, nt=actual_nt_for_pod, basis_dim=args.basis_dim)
                if computed_basis is not None:
                    pod_bases[key_pod_loop] = computed_basis
                    os.makedirs(os.path.dirname(basis_path), exist_ok=True)
                    np.save(basis_path, computed_basis)
                    print(f"  Saved computed POD basis for '{key_pod_loop}' to {basis_path}")
                else:
                    print(f"ERROR: Failed to compute POD basis for '{key_pod_loop}'. Exiting."); exit()
            else:
                pod_bases[key_pod_loop] = loaded_basis
        with torch.no_grad():
            for key_phi_init in state_keys:
                if key_phi_init in pod_bases and hasattr(online_model, 'Phi') and key_phi_init in online_model.Phi:
                    model_phi_param = online_model.Phi[key_phi_init]
                    pod_phi_tensor = torch.tensor(pod_bases[key_phi_init].astype(np.float32),
                                                  device=model_phi_param.device)
                    if model_phi_param.shape == pod_phi_tensor.shape:
                        model_phi_param.copy_(pod_phi_tensor);
                        print(f"  Initialized Phi for '{key_phi_init}' with POD.")
                    else:
                        print(f"  WARNING: Shape mismatch for Phi '{key_phi_init}'. Using random init.")
                else:
                    print(
                        f"  WARNING: No POD basis found or Phi module not present for '{key_phi_init}'. Using random init for Phi.")
    else:
        print("\nSkipping POD basis initialization for Phi (using random orthogonal initialization).")

    print(f"\nStarting training for {DATASET_TYPE.upper()}...")
    start_train_time = time.time()
    online_model = train_multivar_model(
        online_model, train_loader, dataset_type=DATASET_TYPE,
        train_nt_target=TRAIN_NT_FOR_MODEL,
        lr=learning_rate, num_epochs=num_epochs, device=device,
        checkpoint_path=checkpoint_path, lambda_res=lambda_res,
        lambda_orth=lambda_orth, lambda_bc_penalty=lambda_bc_penalty,
        clip_grad_norm=clip_grad_norm
    )
    end_train_time = time.time()
    print(f"Training took {end_train_time - start_train_time:.2f} seconds.")

    if val_data_list:
        print(f"\nStarting validation for {DATASET_TYPE.upper()}...")
        validate_multivar_model(
            online_model, val_loader, dataset_type=DATASET_TYPE, device=device,
            save_fig_path=save_fig_path,
            train_nt_for_model_training=TRAIN_NT_FOR_MODEL,
            T_value_for_model_training=TRAIN_T_TARGET,
            full_T_in_datafile=FULL_T_IN_DATAFILE,
            full_nt_in_datafile=FULL_NT_IN_DATAFILE
        )
    else:
        print("\nNo validation data. Skipping validation.")

    print("=" * 60)
    print(f"Run configuration finished for dataset: {DATASET_TYPE.upper()} - {run_name}")
    print(f"Target checkpoint path: {checkpoint_path}")
    if val_data_list: print(f"Target validation figure path: {save_fig_path}")
    print("=" * 60)