from tqdm import tqdm
import pandas as pd
import torch
import argparse
import os
import numpy as np
import torch.distributed as dist
import copy

import yaml

from vectorizer.columnVectorizer import *
from latent.TableLatentModel import TableLatentModel
from dataset.dataTransformer import DataTransformer

# Update checkpoint path handling
def get_latest_checkpoint(checkpoint_folder):
    """Find the checkpoint with the highest step number or return default path."""
    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder,exist_ok=True)
    checkpoint_files = [f for f in os.listdir(checkpoint_folder) if f.endswith('.pth')]
    default_path = os.path.join(checkpoint_folder, "best_model.pth")
    
    if not checkpoint_files:
        return default_path
        
    step_checkpoints = []
    for f in checkpoint_files:
        if '_step_' in f:
            try:
                step = int(f.split('_step_')[1].split('.')[0])
                step_checkpoints.append((step, f))
            except ValueError:
                continue
    
    if not step_checkpoints:
        return default_path
        
    # Get the file with the highest step number
    latest_checkpoint = max(step_checkpoints, key=lambda x: x[0])[1]
    return os.path.join(checkpoint_folder, latest_checkpoint)

def parse_args():
    parser = argparse.ArgumentParser(description='Train VAE model for table data')
    
    # Add distributed training arguments
    parser.add_argument('--dist_url', default='env://', type=str,
                        help='URL used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str,
                        help='Distributed backend to use (nccl, gloo, etc.)')
    parser.add_argument('--multinode', action='store_true',
                        help='Enable explicit multi-node distributed training')
    
    # Training mode and intervals
    parser.add_argument('--interval_type', type=str, default='epoch', choices=['epoch', 'step'],
                        help='Training interval type (epoch or step based)')
    parser.add_argument('--scheduler_interval', type=str, default='epoch', choices=['epoch', 'step'],
                        help='Scheduler update interval type')
    parser.add_argument('--num_epochs', type=int, default=None,
                        help='Number of training epochs')
    parser.add_argument('--max_steps', type=int, default=None,
                        help='Maximum number of training steps')
    
    # Data paths and splits
    parser.add_argument('--data_folder', type=str, required=True,
                        help='Path to data folder. It should has 2 subfolders: parquet, config.')
    parser.add_argument('--split_ratio', type=float, nargs=3, default=None,
                        help='Split ratios for train/val/test. Must sum to 1. Example: 0.7 0.15 0.15')
    parser.add_argument('--checkpoint_folder', type=str, default="../checkpoints",
                        help='Path to save trained model checkpoint. ')
    parser.add_argument('--test_idx', type=int, default=0, help='Index of current test')

    # LMDB dataset options
    parser.add_argument('--use_lmdb', action='store_true',
                        help='Use LMDB dataset instead of Parquet files')
    
    # Training parameters
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size for training')
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                        help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01,
                        help='Weight decay for optimizer')
    parser.add_argument('--scheduler_patience', type=int, default=3,
                        help='Patience for learning rate scheduler')
    parser.add_argument('--scheduler_factor', type=float, default=0.1,
                        help='Factor for learning rate scheduler')
    parser.add_argument('--vectorizer_warmup_epochs', type=int, default=2,
                        help='Number of epochs to warmup the vectorizer before training main VAE.')
    parser.add_argument('--save_interval', type=int, default=50,
                        help='Interval (in steps) to save model checkpoints during training')
    parser.add_argument('--validation_interval', type=int, default=None,
                        help='Interval (in steps) for validation during training')

    # Model architecture
    parser.add_argument('--autoencoder_type', type=str, default='unimodal',
                        choices=['unimodal', 'multimodal', 'disentangled', 'ae'],
                        help='Type of autoencoder to use (unimodal or multimodal or simple ae.)')
    parser.add_argument('--d_lm', type=int, default=1024,
                        help='Dimension of language model embeddings')
    parser.add_argument('--d_latent_len', type=int, default=16,
                        help='Length of latent dimension')
    parser.add_argument('--d_latent_width', type=int, default=64,
                        help='Width of latent dimension')
    parser.add_argument('--max_n_cols', type=int, default=100,
                        help='Maximum number of columns')
    
    # Encoder specific parameters
    parser.add_argument('--encoder_depth', type=int, default=2,
                        help='Depth of encoder')
    parser.add_argument('--encoder_dim_head', type=int, default=64,
                        help='Dimension of encoder attention heads')
    parser.add_argument('--encoder_ff_mult', type=int, default=4,
                        help='Feed forward multiplication factor')
    parser.add_argument('--fuse_option', type=str, default='flatten',
                        help='Fusion option for encoder')

    # Decoder specific parameters
    parser.add_argument('--decoder_depth', type=int, default=2,
                        help='Depth of decoder')
    parser.add_argument('--decoder_num_heads', type=int, default=8,
                        help='Number of decoder attention heads')

    # Other settings
    parser.add_argument('--device', type=str, default=None,
                        help='Device to use (cuda/cpu). If None, will use cuda if available')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for reproducibility')

    # Update Beta-VAE parameters
    parser.add_argument('--init_beta', type=float, default=0.0,
                        help='Initial beta value for KL divergence weight')
    parser.add_argument('--max_beta', type=float, default=1.0,
                        help='Maximum beta value for KL divergence weight')
    parser.add_argument('--max_beta_steps', type=int, default=1000,
                        help='Number of steps to reach maximum beta value')

    # Add test reconstruction option
    parser.add_argument('--test_construct', type=int, default=3,
                        help='Options: 0: no test, 1: test reconstruction, 2: test reconstruction with interpolation on latent space, 3, random sample from prior.')
    parser.add_argument('--reconstruct_folder', type=str, default='../outputs/vae',
                        help='Output folder for reconstructed tables')

    # Add checkpoint loading arguments
    parser.add_argument('--checkpoint', type=str, default=None,
                        help='Path to load model checkpoint from')
    parser.add_argument('--retrain', action='store_true',
                        help='Whether to retrain model even if checkpoint is provided')
    
    parser.add_argument('--clean_batches', action='store_true',
                        help='Whether to delete batched data after training.')

    # Add argument to save print output to a file
    parser.add_argument('--save_output', action='store_true',
                        help='Whether to save print output to a file')
    parser.add_argument('--log_file', type=str, default='output.log',
                        help='File to save print output')

    # Add early stopping argument
    parser.add_argument('--early_stop_patience', type=int, default=20,
                        help='Number of epochs or validation interval to wait before early stopping')

    # Add checkpoint resuming argument
    parser.add_argument('--resume_from_checkpoint', action='store_true',
                        help='Resume training from checkpoint if it exists')

    # Add shuffle argument
    parser.add_argument('--not_shuffle', action='store_true', default=False,
                        help='Whether to shuffle the data during training')

    # Add skip_iters argument
    parser.add_argument('--skip_iters', type=int, default=0,
                        help='Number of initial iterations to skip (useful for skipping problematic batches)')

    # Add scheduler parameters for warmup + cosine annealing
    parser.add_argument('--scheduler_type', type=str, default='cosine_warmup',
                        choices=['reduce_on_plateau', 'cosine_warmup'],
                        help='Type of learning rate scheduler to use')
    parser.add_argument('--warmup_percentage', type=float, default=0.08,
                        help='Percentage of total steps to use for linear warmup (only for cosine_warmup scheduler)')
    parser.add_argument('--min_lr', type=float, default=1e-7,
                        help='Minimum learning rate for cosine annealing (only for cosine_warmup scheduler)')

    # Add total_scheduler_steps parameter
    parser.add_argument('--total_scheduler_steps', type=int, default=None,
                        help='Total steps for scheduler. If not provided, uses training steps.')

    # Add gradient accumulation parameter
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help='Number of steps to accumulate gradients before performing a backward/update pass')

    # Add num_workers parameter
    parser.add_argument('--num_workers', type=int, default=8,
                        help='Number of worker processes for data loading')

    # Add numerical transformation argument
    parser.add_argument('--numerical_transformation', type=str, default='ple',
                        choices=['ple', 'quantile'],
                        help='Transformation method for numerical columns')

    # Add masked autoencoder parameter
    parser.add_argument('--mask_ratio', type=float, default=0.0,
                        help='Ratio of columns to mask for masked autoencoder setup (0.0 to 1.0)')

    # Add contrastive learning parameters
    parser.add_argument('--contrastive_weight', type=float, default=0.0,
                        help='Weight for contrastive loss (0.0 = disabled)')
    parser.add_argument('--contrastive_temperature', type=float, default=0.07,
                        help='Temperature parameter for contrastive loss')
    parser.add_argument('--base_contrastive_temperature', type=float, default=0.07,
                        help='Base temperature parameter for contrastive loss')
    parser.add_argument('--contrastive_dim', type=int, default=128,
                        help='Output dimension for contrastive projection head')

    # Add combination method parameter
    parser.add_argument('--combination_method', type=str, default='mopoe',
                        choices=['poe', 'moe', 'mopoe', 'samopoe'],
                        help='Method to combine distributions in multimodal VAE')

    # Add load_scheduler_state parameter
    parser.add_argument('--load_scheduler_state', action='store_true', default=True,
                        help='Load scheduler state from checkpoint (default behavior)')
    parser.add_argument('--no_load_scheduler_state', action='store_false', dest='load_scheduler_state',
                        help='Skip loading scheduler state from checkpoint (use when changing max_steps to avoid LR schedule conflicts)')

    # Add debugging argument
    parser.add_argument('--debugging', action='store_true',
                        help='Enable debugging mode for reconstruction test')

    args = parser.parse_args()
    
    # Validate training mode arguments
    if (args.num_epochs is None and args.max_steps is None) or \
       (args.num_epochs is not None and args.max_steps is not None):
        raise ValueError("Exactly one of num_epochs or max_steps must be provided")
    
    # Set device if not specified
    if args.device is None:
        args.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # If cpu is explicitly specified, disable CUDA to prevent any GPU usage
    if args.device == 'cpu':
        # Force CUDA to be unavailable
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        # Extra safety by monkeypatching torch.cuda.is_available
        torch.cuda.is_available = lambda: False
        print("CUDA/GPU usage explicitly disabled by setting device='cpu'")
    
    return args

def check_env():
    print("\n=== Environment Check ===")
    print(f"PyTorch version: {torch.__version__}")
    
    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    print(f"CUDA available: {cuda_available}")
    
    if cuda_available:
        # Get total number of GPUs
        n_gpus = torch.cuda.device_count()
        print(f"Number of available GPUs: {n_gpus}")
        
        # Add more detailed GPU info
        print("\nDetailed GPU Information:")
        print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not Set')}")
        print(f"Current CUDA device: {torch.cuda.current_device()}")
        
        # Get current process info
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
        world_size = int(os.environ.get('WORLD_SIZE', 1))
        print(f"\nDistributed Training Info:")
        print(f"LOCAL_RANK: {local_rank}")
        print(f"WORLD_SIZE: {world_size}")
        print(f"Process rank: {local_rank}/{world_size-1}")
        
        # Show info for each GPU
        print("\nPer-GPU Information:")
        for i in range(n_gpus):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
            props = torch.cuda.get_device_properties(i)
            print(f"    Total Memory: {props.total_memory / 1024**3:.1f} GB")
            if torch.cuda.is_available():
                print(f"    Used Memory: {torch.cuda.memory_allocated(i) / 1024**3:.1f} GB")
                print(f"    Cached Memory: {torch.cuda.memory_reserved(i) / 1024**3:.1f} GB")
    
    print("======================\n")
    return cuda_available

def setup_distributed():
    """Initialize distributed training environment"""
    if 'RANK' not in os.environ:
        return False, 0, 1

    # Initialize process group
    dist.init_process_group(backend='nccl')
    
    # Get rank and world_size
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    # Set device
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    torch.cuda.set_device(local_rank)
    
    return True, rank, world_size

def main():
    check_env()
    args = parse_args()

    # Set up distributed training
    distributed, rank, world_size = setup_distributed()
    
    # Only print from rank 0
    if rank == 0:
        print("\n=== Model Initialization ===")
        print(f"Distributed training: {distributed}")
        print(f"Rank: {rank}/{world_size-1}")
        print(f"Device specified: {args.device}")
        print(f"Model will be created on: {args.device}")

    # Ensure all output folders exist
    data_name = args.data_folder.split("/")[-1]
    experiment_name = f"{data_name}_vae_default_{args.test_idx}"

    # Set up paths based on dataset type
    if args.use_lmdb:
        lmdb_path = args.data_folder
        csv_log_path = os.path.join(args.data_folder, "lmdb_data_log.csv")
        output_folder = None  # Not needed for LMDB
    else:
        df_folder = os.path.join(args.data_folder, "parquet")
        config_folder = os.path.join(args.data_folder, "config")
        output_folder = os.path.join(args.data_folder, "batched_output")
        os.makedirs(output_folder, exist_ok=True)

    checkpoint_folder = os.path.join(args.checkpoint_folder, experiment_name)
    reconstruct_folder = os.path.join(args.reconstruct_folder, experiment_name)
    checkpoint_path = get_latest_checkpoint(checkpoint_folder)
    print("CHeckpoint path:",checkpoint_path)
    final_model_path = os.path.join(checkpoint_folder, "final_vae.pth")
    
    if args.resume_from_checkpoint and not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"No checkpoint found at {checkpoint_path} or {os.path.join(checkpoint_folder, 'best_model.pth')}")    

    os.makedirs(checkpoint_folder, exist_ok=True)
    os.makedirs(reconstruct_folder, exist_ok=True)

    # Save the args as a yaml file in the reconstruct_folder
    args_dict = vars(args)  # Convert Namespace to dictionary
    yaml_file_path = os.path.join(checkpoint_folder, 'args.yaml')
    
    with open(yaml_file_path, 'w') as yaml_file:
        yaml.dump(args_dict, yaml_file, default_flow_style=False)
    
    # Redirect print output to a file if specified
    if args.save_output:
        import sys
        sys.stdout = open(args.log_file, 'w')

    # Set random seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # Initialize TableLatentModel
    model = TableLatentModel(
        d_lm=args.d_lm,
        d_latent_len=args.d_latent_len,
        d_latent_width=args.d_latent_width,
        max_n_cols=args.max_n_cols,
        encoder_depth=args.encoder_depth,
        encoder_dim_head=args.encoder_dim_head,
        encoder_ff_mult=args.encoder_ff_mult,
        fuse_option=args.fuse_option,
        decoder_depth=args.decoder_depth,
        decoder_num_heads=args.decoder_num_heads,
        numerical_transformation=args.numerical_transformation,
        device=args.device,
        autoencoder_type=args.autoencoder_type,
        combination_method=args.combination_method,
    )

    # Add debug statement after model creation
    print(f"Model created. Parameters: {sum(p.numel() for p in model.autoencoder.parameters()):,}")
    print(f"Model device: {next(model.autoencoder.parameters()).device}")
    print("======================\n")

    # Load checkpoint if provided
    if args.checkpoint and os.path.exists(args.checkpoint):
        print(f"Loading checkpoint from {args.checkpoint}")
        checkpoint = torch.load(args.checkpoint, map_location=args.device, weights_only=True)
        model.load_checkpoint(checkpoint)
    elif args.checkpoint and not os.path.exists(args.checkpoint):
        print(f"Checkpoint {args.checkpoint} not found. Cannot load checkpoint!")

    # Prepare data and train model if no checkpoint or retrain requested
    if args.split_ratio is not None:
        if args.use_lmdb:
            train_dataloader, val_dataloader, test_dataloader = model.prepare_data(
                split_ratio=tuple(args.split_ratio),
                batch_size=args.batch_size,
                shuffle=not args.not_shuffle,
                use_lmdb=True,
                lmdb_path=lmdb_path,
                csv_log_path=csv_log_path,
                num_workers=args.num_workers
            )
        else:
            train_dataloader, val_dataloader, test_dataloader = model.prepare_data(
                df_folder=df_folder,
                config_folder=config_folder,
                output_folder=output_folder,
                batch_size=args.batch_size,
                shuffle=not args.not_shuffle,
                split_ratio=tuple(args.split_ratio),
                num_workers=args.num_workers
            )
    else:
        if args.use_lmdb:
            train_dataloader = model.prepare_data(
                batch_size=args.batch_size,
                shuffle=not args.not_shuffle,
                use_lmdb=True,
                lmdb_path=lmdb_path,
                csv_log_path=csv_log_path,
                num_workers=args.num_workers
            )
        else:
            train_dataloader = model.prepare_data(
                df_folder=df_folder,
                config_folder=config_folder,
                output_folder=output_folder,
                batch_size=args.batch_size,
                shuffle=not args.not_shuffle,
                num_workers=args.num_workers
            )
        val_dataloader = None
        test_dataloader = None

    # Train model if no checkpoint or retrain requested
    if args.checkpoint is None or args.retrain:
        print("\n=== Training Setup ===")
        print(f"Training device: {args.device}")
        print(f"Batch size: {args.batch_size}")
        if train_dataloader:
            print(f"Number of batches: {len(train_dataloader)}")
        if args.skip_iters > 0:
            print(f"Skipping first {args.skip_iters} iterations")
        if args.resume_from_checkpoint:
            print(f"Resuming from checkpoint: {checkpoint_path}")
            print(f"Loading scheduler state: {args.load_scheduler_state}")
            if not args.load_scheduler_state:
                print("Note: Scheduler state loading is disabled. A new learning rate schedule will be created.")
                print("      This is useful when changing max_steps between training runs.")
        print("======================\n")

        model.train(
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            num_epochs=args.num_epochs,
            max_steps=args.max_steps,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            scheduler_type=args.scheduler_type,
            scheduler_patience=args.scheduler_patience,
            scheduler_factor=args.scheduler_factor,
            warmup_percentage=args.warmup_percentage,
            min_lr=args.min_lr,
            init_beta=args.init_beta,
            max_beta=args.max_beta,
            max_beta_steps=args.max_beta_steps,
            vectorizer_warmup_epochs=args.vectorizer_warmup_epochs,
            checkpoint_path=checkpoint_path,
            save_interval=args.save_interval,
            validation_interval=args.validation_interval,
            distributed=distributed,
            rank=rank,
            world_size=world_size,
            early_stop_patience=args.early_stop_patience,
            interval_type=args.interval_type,
            scheduler_interval=args.scheduler_interval,
            resume_from_checkpoint=args.resume_from_checkpoint,
            skip_iters=args.skip_iters,
            total_scheduler_steps=args.total_scheduler_steps,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mask_ratio=args.mask_ratio,
            contrastive_weight=args.contrastive_weight,
            contrastive_temperature=args.contrastive_temperature,
            base_contrastive_temperature=args.base_contrastive_temperature,
            contrastive_dim=args.contrastive_dim,
            load_scheduler_state=args.load_scheduler_state
        )

        ckpt = model.get_checkpoint()
        torch.save(ckpt, final_model_path)

    # Test reconstruction if requested
    if args.test_construct != 0:
        if args.test_construct == 1:
            print("Testing reconstruction...")
        elif args.test_construct == 2:
            print("Testing reconstruction with interpolation on latent space...")
        elif args.test_construct == 3:
            print("Testing reconstruction with random sampling from prior...")
        elif args.test_construct == 4:
            print("Testing reconstruction without target column (assumed to be last column)...")
        os.makedirs(reconstruct_folder, exist_ok=True)
        
        # Keep track of reconstructed batches and original batches by dataset
        dataset_batches = {}
        original_batches = {}
        latent_batches = {}  # To store latent representations
        
        # For test_construct == 4, also track original full batches and target batches
        original_full_batches = {}
        target_batches = {}
        target_configs = {}
        
        # Use test dataloader if available, otherwise use train dataloader
        eval_dataloader = test_dataloader if test_dataloader is not None else train_dataloader
        
        # Use existing dataloader for reconstruction
        for batch in tqdm(eval_dataloader, desc="Reconstructing batches"):
            try:
                df_batch = batch['df_batch']
                config = batch['config']
                if "batch_file_path" in batch:
                    batch_file_path = batch['batch_file_path']
                    
                    # Extract dataset name from batch_file_path (format: "{dataset_name}_batch{batch_idx}.parquet")
                    print(batch_file_path)
                    dataset_name = batch_file_path.split('/')[-1].split('_batch')[0]
                else:
                    dataset_name = config['base_name'] # For LMDB dataset
                
                # For test_construct == 4, separate target column (assumed to be the last column)
                target_column = None
                target_config = None
                original_df_batch = df_batch.copy()
                original_config = copy.deepcopy(config)
                
                if args.test_construct == 4:
                    # Assume the last column in variables is the target
                    if 'variables' in config and len(config['variables']) > 0:
                        # Get the last variable
                        target_config = config['variables'][-1]
                        target_column = target_config['variable_name']
                        
                        print(f"Treating '{target_column}' as target column for reconstruction test")
                        
                        # Remove target from df_batch
                        if target_column in df_batch.columns:
                            # Store original df_batch before dropping target
                            df_batch = df_batch.drop(columns=[target_column])
                            
                            # Remove target from config
                            config = copy.deepcopy(config)
                            config['variables'] = config['variables'][:-1]
                        else:
                            print(f"Warning: Target column '{target_column}' not found in data")
                            target_column = None
                            target_config = None
                    else:
                        print("Warning: No variables found in config, cannot identify target column")
                
                # Convert to latent and back
                # Note that df_batch here is normalized.
                latent = model.table_to_latent(df_batch, config, batch_size=args.batch_size)
                if args.test_construct == 2:
                    # Create random interpolation of latents
                    N = len(latent)  # Number of interpolated examples
                    latent = latent.detach().cpu().numpy()
                    # Generate random interpolation between two random latent vectors
                    interpolated_latents = []
                    for _ in range(N):
                        idx1, idx2 = np.random.choice(latent.shape[0], 2, replace=False)
                        alpha = np.random.rand()  # Interpolation factor
                        interpolated_latent = (1 - alpha) * latent[idx1] + alpha * latent[idx2]
                        interpolated_latents.append(interpolated_latent)
                    latent = np.array(interpolated_latents)  # Shape (N, L, D)
                    latent = torch.tensor(latent).to(args.device)
                elif args.test_construct == 3:
                    # Randomly sample from prior of normal distribution
                    latent = np.random.normal(0, 1, (args.batch_size, args.d_latent_len, args.d_latent_width))
                    latent = torch.tensor(latent, dtype=torch.float32).to(args.device)  # Ensure correct dtype
                reconstructed_batch = model.latent_to_table(latent, config, batch_size=args.batch_size)
                
                # Store reconstructed batch
                if dataset_name not in dataset_batches:
                    dataset_batches[dataset_name] = []
                dataset_batches[dataset_name].append(reconstructed_batch)
                
                # Store original batch - for test_construct == 4, store the modified batch without target
                if dataset_name not in original_batches:
                    original_batches[dataset_name] = []
                original_batches[dataset_name].append(df_batch)
                
                # For test_construct == 4, also store the original full batch with target
                if args.test_construct == 4 and target_column is not None:
                    if dataset_name not in original_full_batches:
                        original_full_batches[dataset_name] = []
                    original_full_batches[dataset_name].append(original_df_batch)
                    
                    # Store the target column separately
                    if dataset_name not in target_batches:
                        target_batches[dataset_name] = []
                        target_configs[dataset_name] = target_config
                    target_batches[dataset_name].append(original_df_batch[[target_column]])
                
                # Save latent representation as numpy array
                if dataset_name not in latent_batches:
                    latent_batches[dataset_name] = []
                latent_batches[dataset_name].append(latent.detach().cpu().numpy())  # Convert to numpy and store
            except Exception as e:
                if args.debugging:
                    print(f"Error processing dataset {dataset_name if 'dataset_name' in locals() else 'unknown'}: {str(e)}")
                    continue
                else:
                    raise e
        
        # Combine and save reconstructed and original batches by dataset
        for dataset_name, batches in dataset_batches.items():
            # Combine reconstructed batches
            reconstructed_df = pd.concat(batches, ignore_index=True)
            
            # Combine original batches
            original_df = pd.concat(original_batches[dataset_name], ignore_index=True)
            
            # For test_construct == 4, also combine full original batches and target batches
            if args.test_construct == 4 and dataset_name in original_full_batches:
                original_full_df = pd.concat(original_full_batches[dataset_name], ignore_index=True)
                target_df = pd.concat(target_batches[dataset_name], ignore_index=True)
            
            # Get the config for this dataset
            # Assuming all batches for a dataset have the same config
            dataset_config = None
            for batch in eval_dataloader:
                batch_dataset_name = None
                if "batch_file_path" in batch:
                    batch_dataset_name = batch['batch_file_path'].split('/')[-1].split('_batch')[0]
                elif 'config' in batch and 'base_name' in batch['config']:
                    batch_dataset_name = batch['config']['base_name']
                
                if batch_dataset_name == dataset_name:
                    # For test_construct == 4, use the modified config without target
                    if args.test_construct == 4:
                        dataset_config = copy.deepcopy(batch['config'])
                        if 'variables' in dataset_config and len(dataset_config['variables']) > 0:
                            dataset_config['variables'] = dataset_config['variables'][:-1]
                    else:
                        dataset_config = batch['config']
                    break
            
            if dataset_config:
                # Create DataTransformer from config
                transformer = DataTransformer.from_config(dataset_config)
                
                # Apply inverse transformation to get back to original data format
                print(f"Applying inverse transformation to {dataset_name} data...")
                original_df_inverse = transformer.inverse_transform(original_df)
                reconstructed_df_inverse = transformer.inverse_transform(reconstructed_df)
                
                # Save inverse-transformed reconstructed data
                output_path = os.path.join(reconstruct_folder, f"{dataset_name}_reconstructed.csv")
                reconstructed_df_inverse.to_csv(output_path, index=False)
                print(f"Saved inverse-transformed reconstructed table to {output_path}")
                
                # Save inverse-transformed original data
                original_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_original.csv")
                original_df_inverse.to_csv(original_output_path, index=False)
                print(f"Saved inverse-transformed original table to {original_output_path}")
                
                # For test_construct == 4, save the target column separately
                if args.test_construct == 4 and dataset_name in target_batches:
                    # Get the original config with target to transform the target column
                    original_config = None
                    for batch in eval_dataloader:
                        batch_dataset_name = None
                        if "batch_file_path" in batch:
                            batch_dataset_name = batch['batch_file_path'].split('/')[-1].split('_batch')[0]
                        elif 'config' in batch and 'base_name' in batch['config']:
                            batch_dataset_name = batch['config']['base_name']
                        
                        if batch_dataset_name == dataset_name:
                            original_config = batch['config']
                            break
                    
                    if original_config and target_configs[dataset_name]:
                        # Create a config with only the target variable
                        target_only_config = {
                            'variables': [target_configs[dataset_name]]
                        }
                        
                        # Create transformer for target column
                        target_transformer = DataTransformer.from_config(target_only_config)
                        
                        # Inverse transform target
                        target_df_inverse = target_transformer.inverse_transform(target_df)
                        
                        # Save target column
                        target_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_target.csv")
                        target_df_inverse.to_csv(target_output_path, index=False)
                        print(f"Saved target column to {target_output_path}")
                        
                        # Also save the full original data with all columns
                        if original_full_df is not None:
                            # Create transformer for full data
                            full_transformer = DataTransformer.from_config(original_config)
                            full_df_inverse = full_transformer.inverse_transform(original_full_df)
                            
                            full_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_original_full.csv")
                            full_df_inverse.to_csv(full_output_path, index=False)
                            print(f"Saved full original data to {full_output_path}")
            else:
                # If config not found, save the transformed data with a warning
                print(f"Warning: Config not found for {dataset_name}. Saving transformed data only.")
                
                # Save transformed reconstructed data
                output_path = os.path.join(reconstruct_folder, f"{dataset_name}_reconstructed_transformed.csv")
                reconstructed_df.to_csv(output_path, index=False)
                print(f"Saved transformed reconstructed table to {output_path}")
                
                # Save transformed original data
                original_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_original_transformed.csv")
                original_df.to_csv(original_output_path, index=False)
                print(f"Saved transformed original table to {original_output_path}")
                
                # For test_construct == 4, save the target column separately
                if args.test_construct == 4 and dataset_name in target_batches:
                    target_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_target_transformed.csv")
                    target_df = pd.concat(target_batches[dataset_name], ignore_index=True)
                    target_df.to_csv(target_output_path, index=False)
                    print(f"Saved transformed target column to {target_output_path}")
        
        # Save latent representations as npy files
        for dataset_name, latents in latent_batches.items():
            latent_array = np.concatenate(latents, axis=0)  # Concatenate all latent arrays
            latent_output_path = os.path.join(reconstruct_folder, f"{dataset_name}_latent.npy")
            np.save(latent_output_path, latent_array)  # Save as .npy file
            print(f"Saved latent representation to {latent_output_path}")

    # Clean up distributed training
    if distributed:
        dist.destroy_process_group()

    # Close the output file if it was opened
    if args.save_output:
        sys.stdout.close()
        sys.stdout = sys.__stdout__  # Reset stdout to default

if __name__ == "__main__":
    main()
    