# -*- coding: utf-8 -*-

# Libraries
import json
import os
import warnings

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error, roc_auc_score
from scipy.stats import wasserstein_distance

warnings.filterwarnings('ignore')

# --- GPU/CPU Device Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------------------------------------------------------
# ---[ SECTION 1: CORE INR MODEL ARCHITECTURES ]---
# ----------------------------------------------------------------------------

# --- Activation Function Helpers ---
class LambdaModule(nn.Module):
    def __init__(self, lambda_func):
        super().__init__()
        self.lambda_func = lambda_func

    def forward(self, x):
        return self.lambda_func(x)

class Sine(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)

class Wire(nn.Module):
    def __init__(self, w0=1.0, s0=10.0):
        super().__init__()
        self.w0 = w0
        self.s0 = s0

    def forward(self, x):
        return torch.sin(self.w0 * x) * torch.sigmoid(self.s0 * x)

class Hosc(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x) / (self.w0 * x + 1e-8)

# --- Tabular INR Model Architecture ---
class TabularINR(nn.Module):
    def __init__(
        self,
        n_rows,
        latent_dim=32,
        hidden_dims=[256, 256],
        dropout_rate=0.1,
        activation="relu",
        w0=1.0,
    ):
        super().__init__()
        self.n_rows = n_rows
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.dropout_rate = dropout_rate
        self.activation_name = activation
        self.w0 = w0

        # Model components
        self.row_embedding = nn.Embedding(n_rows, latent_dim)
        self.col_embedding = None # This will be created in .fit() after OHE
        
        # Activation function mapping
        activations = {
            "relu": nn.ReLU(), "siren": Sine(w0), 
            "gauss": LambdaModule(lambda x: torch.exp(-(x**2))),
            "wire": Wire(w0), "hosc": Hosc(w0), 
            "sinc": LambdaModule(lambda x: torch.sinc(x)),
            "finer": LambdaModule(lambda x: torch.sin(torch.exp(x))),
        }

        # Build the network layers
        layers = []
        input_dim = latent_dim * 2
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(activations[self.activation_name])
            layers.append(nn.Dropout(dropout_rate))
            input_dim = hidden_dim
        self.network = nn.Sequential(*layers)

        # A single output head for all features
        self.output_head = nn.Linear(input_dim, 1)

        # Initialize weights
        nn.init.normal_(self.row_embedding.weight, mean=0.0, std=0.02)

        # Scaler and encoders
        self.scaler = MinMaxScaler()
        self.ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        self.numerical_cols = None
        self.categorical_cols = None
        self.ohe_cols = None
        self.ohe_col_mapping = {}
        self.original_cols = None

    def forward(self, row_idx, col_idx):
        row_emb = self.row_embedding(row_idx)
        col_emb = self.col_embedding(col_idx)
        x = torch.cat([row_emb, col_emb], dim=1)
        features = self.network(x)
        return self.output_head(features)
    
    def fit(self, df_miss, mask, numerical_cols, categorical_cols, epochs=100, lr=0.001):
        self.numerical_cols = numerical_cols
        self.categorical_cols = categorical_cols
        self.original_cols = df_miss.columns.tolist()
        
        # One-hot encode categorical columns if they exist
        if self.categorical_cols:
            df_cat = df_miss[self.categorical_cols].astype(str)
            ohe_transformed = self.ohe.fit_transform(df_cat)
            self.ohe_cols = self.ohe.get_feature_names_out(self.categorical_cols)
            df_ohe = pd.DataFrame(ohe_transformed, columns=self.ohe_cols, index=df_miss.index)
            for i, col in enumerate(self.categorical_cols):
                self.ohe_col_mapping[col] = [c for c in self.ohe_cols if c.startswith(col + '_')]
            
            # Create the new mask for OHE columns
            mask_ohe = pd.DataFrame(index=mask.index, columns=self.ohe_cols)
            for orig_col in self.categorical_cols:
                orig_mask = mask[orig_col]
                for ohe_col in self.ohe_col_mapping[orig_col]:
                    mask_ohe[ohe_col] = orig_mask

            df_processed = df_miss.drop(columns=self.categorical_cols)
            df_processed = pd.concat([df_processed, df_ohe], axis=1)
            
            mask_processed = mask.drop(columns=self.categorical_cols)
            mask_processed = pd.concat([mask_processed, mask_ohe], axis=1)
            mask_processed = mask_processed.astype(bool)

        else:
            df_processed = df_miss.copy()
            mask_processed = mask.copy()
            self.ohe_cols = []
        
        # Correctly initialize col_embedding after getting the number of processed columns
        n_cols_processed = df_processed.shape[1]
        self.col_embedding = nn.Embedding(n_cols_processed, self.latent_dim)
        nn.init.normal_(self.col_embedding.weight, mean=0.0, std=0.02)
        
        # Scale the processed data
        self.scaler.fit(df_processed)
        df_scaled = pd.DataFrame(self.scaler.transform(df_processed), columns=df_processed.columns, index=df_processed.index)
        
        # Prepare data for training
        rows_obs, cols_obs = np.where(~mask_processed.values)
        row_idx = torch.tensor(rows_obs, dtype=torch.long).to(DEVICE)
        col_idx = torch.tensor(cols_obs, dtype=torch.long).to(DEVICE)
        values = torch.tensor(df_scaled.values[rows_obs, cols_obs], dtype=torch.float32).to(DEVICE)

        self.to(DEVICE)
        optimizer = optim.Adam(self.parameters(), lr=lr)
        
        # Separate masks for numerical and one-hot encoded columns
        num_col_indices = [df_processed.columns.get_loc(c) for c in self.numerical_cols]
        ohe_col_indices = [df_processed.columns.get_loc(c) for c in self.ohe_cols]
        num_mask = torch.isin(col_idx, torch.tensor(num_col_indices, device=DEVICE))
        ohe_mask = torch.isin(col_idx, torch.tensor(ohe_col_indices, device=DEVICE))

        mse_loss_fn = nn.MSELoss()
        bce_loss_fn = nn.BCEWithLogitsLoss()

        for epoch in range(epochs):
            self.train()
            optimizer.zero_grad()
            preds = self.forward(row_idx, col_idx).squeeze()
            loss = 0
            if num_mask.any():
                loss += mse_loss_fn(preds[num_mask], values[num_mask])
            if ohe_mask.any():
                loss += bce_loss_fn(preds[ohe_mask], values[ohe_mask])
            if isinstance(loss, torch.Tensor):
                loss.backward()
                optimizer.step()
    
    def transform(self, df_miss, mask):
        self.eval()

        if self.categorical_cols:
            df_cat = df_miss[self.categorical_cols].astype(str)
            ohe_transformed = self.ohe.transform(df_cat)
            df_ohe = pd.DataFrame(ohe_transformed, columns=self.ohe_cols, index=df_miss.index)
            df_processed = df_miss.drop(columns=self.categorical_cols)
            df_processed = pd.concat([df_processed, df_ohe], axis=1)

            mask_ohe = pd.DataFrame(index=mask.index, columns=self.ohe_cols)
            for orig_col in self.categorical_cols:
                orig_mask = mask[orig_col]
                for ohe_col in self.ohe_col_mapping[orig_col]:
                    mask_ohe[ohe_col] = orig_mask
            mask_processed = mask.drop(columns=self.categorical_cols)
            mask_processed = pd.concat([mask_processed, mask_ohe], axis=1)
            mask_processed = mask_processed.astype(bool)

        else:
            df_processed = df_miss.copy()
            mask_processed = mask.copy()

        df_scaled = pd.DataFrame(self.scaler.transform(df_processed), columns=df_processed.columns, index=df_processed.index)
        imputed_df_scaled = df_scaled.copy()
        
        rows_miss, cols_miss = np.where(mask_processed.values)

        if len(rows_miss) > 0:
            with torch.no_grad():
                row_idx_miss = torch.tensor(rows_miss, dtype=torch.long).to(DEVICE)
                col_idx_miss = torch.tensor(cols_miss, dtype=torch.long).to(DEVICE)
                imputed_vals = self.forward(row_idx_miss, col_idx_miss).squeeze().cpu().numpy()
                imputed_df_scaled.values[rows_miss, cols_miss] = imputed_vals

        # Correctly create the dataframe with processed columns
        imputed_df_unscaled = pd.DataFrame(self.scaler.inverse_transform(imputed_df_scaled), columns=imputed_df_scaled.columns, index=imputed_df_scaled.index)
        
        # Post-process one-hot encoded columns (winner-takes-all)
        imputed_df_post_processed = self.post_process_imputed(imputed_df_unscaled, self.ohe_col_mapping)

        # Recombine the dataframe to its original form
        final_imputed_df = pd.DataFrame(index=df_miss.index, columns=self.original_cols)
        final_imputed_df[self.numerical_cols] = imputed_df_post_processed[self.numerical_cols]

        if self.categorical_cols:
            for cat_col, ohe_cols in self.ohe_col_mapping.items():
                if len(ohe_cols) > 0:
                    imputed_ohe = imputed_df_post_processed[ohe_cols].values
                    original_categories = self.ohe.categories_[self.categorical_cols.index(cat_col)]
                    imputed_categories = original_categories[np.argmax(imputed_ohe, axis=1)]
                    final_imputed_df[cat_col] = imputed_categories
        
        return final_imputed_df

    def post_process_imputed(self, imputed_df, ohe_col_mapping):
        """Applies winner-takes-all for one-hot encoded columns."""
        for orig_cat_col, ohe_group in ohe_col_mapping.items():
            if not ohe_group or not all(c in imputed_df.columns for c in ohe_group):
                continue
            imputed_group = imputed_df[ohe_group]
            winner_takes_all = np.zeros_like(imputed_group.values)
            winner_indices = np.argmax(imputed_group.values, axis=1)
            winner_takes_all[np.arange(len(winner_takes_all)), winner_indices] = 1
            imputed_df[ohe_group] = winner_takes_all
        return imputed_df
    

#
# # --- Embedding-based INR Model Architecture ---
#
class TabularINRWithRowColEmbedding(nn.Module):
    def __init__(self, n_rows, latent_dim=32, hidden_dim=256, dropout=0.1):
        super().__init__()
        self.n_rows = n_rows
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        
        self.row_embedding = nn.Embedding(n_rows, latent_dim)
        self.col_embedding = None # This will be created in .fit() after OHE
        
        self.network = nn.Sequential(
            nn.Linear(latent_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        nn.init.xavier_normal_(self.row_embedding.weight)
        self._init_weights()
        
        # Scaler and encoders
        self.scaler = MinMaxScaler()
        self.ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        self.numerical_col_indices = None
        self.ohe_col_indices = None
        self.ohe_col_mapping = {}
        self.original_cols = None

    def _init_weights(self):
        """Initialize network weights for better convergence."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, row_idx, col_idx):
        """Forward pass for the INR model."""
        row_idx, col_idx = row_idx.to(DEVICE), col_idx.to(DEVICE)
        row_emb = self.row_embedding(row_idx)
        col_emb = self.col_embedding(col_idx)
        x = torch.cat([row_emb, col_emb], dim=1)
        return self.network(x)

    def fit(self, df_miss, mask, numerical_cols, categorical_cols, epochs=10, patience=2, batch_size=4):
        self.numerical_cols = numerical_cols
        self.categorical_cols = categorical_cols
        self.original_cols = df_miss.columns.tolist()
        
        if categorical_cols:
            ohe_transformed = self.ohe.fit_transform(df_miss[categorical_cols].astype(str))
            self.ohe_cols = self.ohe.get_feature_names_out(categorical_cols)
            df_ohe = pd.DataFrame(ohe_transformed, columns=self.ohe_cols, index=df_miss.index)
            
            for col in categorical_cols:
                self.ohe_col_mapping[col] = [c for c in self.ohe_cols if c.startswith(col + '_')]

            mask_ohe = pd.DataFrame(index=mask.index, columns=self.ohe_cols)
            for orig_col in categorical_cols:
                orig_mask = mask[orig_col]
                for ohe_col in self.ohe_col_mapping[orig_col]:
                    mask_ohe[ohe_col] = orig_mask
            
            df_processed = pd.concat([df_miss.drop(columns=categorical_cols), df_ohe], axis=1)
            mask_processed = pd.concat([mask.drop(columns=categorical_cols), mask_ohe], axis=1).astype(bool)
        else:
            df_processed = df_miss.copy()
            mask_processed = mask.copy()
            self.ohe_cols = []
        
        n_cols_processed = df_processed.shape[1]
        self.col_embedding = nn.Embedding(n_cols_processed, self.latent_dim)
        nn.init.xavier_normal_(self.col_embedding.weight)
        
        self.to(DEVICE)

        self.scaler.fit(df_processed)
        df_scaled = pd.DataFrame(self.scaler.transform(df_processed), columns=df_processed.columns, index=df_processed.index)
        
        self.numerical_col_indices = [df_processed.columns.get_loc(c) for c in numerical_cols]
        ohe_cols_list = [c for c in df_processed.columns if c in self.ohe.get_feature_names_out(categorical_cols)] if categorical_cols else []
        self.ohe_col_indices = [df_processed.columns.get_loc(c) for c in ohe_cols_list]
        
        observed_mask = ~mask_processed.values
        row_indices, col_indices = np.where(observed_mask)
        values = df_scaled.values[observed_mask]
         
        row_indices_tensor = torch.tensor(row_indices, dtype=torch.long, device=DEVICE)
        col_indices_tensor = torch.tensor(col_indices, dtype=torch.long, device=DEVICE)
        values_tensor = torch.tensor(values, dtype=torch.float32, device=DEVICE)
        
        self.train_inr_model(
            numerical_col_indices=self.numerical_col_indices,
            ohe_col_indices=self.ohe_col_indices,
            row_indices=row_indices_tensor, col_indices=col_indices_tensor, values=values_tensor,
            epochs=epochs, patience=patience, batch_size=batch_size
        )
    
    def transform(self, df_miss, mask):
        self.eval()
        
        if self.categorical_cols:
            ohe_transformed = self.ohe.transform(df_miss[self.categorical_cols].astype(str))
            df_ohe = pd.DataFrame(ohe_transformed, columns=self.ohe_cols, index=df_miss.index)
            df_processed = pd.concat([df_miss.drop(columns=self.categorical_cols), df_ohe], axis=1)
            
            mask_ohe = pd.DataFrame(index=mask.index, columns=self.ohe_cols)
            for orig_col in self.categorical_cols:
                orig_mask = mask[orig_col]
                for ohe_col in self.ohe_col_mapping[orig_col]:
                    mask_ohe[ohe_col] = orig_mask
            mask_processed = pd.concat([mask.drop(columns=self.categorical_cols), mask_ohe], axis=1).astype(bool)
        else:
            df_processed = df_miss.copy()
            mask_processed = mask.copy()

        df_scaled = pd.DataFrame(self.scaler.transform(df_processed), columns=df_processed.columns, index=df_processed.index)
        df_imputed_processed = df_scaled.copy()
        
        for i in range(len(df_miss)):
            if mask_processed.iloc[i].any():
                observed_cols_idx = np.where(~mask_processed.iloc[i].values)[0]
                observed_values = df_scaled.iloc[i, observed_cols_idx].values
                
                if len(observed_cols_idx) > 0:
                    new_row_emb = self.optimize_new_row_embedding(
                        observed_cols_idx, observed_values,
                        self.numerical_col_indices, self.ohe_col_indices
                    )
                    
                    missing_cols_idx = np.where(mask_processed.iloc[i].values)[0]
                    if len(missing_cols_idx) > 0:
                        with torch.no_grad():
                            missing_col_tensor = torch.tensor(missing_cols_idx, dtype=torch.long, device=DEVICE)
                            row_emb_expanded = new_row_emb.expand(len(missing_cols_idx), -1)
                            col_emb = self.col_embedding(missing_col_tensor)
                            x = torch.cat([row_emb_expanded, col_emb], dim=1)
                            predictions = self.network(x).squeeze().cpu().numpy()
                            
                            if predictions.ndim == 0:
                                predictions = np.array([predictions.item()])
                            
                            df_imputed_processed.iloc[i, missing_cols_idx] = predictions

        imputed_df_unscaled = pd.DataFrame(self.scaler.inverse_transform(df_imputed_processed), columns=df_processed.columns, index=df_processed.index)
        
        imputed_df_post_processed = self.post_process_imputed(imputed_df_unscaled, self.ohe_col_mapping)
        
        final_imputed_df = pd.DataFrame(index=df_miss.index, columns=self.original_cols)
        final_imputed_df[self.numerical_cols] = imputed_df_post_processed[self.numerical_cols]
        if self.categorical_cols:
            for cat_col, ohe_cols in self.ohe_col_mapping.items():
                if len(ohe_cols) > 0:
                    imputed_ohe = imputed_df_post_processed[ohe_cols].values
                    original_categories = self.ohe.categories_[self.categorical_cols.index(cat_col)]
                    imputed_categories = original_categories[np.argmax(imputed_ohe, axis=1)]
                    final_imputed_df[cat_col] = imputed_categories
        
        return final_imputed_df

    def optimize_new_row_embedding(self, col_indices, values, numerical_col_indices, ohe_col_indices, n_steps=100, lr=0.001):
        """
        Optimizes a new row embedding for a single row with observed values.
        This is used during the imputation of new, unseen data.
        Now uses dual loss: MSE for numerical features, BCE for categorical features.
        """
        was_training = self.training
        self.eval()
        
        for param in self.parameters():
            param.requires_grad = False
            
        col_indices = torch.tensor(col_indices, dtype=torch.long, device=DEVICE)
        values = torch.tensor(values, dtype=torch.float32, device=DEVICE)
        new_row_emb = nn.Parameter(
            torch.randn(1, self.latent_dim, device=DEVICE) * 0.02, requires_grad=True
        )
        
        # Move column indices to device for faster lookup
        numerical_col_indices_dev = torch.tensor(numerical_col_indices, device=DEVICE)
        ohe_col_indices_dev = torch.tensor(ohe_col_indices, device=DEVICE)
        
        optimizer = optim.AdamW([new_row_emb], lr=lr, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_steps)
        
        # Define separate loss functions
        mse_loss_fn = nn.MSELoss()
        bce_loss_fn = nn.BCEWithLogitsLoss()
        
        def compute_loss(pred, cols, vals):
            """Compute combined loss for numerical and categorical features"""
            num_mask = torch.isin(cols, numerical_col_indices_dev)
            ohe_mask = torch.isin(cols, ohe_col_indices_dev)
            
            loss_num = 0
            loss_cat = 0
            
            if num_mask.any():
                loss_num = mse_loss_fn(pred[num_mask], vals[num_mask])
            if ohe_mask.any():
                loss_cat = bce_loss_fn(pred[ohe_mask], vals[ohe_mask])
                
            return loss_num + loss_cat
        
        for step in range(n_steps):  # Fixed the syntax error in original code
            optimizer.zero_grad()
            col_emb = self.col_embedding(col_indices)
            row_emb_expanded = new_row_emb.expand(len(col_indices), -1)
            x = torch.cat([row_emb_expanded, col_emb], dim=1)
            pred = self.network(x).squeeze()
            
            # Use dual loss instead of just MSE
            data_loss = compute_loss(pred, col_indices, values)
            reg_loss = 0.01 * torch.norm(new_row_emb, p=2)
            loss = data_loss + reg_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_([new_row_emb], max_norm=1.0)
            optimizer.step()
            scheduler.step()
        
        for param in self.parameters():
            param.requires_grad = True
        self.train(was_training)
        
        return new_row_emb.detach()

    def train_inr_model(self, numerical_col_indices, ohe_col_indices, row_indices, col_indices, values, epochs=10, patience=200, batch_size=1024):
        """Main training loop for the TabularINR model with dual loss for num/cat features."""
        scaler = torch.cuda.amp.GradScaler() if DEVICE.type == 'cuda' else None
        optimizer = optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50)
        
        # Define separate loss functions
        mse_loss_fn = nn.MSELoss()
        bce_loss_fn = nn.BCEWithLogitsLoss()
    
        # Move column indices to device for faster lookup
        numerical_col_indices_dev = torch.tensor(numerical_col_indices, device=DEVICE)
        ohe_col_indices_dev = torch.tensor(ohe_col_indices, device=DEVICE)
    
        num_samples = len(row_indices)
        indices = torch.randperm(num_samples, device=DEVICE)
        split = int(num_samples * 0.8)
        train_indices, val_indices = indices[:split], indices[split:]
        
        train_dataset = torch.utils.data.TensorDataset(
            row_indices[train_indices], col_indices[train_indices], values[train_indices]
        )
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False
        )
        
        val_row_idx, val_col_idx, val_values = row_indices[val_indices], col_indices[val_indices], values[val_indices]
        
        best_loss = float('inf')
        no_improve_count = 0
        
        def compute_loss(pred, cols, vals):
            loss = 0
            num_mask = torch.isin(cols, numerical_col_indices_dev)
            ohe_mask = torch.isin(cols, ohe_col_indices_dev)
            
            loss_num = 0
            loss_cat = 0
            
            if num_mask.any():
                loss_num = mse_loss_fn(pred[num_mask], vals[num_mask])
            if ohe_mask.any():
                loss_cat = bce_loss_fn(pred[ohe_mask], vals[ohe_mask])
                
            loss = loss_num + loss_cat
            return loss
    
        for epoch in range(epochs):
            self.train()
            for batch_row, batch_col, batch_val in train_loader:
                batch_row, batch_col, batch_val = batch_row.to(DEVICE), batch_col.to(DEVICE), batch_val.to(DEVICE)
                optimizer.zero_grad()
                
                if scaler:
                    with torch.cuda.amp.autocast():
                        pred = self(batch_row, batch_col).squeeze()
                        loss = compute_loss(pred, batch_col, batch_val)
                    if isinstance(loss, torch.Tensor):
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                else:
                    pred = self(batch_row, batch_col).squeeze()
                    loss = compute_loss(pred, batch_col, batch_val)
                    if isinstance(loss, torch.Tensor):
                        loss.backward()
                        optimizer.step()
            
            self.eval()
            with torch.no_grad():
                val_pred = self(val_row_idx, val_col_idx).squeeze()
                val_loss_tensor = compute_loss(val_pred, val_col_idx, val_values)
                val_loss = val_loss_tensor.item() if isinstance(val_loss_tensor, torch.Tensor) else 0.0
            
            scheduler.step(val_loss)
            if val_loss < best_loss:
                best_loss = val_loss
                no_improve_count = 0
            else:
                no_improve_count += 1
            
            if no_improve_count >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
                
        return self
    
    def impute_with_inr(model, df_with_missing, mask, numerical_col_indices, ohe_col_indices):
        """Imputes missing values using a trained INR model with dual loss optimization."""
        model.eval()
        df_imputed = df_with_missing.copy()
        
        for i in range(len(df_with_missing)):
            if mask.iloc[i].any():  # If there are missing values in the row
                observed_cols = np.where(~mask.iloc[i].values)[0]
                if len(observed_cols) > 0:
                    observed_values = df_with_missing.iloc[i, observed_cols].values
                    
                    # Pass the additional parameters for dual loss
                    new_row_emb = model.optimize_new_row_embedding(
                        observed_cols, 
                        observed_values, 
                        numerical_col_indices, 
                        ohe_col_indices
                    )
                    
                    missing_cols = np.where(mask.iloc[i].values)[0]
                    if len(missing_cols) > 0:
                        with torch.no_grad():
                            missing_col_tensor = torch.tensor(missing_cols, dtype=torch.long, device=DEVICE)
                            row_emb_expanded = new_row_emb.expand(len(missing_cols), -1)
                            col_emb = model.col_embedding(missing_col_tensor)
                            x = torch.cat([row_emb_expanded, col_emb], dim=1)
                            predictions = model.network(x).squeeze().cpu().numpy()
                            
                            # Handle case where only one value is predicted
                            if predictions.ndim == 0:
                                predictions = [predictions.item()]
                            
                            # Post-process predictions based on feature type
                            for j, col_idx in enumerate(missing_cols):
                                if col_idx in ohe_col_indices:
                                    # Apply sigmoid for categorical features (since we used BCEWithLogitsLoss)
                                    predictions[j] = 1 / (1 + np.exp(-predictions[j]))  # Sigmoid activation
                            
                            df_imputed.iloc[i, missing_cols] = predictions
        return df_imputed
        
    def post_process_imputed(self , imputed_df, ohe_col_mapping):
        """Applies winner-takes-all for one-hot encoded columns."""
        for _, ohe_group in ohe_col_mapping.items():
            if not ohe_group or not all(c in imputed_df.columns for c in ohe_group):
                continue
            imputed_group = imputed_df[ohe_group]
            winner_takes_all = np.zeros_like(imputed_group.values)
            winner_indices = np.argmax(imputed_group.values, axis=1)
            winner_takes_all[np.arange(len(winner_takes_all)), winner_indices] = 1
            imputed_df[ohe_group] = winner_takes_all
        return imputed_df

    def evaluate_imputation_performance(imputed_df, original_df, mask, numerical_cols, categorical_cols, ohe_col_mapping):
        """Calculates NRMSE, RMSE, WD, and AUROC for imputation quality."""
        nrmse, rmse, wd_scores, auroc_scores = 0, 0, [], []
        
        if numerical_cols:
            num_mask = mask[numerical_cols].values
            if np.any(num_mask):
                imputed = imputed_df[numerical_cols].values[num_mask]
                original = original_df[numerical_cols].values[num_mask]
                rmse = np.sqrt(mean_squared_error(original, imputed))
                std_dev = np.std(original)
                if std_dev > 1e-8: nrmse = rmse / std_dev
            for col in numerical_cols:
                wd_scores.append(wasserstein_distance(original_df[col], imputed_df[col]))

        if categorical_cols:
            for cat_col, ohe_cols in ohe_col_mapping.items():
                cat_mask = mask[ohe_cols].values.any(axis=1)
                if np.any(cat_mask) and len(ohe_cols) > 1:
                    try:
                        auroc_scores.append(roc_auc_score(
                            original_df.loc[cat_mask, ohe_cols], 
                            imputed_df.loc[cat_mask, ohe_cols], 
                            multi_class='ovr'
                        ))
                    except ValueError: pass
        
        return nrmse, rmse, np.mean(wd_scores) if wd_scores else 0, np.mean(auroc_scores) if auroc_scores else 0
