#!/usr/bin/env python3
"""
Improved NPT (Nuisance-Prompt Tuning) experiment script with variance-aware attention regularization
and entropy maximization for enhanced feature discrimination and better OOD detection performance.
"""

import argparse
import os
import sys
import torch
import torch.nn.utils
import time
import datetime
import numpy as np
import pandas as pd
import csv
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from torch.nn import functional as F

from utils.train_util import setup_logger, set_random_seed, AverageMeter, load_clip_to_cpu
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from datasets.train.imagenet import ImageNetDataset
from datasets.eval.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from clip_w_local.simple_tokenizer import SimpleTokenizer as _Tokenizer
from trainers import build_optimizer, build_lr_scheduler
from utils.losses import compute_accuracy
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json
from configs.implemention import get_cfg_default
from yacs.config import CfgNode as CN

from npt_models import NPTCustomCLIP, extract_attention_weights


class ImprovedNPTExperiment:
    """Class to manage improved NPT experiments with variance-aware attention regularization"""
    
    def __init__(self, cfg: CN, args: argparse.Namespace):
        self.cfg = cfg
        self.args = args
        self.device = self._setup_device()
        self.model = None
        self.train_loader = None
        self.optimizer = None
        self.scheduler = None
        
        # Momentum-based loss balancing parameters
        self.momentum_beta = getattr(args, 'momentum_beta', 0.9)
        self.adaptation_alpha = getattr(args, 'adaptation_alpha', 0.1)
        self.min_weight_factor = getattr(args, 'min_weight_factor', 0.1)
        self.max_weight_factor = getattr(args, 'max_weight_factor', 3.0)
        self.warmup_steps = getattr(args, 'warmup_steps', 50)
        
        # Variance-aware attention regularization and entropy maximization parameters
        self.lambda_var = getattr(args, 'lambda_var', 0.1)
        self.lambda_entropy = getattr(args, 'lambda_entropy', 0.05)
        self.epsilon = getattr(args, 'epsilon', 1e-8)
        
    def _setup_device(self) -> torch.device:
        """Setup device"""
        if torch.cuda.is_available() and self.cfg.USE_CUDA:
            torch.backends.cudnn.benchmark = True
            return torch.device("cuda")
        return torch.device("cpu")
    
    def _validate_config(self) -> None:
        """Validate configuration"""
        if not hasattr(self.cfg, 'DATASET') or not hasattr(self.cfg.DATASET, 'ROOT'):
            raise ValueError("DATASET.ROOT must be specified in config")
        
        if not os.path.exists(self.cfg.DATASET.ROOT):
            raise ValueError(f"Dataset path does not exist: {self.cfg.DATASET.ROOT}")
            
        if not hasattr(self.cfg, 'OUTPUT_DIR'):
            raise ValueError("OUTPUT_DIR must be specified in config")
    
    def _create_train_transform(self) -> Compose:
        """Create training data transforms"""
        if hasattr(self.cfg, 'INPUT') and hasattr(self.cfg.INPUT, 'SIZE'):
            resize_size = tuple(self.cfg.INPUT.SIZE)
        else:
            resize_size = (224, 224)
            
        train_transform = []
        if "random_resized_crop" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomResizedCrop(resize_size))
        if "random_flip" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(RandomHorizontalFlip())
        train_transform.append(ToTensor())
        if "normalize" in self.cfg.INPUT.TRANSFORMS:
            train_transform.append(Normalize(mean=self.cfg.INPUT.PIXEL_MEAN, std=self.cfg.INPUT.PIXEL_STD))
        
        return Compose(train_transform)
    
    def setup_data(self) -> None:
        """Setup dataset and data loaders"""
        print("Setting up data loaders...")
        
        # Prepare training dataset
        num_shots = getattr(self.cfg.DATASET, 'NUM_SHOTS', 1)
        seed = getattr(self.cfg, 'SEED', 1)
        
        train_transform = self._create_train_transform()
        
        self.train_dataset = ImageNetDataset(
            root=self.cfg.DATASET.ROOT,
            split="train",
            num_shots=num_shots,
            seed=seed,
            transform=train_transform
        )
        
        # Prepare data loader
        batch_size = self.cfg.DATALOADER.TRAIN_X.BATCH_SIZE if hasattr(self.cfg.DATALOADER, 'TRAIN_X') else 32
        num_workers = getattr(self.cfg.DATALOADER, 'NUM_WORKERS', 8)
        
        print(f"Using {num_workers} workers for data loading")
        
        self.train_loader = DataLoader(
            self.train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=True if num_workers > 0 else False,
            prefetch_factor=2 if num_workers > 0 else None
        )
        
        print(f"Train dataset size: {len(self.train_dataset)}")
    
    def setup_model(self) -> None:
        """Setup model and optimizer"""
        print("Setting up Improved NPT model with variance-aware attention regularization and entropy maximization...")
        print(f"  Momentum beta: {self.momentum_beta}")
        print(f"  Adaptation alpha: {self.adaptation_alpha}")
        print(f"  Min weight factor: {self.min_weight_factor}")
        print(f"  Max weight factor: {self.max_weight_factor}")
        print(f"  Warmup steps: {self.warmup_steps}")
        print(f"  Lambda var: {self.lambda_var}")
        print(f"  Lambda entropy: {self.lambda_entropy}")
        print(f"  Epsilon: {self.epsilon}")
        
        # Load CLIP model
        print(f"Loading CLIP (backbone: {self.cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(self.cfg)
        
        if self.cfg.TRAINER.LOCOOP.PREC in ["fp32", "amp"]:
            clip_model.float()
        
        # Add all hyperparameters to config
        self.cfg.defrost()
        self.cfg.momentum_beta = self.momentum_beta
        self.cfg.adaptation_alpha = self.adaptation_alpha
        self.cfg.min_weight_factor = self.min_weight_factor
        self.cfg.max_weight_factor = self.max_weight_factor
        self.cfg.warmup_steps = self.warmup_steps
        self.cfg.lambda_var = self.lambda_var
        self.cfg.lambda_entropy = self.lambda_entropy
        self.cfg.epsilon = self.epsilon
        self.cfg.freeze()
        
        # Build NPT custom CLIP model
        classnames = self.train_dataset.classnames
        print("Building NPT Custom CLIP with variance-aware attention regularization and entropy maximization")
        self.model = NPTCustomCLIP(self.cfg, classnames, clip_model)
        
        # Disable gradients except for prompt learner
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        
        self.model.to(self.device)
        
        # Build optimizer and scheduler
        self.optimizer = build_optimizer(self.model.prompt_learner, self.cfg.OPTIM)
        self.scheduler = build_lr_scheduler(self.optimizer, self.cfg.OPTIM)
        
        print("Improved NPT Model with variance-aware attention regularization and entropy maximization setup completed")
    
    def extract_patch_attention(self, image):
        """Extract attention weights for patch weighting"""
        image_input = image.type(self.model.dtype)
        return extract_attention_weights(self.model.image_encoder, image_input)

    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """Train for one epoch with improved NPT loss"""
        self.model.train()
        total_loss = 0
        total_acc = 0
        total_samples = 0
        batch_time = AverageMeter()
        data_time = AverageMeter()
        num_batches = len(self.train_loader)
        
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        # Loss tracking
        total_loss_global = 0
        total_loss_patch = 0 
        total_loss_margin = 0
        total_loss_var = 0
        total_loss_entropy = 0
        
        # Adaptive weight tracking
        total_adaptive_lambda_patch = 0
        total_adaptive_lambda_margin = 0
        
        end = time.time()
        
        for batch_idx, batch in enumerate(self.train_loader):
            data_time.update(time.time() - end)
            images, labels, _ = batch
            images = images.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            
            # Forward pass
            logits, logits_local, text_features = self.model(images)
            
            # Extract attention weights for patch weighting
            try:
                attention_weights = self.extract_patch_attention(images)
            except:
                attention_weights = None
            
            # Get image features for loss computation
            image_features, local_features = self.model.image_encoder(images.type(self.model.dtype))
            
            # Normalize features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            local_features = local_features / local_features.norm(dim=-1, keepdim=True)
            
            # Compute NPT loss with momentum-based balancing
            loss_dict = self.model.compute_npt_loss(
                image_features, local_features, text_features, labels, attention_weights
            )
            
            loss = loss_dict['loss_total']
            loss.backward()
            
            # Optimizer step
            self.optimizer.step()

            # Calculate accuracy on class logits only
            acc = compute_accuracy(logits, labels)[0].item()
            
            # Update metrics
            batch_size = images.size(0)
            total_loss += loss.item() * batch_size
            total_acc += acc * batch_size / 100.0
            total_samples += batch_size
            
            total_loss_global += loss_dict['loss_global'].item() * batch_size
            total_loss_patch += loss_dict['loss_patch'].item() * batch_size
            total_loss_margin += loss_dict['loss_margin'].item() * batch_size
            total_loss_var += loss_dict['loss_var'].item() * batch_size
            total_loss_entropy += loss_dict['loss_entropy'].item() * batch_size
            
            total_adaptive_lambda_patch += loss_dict['adaptive_lambda_patch'] * batch_size
            total_adaptive_lambda_margin += loss_dict['adaptive_lambda_margin'] * batch_size

            batch_time.update(time.time() - end)
            end = time.time()

            # Calculate ETA
            nb_remain = (num_batches - batch_idx - 1) + (max_epoch - epoch - 1) * num_batches
            eta_seconds = batch_time.avg * nb_remain
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))

            # Display progress
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == num_batches:
                print(
                    f"epoch [{epoch + 1}/{max_epoch}] "
                    f"batch [{batch_idx + 1}/{num_batches}] "
                    f"time {batch_time.val:.3f} ({batch_time.avg:.3f}) "
                    f"data {data_time.val:.3f} ({data_time.avg:.3f}) "
                    f"loss {loss.item():.4f} "
                    f"l_global {loss_dict['loss_global'].item():.4f} "
                    f"l_patch {loss_dict['loss_patch'].item():.4f} "
                    f"l_margin {loss_dict['loss_margin'].item():.4f} "
                    f"l_var {loss_dict['loss_var'].item():.4f} "
                    f"l_entropy {loss_dict['loss_entropy'].item():.4f} "
                    f"acc {acc:.2f} "
                    f"adap_λp {loss_dict['adaptive_lambda_patch']:.3f} "
                    f"adap_λm {loss_dict['adaptive_lambda_margin']:.3f} "
                    f"λvar {loss_dict['lambda_var']:.3f} "
                    f"λent {loss_dict['lambda_entropy']:.3f} "
                    f"lr {self.optimizer.param_groups[0]['lr']:.4e} "
                    f"eta {eta}"
                )

        self.scheduler.step()
        
        # Calculate epoch averages
        avg_loss = total_loss / total_samples
        avg_acc = 100 * total_acc / total_samples
        avg_loss_global = total_loss_global / total_samples
        avg_loss_patch = total_loss_patch / total_samples  
        avg_loss_margin = total_loss_margin / total_samples
        avg_loss_var = total_loss_var / total_samples
        avg_loss_entropy = total_loss_entropy / total_samples
        avg_adaptive_lambda_patch = total_adaptive_lambda_patch / total_samples
        avg_adaptive_lambda_margin = total_adaptive_lambda_margin / total_samples
        
        return {
            'loss': avg_loss,
            'accuracy': avg_acc,
            'loss_global': avg_loss_global,
            'loss_patch': avg_loss_patch,
            'loss_margin': avg_loss_margin,
            'loss_var': avg_loss_var,
            'loss_entropy': avg_loss_entropy,
            'adaptive_lambda_patch': avg_adaptive_lambda_patch,
            'adaptive_lambda_margin': avg_adaptive_lambda_margin
        }
    
    def train(self) -> None:
        """Train for all epochs"""
        print("Starting Improved NPT training with variance-aware attention regularization and entropy maximization...")
        max_epoch = self.cfg.OPTIM.MAX_EPOCH
        
        for epoch in range(max_epoch):
            metrics = self.train_epoch(epoch)
            print(f"Epoch [{epoch+1}/{max_epoch}] "
                  f"Loss: {metrics['loss']:.4f} "
                  f"Acc: {metrics['accuracy']:.2f} "
                  f"L_global: {metrics['loss_global']:.4f} "
                  f"L_patch: {metrics['loss_patch']:.4f} "
                  f"L_margin: {metrics['loss_margin']:.4f} "
                  f"L_var: {metrics['loss_var']:.4f} "
                  f"L_entropy: {metrics['loss_entropy']:.4f} "
                  f"Adapt_λp: {metrics['adaptive_lambda_patch']:.3f} "
                  f"Adapt_λm: {metrics['adaptive_lambda_margin']:.3f}")
        
        print("Improved NPT training with variance-aware attention regularization and entropy maximization completed")
        
        # Log final balancer stats
        balancer_stats = self.model.loss_balancer.get_stats()
        print(f"Final EMA stats: Global={balancer_stats['ema_global']:.4f}, "
              f"Patch={balancer_stats['ema_patch']:.4f}, "
              f"Margin={balancer_stats['ema_margin']:.4f}")
    
    def load_model(self, model_path: str) -> None:
        """Load model"""
        print(f"Loading NPT model from {model_path}")
        self.model.prompt_learner.load_state_dict(torch.load(model_path))
        print("NPT model loaded successfully")
    
    def test_ood_detection(self, model: torch.nn.Module, data_loader: DataLoader, T: float = 1.0) -> np.ndarray:
        """Test OOD detection with NPT model"""
        to_np = lambda x: x.data.cpu().numpy()
        concat = lambda x: np.concatenate(x, axis=0)

        mcm_score = []
        
        for batch_idx, (images, labels, *id_flag) in enumerate(tqdm(data_loader, desc="Testing OOD with Improved NPT")):
            images = images.to(self.device)
            model.eval()
            with torch.no_grad():
                # Use class logits only (exclude nuisance)
                output, output_local, _ = model(images)
            
            output /= 100.0
            smax_global = to_np(F.softmax(output/T, dim=-1))
            mcm_global_score = -np.max(smax_global, axis=1)
            mcm_score.append(mcm_global_score)

        return concat(mcm_score)[:len(data_loader.dataset)].copy()
    
    def evaluate(self) -> None:
        """Run evaluation"""
        print("Starting Improved NPT evaluation...")
        
        # Set model to evaluation mode
        self.model.eval()
        
        # Prepare data loader
        _, preprocess = clip_w_local.load(self.cfg.MODEL.BACKBONE.NAME)
        self.args.in_dataset = "imagenet"
        self.args.batch_size = 512
        
        id_data_loader = set_test_loader(self.args, "imagenet", preprocess)
        
        # Calculate in-distribution scores
        in_score_mcm = self.test_ood_detection(self.model, id_data_loader, 1)
        
        # Lists for evaluation
        auroc_list_mcm, fpr_list_mcm = [], []
        results_data = []
        
        # Evaluate out-of-distribution datasets
        out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
        
        scores_dict: Dict[str, Dict[str, np.ndarray]] = {}
        scores_dict["MCM"] = {}
        scores_dict["MCM"]["ImageNet"] = in_score_mcm
        
        for out_dataset in out_datasets:
            print(f"Evaluating OOD dataset: {out_dataset}")
            ood_loader = set_test_loader(self.args, out_dataset, preprocess)
            out_score_mcm = self.test_ood_detection(self.model, ood_loader, 1)

            # Evaluate MCM score
            print("Improved NPT MCM score")
            mcm_results = get_and_print_results(
                self.args, in_score_mcm, out_score_mcm,
                auroc_list_mcm, fpr_list_mcm
            )
            scores_dict["MCM"][out_dataset] = out_score_mcm
            
            # Save results
            results_data = add_results(results_data, mcm_results, out_dataset)

        # Add overall results
        results_data = add_overall_results(results_data, auroc_list_mcm, fpr_list_mcm)

        # Save scores to .npz
        np.savez(f"{self.args.output_dir}/scores.npz", **scores_dict)

        # Save results to JSON
        save_results_to_json(results_data, self.args.output_dir, "results.json")
        print("Improved NPT evaluation completed")


def run_improved_npt_experiment(output_dir_path: str, momentum_beta: float = 0.9, adaptation_alpha: float = 0.1,
                                 min_weight_factor: float = 0.1, max_weight_factor: float = 3.0, 
                                 warmup_steps: int = 50, lambda_var: float = 0.1, lambda_entropy: float = 0.05,
                                 epsilon: float = 1e-8) -> None:
    """
    Run the improved NPT experiment with variance-aware attention regularization and entropy maximization.
    """
    
    # Set up arguments that match the baseline
    args = argparse.Namespace()
    
    # Core arguments
    args.root = "/datasets/LoCoOp"
    args.seed = 0
    args.trainer = "LoCoOp"
    args.dataset_config_file = "configs/datasets/imagenet.yaml"
    args.config_file = "configs/trainers/LoCoOp/vit_b16_ep30.yaml"
    
    # NPT parameters
    dataset = "imagenet"
    cfg = "vit_b16_ep30"
    shots = 1
    nctx = 16
    csc = False
    ctp = "end"
    
    args.output_dir = output_dir_path
    
    # NPT-specific hyperparameters
    args.lambda_patch = 0.25
    args.lambda_margin = 0.25
    args.margin = 0.2
    args.method = "improved_npt_variance_entropy"
    
    # Momentum-based loss balancing hyperparameters
    args.momentum_beta = momentum_beta
    args.adaptation_alpha = adaptation_alpha
    args.min_weight_factor = min_weight_factor
    args.max_weight_factor = max_weight_factor
    args.warmup_steps = warmup_steps
    
    # Variance-aware attention regularization and entropy maximization hyperparameters
    args.lambda_var = lambda_var
    args.lambda_entropy = lambda_entropy
    args.epsilon = epsilon
    
    # Additional parameters
    args.resume = ""
    args.backbone = ""
    args.head = ""
    args.eval_only = False
    args.model_dir = ""
    args.load_epoch = None
    args.no_train = False
    args.T = 1.0
    args.sample_size = 500
    
    # Options for configuration overrides
    args.opts = [
        "TRAINER.LOCOOP.N_CTX", str(nctx),
        "TRAINER.LOCOOP.CSC", str(csc),
        "TRAINER.LOCOOP.CLASS_TOKEN_POSITION", ctp,
        "DATASET.NUM_SHOTS", str(shots)
    ]
    
    print("Running Improved NPT experiment with variance-aware attention regularization and entropy maximization:")
    print(f"  Root: {args.root}")
    print(f"  Dataset: {dataset}")
    print(f"  Config: {cfg}")
    print(f"  Shots: {shots}")
    print(f"  Context tokens: {nctx}")
    print(f"  Lambda patch: {args.lambda_patch}")
    print(f"  Lambda margin: {args.lambda_margin}")
    print(f"  Margin: {args.margin}")
    print(f"  Method: {args.method}")
    print(f"  Momentum beta: {args.momentum_beta}")
    print(f"  Adaptation alpha: {args.adaptation_alpha}")
    print(f"  Min weight factor: {args.min_weight_factor}")
    print(f"  Max weight factor: {args.max_weight_factor}")
    print(f"  Warmup steps: {args.warmup_steps}")
    print(f"  Lambda var: {args.lambda_var}")
    print(f"  Lambda entropy: {args.lambda_entropy}")
    print(f"  Epsilon: {args.epsilon}")
    print(f"  Output dir: {args.output_dir}")
    print()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Setup configuration
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # Merge configurations
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    
    reset_cfg(cfg, args)
    cfg.merge_from_list(args.opts)
    
    # Add NPT hyperparameters to config
    cfg.lambda_patch = args.lambda_patch
    cfg.lambda_margin = args.lambda_margin
    cfg.margin = args.margin
    
    cfg.freeze()
    
    # Set seed
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    
    # Setup logger
    setup_logger(cfg.OUTPUT_DIR)
    
    # Initialize experiment class
    experiment = ImprovedNPTExperiment(cfg, args)
    
    # Validate configuration
    experiment._validate_config()
    
    # Prepare data
    experiment.setup_data()
    
    # Prepare model
    experiment.setup_model()
    
    # Load model for evaluation only
    if args.eval_only:
        if not args.model_dir:
            raise ValueError("--model-dir must be specified for eval-only mode")
        experiment.load_model(args.model_dir)
    
    # Run training
    if not args.eval_only:
        experiment.train()
    
    # Run evaluation
    experiment.evaluate()


def extend_cfg(cfg: CN) -> None:
    """Extend configuration"""
    cfg.TRAINER.LOCOOP = CN()
    cfg.TRAINER.LOCOOP.N_CTX = 16
    cfg.TRAINER.LOCOOP.CSC = False
    cfg.TRAINER.LOCOOP.CTX_INIT = ""
    cfg.TRAINER.LOCOOP.PREC = "fp16"
    cfg.TRAINER.LOCOOP.CLASS_TOKEN_POSITION = "end"

    cfg.DATASET.SUBSAMPLE_CLASSES = "all"


def reset_cfg(cfg: CN, args: argparse.Namespace) -> None:
    """Reset configuration"""
    if args.root:
        cfg.DATASET.ROOT = args.root
    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir
    if args.resume:
        cfg.RESUME = args.resume
    if args.seed:
        cfg.SEED = args.seed
    if args.trainer:
        cfg.TRAINER.NAME = args.trainer
    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone
    if args.head:
        cfg.MODEL.HEAD.NAME = args.head


def main():
    """Main entry point for Improved NPT script."""
    parser = argparse.ArgumentParser(
        description="Improved NPT experiment with variance-aware attention regularization and entropy maximization"
    )
    parser.add_argument(
        "--output-dir", 
        type=str, 
        required=True,
        help="Path to output directory where results will be saved"
    )
    
    # Momentum-based loss balancing specific arguments
    parser.add_argument(
        "--momentum-beta",
        type=float,
        default=0.9,
        help="Momentum factor for EMA smoothing (default: 0.9)"
    )
    parser.add_argument(
        "--adaptation-alpha",
        type=float,
        default=0.1,
        help="Adaptation strength for weight adjustment (default: 0.1)"
    )
    parser.add_argument(
        "--min-weight-factor",
        type=float,
        default=0.1,
        help="Minimum weight scaling factor (default: 0.1)"
    )
    parser.add_argument(
        "--max-weight-factor",
        type=float,
        default=3.0,
        help="Maximum weight scaling factor (default: 3.0)"
    )
    parser.add_argument(
        "--warmup-steps",
        type=int,
        default=50,
        help="Number of steps before applying adaptive weighting (default: 50)"
    )
    
    # Variance-aware attention regularization and entropy maximization arguments
    parser.add_argument(
        "--lambda-var",
        type=float,
        default=0.1,
        help="Weight for variance regularization loss (default: 0.1)"
    )
    parser.add_argument(
        "--lambda-entropy",
        type=float,
        default=0.05,
        help="Weight for entropy maximization loss (default: 0.05)"
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        default=1e-8,
        help="Small constant for numerical stability (default: 1e-8)"
    )
    
    args = parser.parse_args()
    
    # Run the improved NPT experiment
    run_improved_npt_experiment(
        args.output_dir,
        args.momentum_beta,
        args.adaptation_alpha,
        args.min_weight_factor,
        args.max_weight_factor,
        args.warmup_steps,
        args.lambda_var,
        args.lambda_entropy,
        args.epsilon
    )


if __name__ == "__main__":
    main()