#!/usr/bin/env python3
"""
CSV Augmentation Script using Weak Expert and Strong Generalist Framework

This script reads a CSV file with INPUT and READMIT_30D columns, augments the INPUT column
using a biomedical NER model (weak expert) and a language model (strong generalist),
and saves the augmented data to a designated output folder.

Usage:
    python ours.py input_file.csv [--output_dir ./data] [--model_name Qwen/Qwen3-0.6B] [--method ours|cato|vanilla]
"""

import os
import argparse
import pandas as pd
import torch
from transformers import (
    AutoTokenizer, 
    pipeline, 
    BitsAndBytesConfig, 
    AutoModelForCausalLM, 
    AutoModelForTokenClassification
)
from tqdm import tqdm
import nltk
from nltk.tokenize import sent_tokenize
import logging
import json
import hashlib
from pathlib import Path

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ClinicalNoteAugmenter:
    """
    Clinical note augmenter using weak expert (biomedical NER) and strong generalist (LLM) models.
    """
    
    def __init__(self, model_name="Qwen/Qwen3-0.6B", huggingface_token=None, use_thinking=False):
        """
        Initialize the augmenter with weak expert and strong generalist models.
        
        Args:
            model_name (str): Name of the strong generalist model
            huggingface_token (str): HuggingFace token for accessing models
        """
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.huggingface_token = huggingface_token or os.environ.get('HUGGING_FACE_TOKEN')
        self.use_thinking = bool(use_thinking)
        
        if not self.huggingface_token:
            logger.warning("No HuggingFace token provided. Some models may not be accessible.")
        
        logger.info(f"Device set to use {self.device}")
        
        # Initialize weak expert model (biomedical NER)
        self._setup_weak_expert()
        
        # Initialize strong generalist model
        self._setup_strong_generalist(model_name)
        
        # Download required NLTK data
        try:
            nltk.download('punkt_tab', quiet=True)
        except Exception as e:
            logger.warning(f"Could not download NLTK data: {e}")
    
    def _setup_weak_expert(self):
        """Initialize the weak expert (biomedical NER) model."""
        logger.info("Setting up weak expert (biomedical NER) model...")
        
        weak_expert_model_name = "d4data/biomedical-ner-all"
        
        try:
            self.weak_tokenizer = AutoTokenizer.from_pretrained(weak_expert_model_name)
            self.weak_model = AutoModelForTokenClassification.from_pretrained(weak_expert_model_name)
            
            self.ner_classifier = pipeline(
                task="ner",
                model=self.weak_model,
                tokenizer=self.weak_tokenizer,
                aggregation_strategy="simple"
            )
            logger.info("Weak expert model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load weak expert model: {e}")
            raise
    
    def _setup_strong_generalist(self, model_name):
        """Initialize the strong generalist (LLM) model."""
        logger.info(f"Setting up strong generalist model: {model_name}")
        
        try:
            # Configure quantization for memory efficiency
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            
            self.strong_tokenizer = AutoTokenizer.from_pretrained(
                model_name, 
                use_auth_token=self.huggingface_token,
                trust_remote_code=True
            )
            self.strong_model = AutoModelForCausalLM.from_pretrained(
                model_name,
                use_auth_token=self.huggingface_token,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
            
            self.augmentation_pipeline = pipeline(
                "text-generation",
                model=self.strong_model,
                tokenizer=self.strong_tokenizer,
                torch_dtype=torch.bfloat16,
                device_map="auto"
            )
            logger.info("Strong generalist model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load strong generalist model: {e}")
            raise
    
    def _get_checkpoint_path(self, output_path):
        """Get the checkpoint file path for a given output file."""
        output_dir = os.path.dirname(output_path)
        base_name = os.path.splitext(os.path.basename(output_path))[0]
        return os.path.join(output_dir, f"{base_name}_checkpoint.json")
    
    def _load_checkpoint(self, checkpoint_path):
        """Load checkpoint data if it exists."""
        if os.path.exists(checkpoint_path):
            try:
                with open(checkpoint_path, 'r') as f:
                    checkpoint_data = json.load(f)
                logger.info(f"Found checkpoint with {len(checkpoint_data)} processed notes")
                return checkpoint_data
            except Exception as e:
                logger.warning(f"Failed to load checkpoint: {e}")
        return {}
    
    def _save_checkpoint(self, checkpoint_path, checkpoint_data):
        """Save checkpoint data."""
        try:
            with open(checkpoint_path, 'w') as f:
                json.dump(checkpoint_data, f, indent=2)
        except Exception as e:
            logger.warning(f"Failed to save checkpoint: {e}")
    
    def _get_note_hash(self, note):
        """Generate a hash for a note to track changes."""
        return hashlib.md5(note.encode()).hexdigest()
    
    def weak_expert_ner(self, text):
        """
        Extract medical entities using the weak expert (biomedical NER) model.
        
        Args:
            text (str): Input text to extract entities from
            
        Returns:
            list: List of extracted medical entities
        """
        # Filter groups for relevant medical entities
        filtered_groups = [
            'Disease_disorder', 'Sign_symptom', 'Lab_value', 'Height', 'Severity', 
            'Dosage', 'Frequency', 'Medication', 'Duration', 'Diagnostic_procedure', 
            'Therapeutic_procedure', 'Biological_structure', 'Biological_attribute', 
            'History', 'Family_history', 'Age', 'Sex', 'Clinical_event'
        ]
        
        default_max_length = 512
        model_max_length = self.ner_classifier.tokenizer.model_max_length
        max_length = default_max_length if (
            model_max_length is None or 
            model_max_length <= 0 or 
            model_max_length > 10000
        ) else model_max_length
        
        sentences = sent_tokenize(text)
        all_entities = []
        
        for sentence in sentences:
            try:
                encoded = self.ner_classifier.tokenizer.encode(
                    sentence, 
                    add_special_tokens=True, 
                    max_length=max_length, 
                    truncation=True
                )
                truncated_sentence = self.ner_classifier.tokenizer.decode(
                    encoded, 
                    skip_special_tokens=True
                )
                
                predictions = self.ner_classifier(truncated_sentence)
                for pred in predictions:
                    entity_group = pred['entity_group']
                    if entity_group not in filtered_groups:
                        continue
                    entity = pred["word"]
                    entity_block = ":".join([entity_group, entity])
                    all_entities.append(entity_block)
            except Exception as e:
                logger.warning(f"Error processing sentence: {e}")
                continue
        
        return list(set(all_entities))
    
    def rewrite_clinical_note_ours(self, note, max_new_tokens=5000, num_beams=1, temperature=0.5):
        """
        Rewrite a clinical note using the weak expert and strong generalist framework.
        
        Args:
            note (str): Original clinical note
            max_new_tokens (int): Maximum new tokens to generate
            num_beams (int): Number of beams for generation
            temperature (float): Temperature for generation
            
        Returns:
            str: Augmented clinical note
        """
        try:
            # Step 1: Extract medical keywords using the weak expert
            medical_keywords = self.weak_expert_ner(note)
            
            # Step 2: Create enhanced prompts that provide specific entities but prevent listing
            if medical_keywords:
                # Extract just the entity names (without types) for cleaner presentation
                entity_names = []
                for keyword in medical_keywords:
                    if ':' in keyword:
                        entity_name = keyword.split(':')[1]
                        entity_names.append(entity_name)
                    else:
                        entity_names.append(keyword)
                
                # Remove duplicates while preserving order
                unique_entities = list(dict.fromkeys(entity_names))
                entities_text = ', '.join(unique_entities)
                
                system_content = """You are a medical AI assistant with expertise in clinical documentation. Your task is to rewrite clinical notes while maintaining complete medical accuracy.

IMPORTANT INSTRUCTIONS:
- You must preserve all medical entities exactly as they appear
- Do NOT list or enumerate the entities - simply incorporate them naturally into your rewritten text
- You may change sentence structure, word choice, and writing style
- You must NOT change any medical terminology, dosages, measurements, or clinical findings
- Ensure the rewritten note contains the same medical information as the original"""
                
                user_content = f"""The following clinical note contains these important medical entities that must be preserved exactly: {entities_text}

Please rewrite this clinical note while naturally incorporating all medical entities. Do not list the entities separately - include them naturally in your rewrite:

{note}"""
            else:
                system_content = """You are a medical AI assistant with expertise in clinical documentation. Your task is to rewrite clinical notes while maintaining complete medical accuracy.

IMPORTANT INSTRUCTIONS:
- Preserve all medical facts, diagnoses, medications, procedures, and clinical findings
- You may change sentence structure, word choice, and writing style
- You must NOT change any medical terminology, dosages, measurements, or clinical findings
- Ensure the rewritten note contains the same medical information as the original"""
                
                user_content = f"""Please rewrite this clinical note while preserving all medical information:

{note}"""
            
            # Step 3: Create prompt/messages for the strong generalist (Qwen format)
            if self.use_thinking:
                messages = [
                    {"role": "system", "content": system_content},
                    {"role": "user", "content": user_content}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=32768
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                end_think_id = self.strong_tokenizer.convert_tokens_to_ids("</think>")
                if end_think_id is None:
                    end_think_id = 151668
                try:
                    index = len(output_ids) - output_ids[::-1].index(end_think_id)
                except ValueError:
                    index = 0
                final_text = self.strong_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
                return final_text
            else:
                messages = [
                    {"role": "system", "content": system_content},
                    {"role": "user", "content": user_content}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=num_beams,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.strong_tokenizer.eos_token_id
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                final_text = self.strong_tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
                return final_text
            
        except Exception as e:
            logger.error(f"Error rewriting clinical note: {e}")
            return note  # Return original note if augmentation fails
    
    def rewrite_clinical_note_naive(self, note, max_new_tokens=5000, num_beams=1, temperature=0.5):
        """
        Naive baseline: Simply rewrite the clinical note in different style without preserving medical entities.
        
        Args:
            note (str): Original clinical note
            max_new_tokens (int): Maximum new tokens to generate
            num_beams (int): Number of beams for generation
            temperature (float): Temperature for generation
            
        Returns:
            str: Augmented clinical note
        """
        try:
            if self.use_thinking:
                messages = [
                    {"role": "system", "content": "You are an AI assistant. Your task is to rewrite the given clinical note in a different writing style while maintaining the same medical information and clinical accuracy. You can change sentence structure, word choice, and writing flow, but preserve all medical facts and details."},
                    {"role": "user", "content": f"Please rewrite this clinical note in a different style:\n\n{note}"}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=32768
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                end_think_id = self.strong_tokenizer.convert_tokens_to_ids("</think>")
                if end_think_id is None:
                    end_think_id = 151668
                try:
                    index = len(output_ids) - output_ids[::-1].index(end_think_id)
                except ValueError:
                    index = 0
                final_text = self.strong_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
                return final_text
            else:
                messages = [
                    {"role": "system", "content": "You are an AI assistant. Your task is to rewrite the given clinical note in a different writing style while maintaining the same medical information and clinical accuracy. You can change sentence structure, word choice, and writing flow, but preserve all medical facts and details."},
                    {"role": "user", "content": f"Please rewrite this clinical note in a different style:\n\n{note}"}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=num_beams,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.strong_tokenizer.eos_token_id
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                final_text = self.strong_tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
                return final_text
            
        except Exception as e:
            logger.error(f"Error in naive rewriting: {e}")
            return note  # Return original note if augmentation fails
    
    def rewrite_clinical_note_vanilla(self, note, max_new_tokens=5000, num_beams=1, temperature=0.5):
        """
        Vanilla baseline: Minimal prompt to rewrite the input.
        """
        try:
            if self.use_thinking:
                messages = [
                    {"role": "system", "content": "Rewrite the input."},
                    {"role": "user", "content": note}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=32768
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                end_think_id = self.strong_tokenizer.convert_tokens_to_ids("</think>")
                if end_think_id is None:
                    end_think_id = 151668
                try:
                    index = len(output_ids) - output_ids[::-1].index(end_think_id)
                except ValueError:
                    index = 0
                final_text = self.strong_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
                return final_text
            else:
                messages = [
                    {"role": "system", "content": "Rewrite the input."},
                    {"role": "user", "content": note}
                ]
                text = self.strong_tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False
                )
                model_inputs = self.strong_tokenizer([text], return_tensors="pt").to(self.strong_model.device)
                generated_ids = self.strong_model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=num_beams,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.strong_tokenizer.eos_token_id
                )
                output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
                final_text = self.strong_tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
                return final_text
        except Exception as e:
            logger.error(f"Error in vanilla rewriting: {e}")
            return note
    
    def augment_csv_file(self, input_file, output_dir="./data", output_filename=None, method="ours", 
                        checkpoint_interval=100, resume=True):
        """
        Augment a CSV file containing clinical notes with checkpointing support.
        
        Args:
            input_file (str): Path to input CSV file
            output_dir (str): Directory to save augmented data
            method (str): Augmentation method - "ours" or "naive"
            checkpoint_interval (int): Save checkpoint every N processed notes
            resume (bool): Whether to resume from existing checkpoint
            
        Returns:
            str: Path to the augmented CSV file
        """
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Use output_dir directly (no subfolder)
        
        # Read input CSV
        try:
            df = pd.read_csv(input_file)
            logger.info(f"Loaded CSV with {len(df)} rows")
        except Exception as e:
            logger.error(f"Failed to read CSV file: {e}")
            raise
        
        # Validate columns and determine label column
        if 'INPUT' not in df.columns:
            raise ValueError("Missing required column: INPUT")
        if 'death' in df.columns:
            label_column = 'death'
        elif 'STAY_DAYS' in df.columns:
            label_column = 'STAY_DAYS'
        elif 'READMIT_30D' in df.columns:
            label_column = 'READMIT_30D'
        else:
            raise ValueError("Missing required label column: expected death, STAY_DAYS, or READMIT_30D")
        
        # Validate method
        if method not in ["ours", "cato", "vanilla"]:
            raise ValueError(f"Method must be 'ours', 'cato', or 'vanilla', got: {method}")
        
        logger.info(f"Using augmentation method: {method}")
        
        # Generate output filename and paths
        if output_filename is None:
            base_name = os.path.splitext(os.path.basename(input_file))[0]
            output_filename = f"{base_name}_{method}.csv"
        
        output_path = os.path.join(output_dir, output_filename)
        checkpoint_path = self._get_checkpoint_path(output_path)
        
        # Check if output file already exists and load checkpoint
        existing_data = []
        processed_hashes = {}
        
        if resume and os.path.exists(output_path):
            try:
                existing_df = pd.read_csv(output_path)
                logger.info(f"Found existing output file with {len(existing_df)} rows")
                
                # Convert existing data to our format
                for _, row in existing_df.iterrows():
                    # Support legacy columns and new 'death' column in existing outputs
                    if 'death' in existing_df.columns:
                        label_value = row['death']
                    elif 'STAY_DAYS' in existing_df.columns:
                        label_value = row['STAY_DAYS']
                    elif 'readmit_30D' in existing_df.columns:
                        label_value = row['readmit_30D']
                    elif 'READMIT_30D' in existing_df.columns:
                        label_value = row['READMIT_30D']
                    else:
                        label_value = None
                    note_data = {
                        'original_note': row['original_note'],
                        'augmented_note': row['augmented_note'],
                        'death': label_value,
                        'method': method
                    }
                    existing_data.append(note_data)
                    
                    # Track processed notes by hash
                    note_hash = self._get_note_hash(row['original_note'])
                    processed_hashes[note_hash] = len(existing_data) - 1
                
                logger.info(f"Loaded {len(existing_data)} existing augmented notes")
                
            except Exception as e:
                logger.warning(f"Failed to load existing output file: {e}")
                existing_data = []
                processed_hashes = {}
        
        # Load checkpoint for additional progress tracking
        checkpoint_data = {}
        if resume:
            checkpoint_data = self._load_checkpoint(checkpoint_path)
            logger.info(f"Loaded checkpoint with {len(checkpoint_data)} entries")
        
        # Process each note
        augmented_notes = existing_data.copy()
        total_processed = len(augmented_notes)
        
        logger.info(f"Starting augmentation. {total_processed} notes already processed.")
        
        for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Augmenting notes using {method} method"):
            try:
                original_note = str(row['INPUT'])
                death_label = row[label_column]
                note_hash = self._get_note_hash(original_note)
                
                # Check if this note was already processed
                if note_hash in processed_hashes:
                    logger.debug(f"Note {idx} already processed, skipping")
                    continue
                
                # Augment the note based on selected method
                if method == "ours":
                    augmented_note = self.rewrite_clinical_note_ours(original_note)
                elif method == "cato":
                    augmented_note = self.rewrite_clinical_note_naive(original_note)
                else:  # vanilla
                    augmented_note = self.rewrite_clinical_note_vanilla(original_note)
                
                # Add to results
                note_data = {
                    'original_note': original_note,
                    'augmented_note': augmented_note,
                    'death': death_label,
                    'method': method
                }
                augmented_notes.append(note_data)
                
                # Track processed notes
                processed_hashes[note_hash] = len(augmented_notes) - 1
                total_processed += 1
                
                # Save checkpoint periodically
                if resume and total_processed % checkpoint_interval == 0:
                    checkpoint_data = {
                        'total_processed': total_processed,
                        'method': method,
                        'input_file': input_file,
                        'processed_hashes': list(processed_hashes.keys())
                    }
                    self._save_checkpoint(checkpoint_path, checkpoint_data)
                    # Write intermediate CSV to the final output path (overwrites with latest progress)
                    temp_df = pd.DataFrame(augmented_notes)
                    temp_df.to_csv(output_path, index=False)
                    logger.info(f"Saved intermediate CSV progress to {output_path}")
                
            except Exception as e:
                logger.error(f"Error processing row {idx}: {e}")
                # Add original note if augmentation fails
                note_data = {
                    'original_note': str(row['INPUT']),
                    'augmented_note': str(row['INPUT']),  # Keep original
                    'death': row.get(label_column, None),
                    'method': method
                }
                augmented_notes.append(note_data)
                total_processed += 1
        
        # Create final output DataFrame
        output_df = pd.DataFrame(augmented_notes)
        
        # Save final augmented data
        try:
            output_df.to_csv(output_path, index=False)
            logger.info(f"Saved final augmented data to: {output_path}")
            
            # Intermediate CSV was periodically written to the final output path during processing
            
            # Save final checkpoint if resume is enabled
            if resume:
                final_checkpoint = {
                    'total_processed': total_processed,
                    'method': method,
                    'input_file': input_file,
                    'output_file': output_path,
                    'processed_hashes': list(processed_hashes.keys()),
                    'completed': True
                }
                self._save_checkpoint(checkpoint_path, final_checkpoint)
            
        except Exception as e:
            logger.error(f"Failed to save augmented data: {e}")
            raise
        
        return output_path


def main():
    """Main function to run the CSV augmentation script."""
    parser = argparse.ArgumentParser(
        description="Augment CSV files using weak expert and strong generalist framework"
    )
    parser.add_argument(
        "input_file", 
        help="Path to input CSV file with INPUT and death columns"
    )
    parser.add_argument(
        "--output_dir", 
        default="./data", 
        help="Output directory for augmented data (default: ./data)"
    )
    parser.add_argument(
        "--output_filename", 
        help="Output filename (optional, will auto-generate if not provided)"
    )
    parser.add_argument(
        "--model_name", 
        default="Qwen/Qwen3-0.6B",
        help="Strong generalist model name (default: Qwen/Qwen3-0.6B)"
    )
    parser.add_argument(
        "--huggingface_token", 
        help="HuggingFace token for accessing models"
    )
    parser.add_argument(
        "--use_thinking",
        action="store_true",
        help="Enable Qwen thinking mode (default: off). When enabled, max_new_tokens=32768"
    )
    parser.add_argument(
        "--method",
        choices=["ours", "cato", "vanilla"],
        default="ours",
        help="Augmentation method: 'ours' (weak expert + strong generalist), 'cato' (current naive), or 'vanilla' (simple 'rewrite the input')"
    )
    parser.add_argument(
        "--checkpoint_interval",
        type=int,
        default=1,
        help="Save checkpoint every N processed notes (default: 100)"
    )
    parser.add_argument(
        "--no_resume",
        action="store_true",
        help="Disable resume from checkpoint functionality"
    )
    
    args = parser.parse_args()
    
    # Validate input file
    if not os.path.exists(args.input_file):
        logger.error(f"Input file not found: {args.input_file}")
        return 1
    
    try:
        # Initialize augmenter
        logger.info("Initializing Clinical Note Augmenter...")
        augmenter = ClinicalNoteAugmenter(
            model_name=args.model_name,
            huggingface_token=args.huggingface_token,
            use_thinking=args.use_thinking
        )
        
        # Augment CSV file
        output_path = augmenter.augment_csv_file(
            input_file=args.input_file,
            output_dir=args.output_dir,
            output_filename=args.output_filename,
            method=args.method,
            checkpoint_interval=args.checkpoint_interval,
            resume=not args.no_resume
        )
        
        logger.info(f"Augmentation completed successfully using {args.method} method!")
        logger.info(f"Output saved to: {output_path}")
        
        return 0
        
    except Exception as e:
        logger.error(f"Augmentation failed: {e}")
        return 1


if __name__ == "__main__":
    exit(main())
