"""
Synthetic Data Training Script for RLHF (Reinforcement Learning from Human Feedback)

This module implements a specialized training pipeline for fine-tuning language models using synthetic datasets
with various RLHF methods:
- DPO (Direct Preference Optimization)
- Single Model Training

The script is optimized for synthetic data training with single GPU and LoRA configurations.
Mixed precision and quantization are disabled for stability.
"""

import warnings
import os
import sys
import logging
import yaml
from importlib.metadata import PackageNotFoundError
import json
import torch
# Distributed training imports removed - single GPU only
from peft import LoraConfig, get_peft_model
from trainer_single import SingleModelTrainer

# Configure logging to suppress warnings
logging.getLogger("trl").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*tokenizer.*deprecated.*")
warnings.filterwarnings("ignore", module="trl.*")
warnings.filterwarnings("ignore", module="transformers.*")
warnings.filterwarnings("ignore", module="peft.*")

# Disable all warnings
warnings.simplefilter("ignore")

# Set PyTorch memory allocation settings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import wandb
import argparse
from huggingface_hub import login as hf_login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset


def load_config(config_path):
    """
    Load and process configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        dict: Processed configuration dictionary with shared settings merged
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # If config has 'shared' section, merge it with the root level
    if 'shared' in config:
        shared_config = config.pop('shared')
        config.update(shared_config)
    
    return config


def get_model_and_tokenizer(config):
    """
    Initialize model and tokenizer based on configuration settings.
    
    Args:
        config: Configuration dictionary containing model and training settings
        
    Returns:
        tuple: (model, tokenizer) initialized according to configuration
    """
    # Prepare model configuration - mixed precision and quantization disabled
    model_config = {
        "trust_remote_code": config['model']['trust_remote_code'],
        "torch_dtype": torch.float32,  # Mixed precision disabled per user request
        "device_map": "auto"
    }

    # Quantization disabled per user request
    # Load model without quantization
    model = AutoModelForCausalLM.from_pretrained(
        config['model']['model_name'],
        **model_config
    )
    print("Using regular LoRA for model training (quantization and mixed precision disabled)")
    
    # Configure LoRA
    lora_config = LoraConfig(
        r=config['lora']['r'],
        lora_alpha=config['lora']['lora_alpha'],
        target_modules=config['lora']['target_modules'],
        lora_dropout=config['lora']['lora_dropout'],
        bias=config['lora']['bias'],
        task_type=config['lora']['task_type']
    )
    
    # Convert model to PEFT model with LoRA
    model = get_peft_model(model, lora_config)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config['model']['model_name'],
        trust_remote_code=config['model']['trust_remote_code']
    )
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer


def format_dual_dataset(dataset):
    """
    Format dataset for DualModelTrainer by converting to the required structure.
    
    Args:
        dataset: Input dataset with 'prompt', 'chosen', and 'rejected' fields
        
    Returns:
        list: Formatted dataset with 'x', 'y_w', and 'y_l' fields
    """
    formatted_dataset = []
    for example in dataset:
        if 'prompt' in example and 'chosen' in example and 'rejected' in example:
            formatted_dataset.append({
                'x': example['prompt'],
                'y_w': example['chosen'],
                'y_l': example['rejected']
            })
    
    print(f"Formatted {len(formatted_dataset)} preference pairs for dual training")
    return formatted_dataset


def train_synthetic_model(model, tokenizer, train_dataset, config, trainer_type):
    """
    Train a language model using synthetic data with the specified RLHF method.
    
    Args:
        model: Base language model to be fine-tuned
        tokenizer: Tokenizer for text processing
        train_dataset: Synthetic dataset for training
        config: Configuration dictionary
        trainer_type: Type of trainer to use ('dpo' or 'single')
        
    Returns:
        Trainer instance after training completion
    """
    print(f"\nTraining synthetic model with {trainer_type} method")
    print(f"Dataset size: {len(train_dataset)}")
    
    if trainer_type == 'dpo':
        from trl import DPOConfig, DPOTrainer
        
        # Use model directly (no DDP wrapping for single GPU)
        base_model = model
        
        # Important: we need to handle gradient checkpointing carefully
        base_model.config.use_cache = False  # This is crucial when using gradient checkpointing
        
        # Set run name with date and time for synthetic training
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"synthetic_{trainer_type}_{current_time}"
        
        trl_config = DPOConfig(
            learning_rate=float(config['method_configs']['dpo']['training']['learning_rate']),
            per_device_train_batch_size=config['training']['per_device_train_batch_size'],
            per_device_eval_batch_size=config['training']['per_device_train_batch_size'],
            num_train_epochs=config['training']['num_train_epochs'],
            logging_steps=config['training']['logging_steps'],
            save_strategy="no",  # Disable checkpoint saving during training
            output_dir=os.path.join(config['training']['output_dir'], f"synthetic_{trainer_type}"),
            run_name=run_name,
            beta=config['method_configs']['dpo']['specific']['beta'],
            max_length=config['method_configs']['dpo']['specific']['max_length'],
            max_prompt_length=config['method_configs']['dpo']['specific']['max_prompt_length'],
            gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
            fp16=False,  # Mixed precision disabled per user request
            optim="adamw_torch",
            gradient_checkpointing=False,  # Disabled per user request
            disable_tqdm=False,  # Enable tqdm for single GPU
            dataloader_num_workers=4,  # Add workers for faster data loading
            dataloader_pin_memory=True,  # Pin memory for faster GPU transfer
            remove_unused_columns=False,  # Keep all columns for DPO
        )
        
        # Create the DPO trainer with the unwrapped model
        trainer = DPOTrainer(
            model=base_model,
            args=trl_config,
            train_dataset=train_dataset,
            processing_class=tokenizer
        )

    elif trainer_type == 'single':
        # Format dataset for SingleModelTrainer
        formatted_dataset = format_dual_dataset(train_dataset)
        
        # Set run name with date and time for synthetic training
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"synthetic_{trainer_type}_{current_time}"
        
        # Create trainer configuration
        # Use per_device_batch_size directly for single GPU
        batch_size = config['training']['per_device_train_batch_size']
        
        print(f"Debug: batch_size={batch_size}")
        
        trainer_config = {
            'mu_epochs': config['method_configs']['single']['specific']['mu_epochs'],
            'pi_epochs': config['method_configs']['single']['specific']['pi_epochs'],
            'mu_learning_rate': float(config['method_configs']['single']['specific']['mu_learning_rate']),
            'pi_learning_rate': float(config['method_configs']['single']['specific']['pi_learning_rate']),
            'batch_size': batch_size,  # Single GPU batch size
            'num_train_epochs': config['training']['num_train_epochs'],
            'logging_steps': config['training']['logging_steps'],
            'save_steps': config['training']['save_steps'],
            'output_dir': os.path.join(config['training']['output_dir'], f"synthetic_{trainer_type}"),
            'wandb_project': config['tracking']['wandb_project'],
            'wandb_run_name': run_name,  # Add run name to config
            'beta_pi': config['method_configs']['single']['specific']['beta_pi'],
            'beta': config['method_configs']['single']['specific']['beta'],
            'reference_model_id': config['method_configs']['single']['specific']['reference_model_id'],
            'trust_remote_code': config['model']['trust_remote_code'],
            'resume': config['method_configs']['single']['specific'].get('resume', {})  # Add resume configuration
        }
        
        # Use model directly (no DDP wrapping for single GPU)
        base_model = model
        
        # Create and run trainer
        trainer = SingleModelTrainer(
            model=base_model,
            config=trainer_config,
            train_dataset=formatted_dataset,
            tokenizer=tokenizer
        )

    else:
        raise ValueError(f"Unsupported trainer type: {trainer_type}")
    
    # Execute training
    print(f"Starting synthetic training with {trainer_type}...")
    trainer.train()
    
    return trainer


def preprocess_synthetic_dataset(dataset, tokenizer):
    """
    Preprocess synthetic dataset for training by formatting prompts and responses.
    
    Args:
        dataset: Input synthetic dataset
        tokenizer: Tokenizer for text processing
        
    Returns:
        Dataset: Preprocessed dataset with formatted prompts and responses
    """
    def format_prompt(example):
        # The synthetic dataset has 'x' as the prompt and 'y_w'/'y_l' as responses
        # Add "Answer with one word only:" to match evaluation format
        formatted_prompt = f"{example['x']} Answer with one word only:"
        return {
            "prompt": formatted_prompt,
            "chosen": example["y_w"],
            "rejected": example["y_l"]
        }
    
    # Apply formatting to the dataset
    formatted_dataset = dataset.map(format_prompt)
    
    return formatted_dataset


def save_synthetic_model(trainer, tokenizer, save_path, config, trainer_type):
    """
    Save synthetic trained model, tokenizer, and configuration to disk.
    
    Args:
        trainer: Trained model trainer instance
        tokenizer: Tokenizer used for training
        save_path: Directory path to save the model
        config: Configuration dictionary used for training
        trainer_type: Type of trainer used ('dpo' or 'single')
    """
    # Create synthetic-specific save directory
    synthetic_save_path = os.path.join(save_path, f"synthetic_{trainer_type}")
    
    # Extract the model, handling DDP wrapping
    if isinstance(trainer, SingleModelTrainer):
        # Single model case - add beta subdirectory
        beta = config['method_configs']['single']['specific']['beta']
        synthetic_save_path = os.path.join(synthetic_save_path, f"beta_{beta}")
        os.makedirs(synthetic_save_path, exist_ok=True)
        
        # Get model directly (no DDP wrapping for single GPU)
        model_to_save = trainer.model
        
        model_to_save.save_pretrained(synthetic_save_path)
    else:
        # DPO case - no beta subdirectory
        os.makedirs(synthetic_save_path, exist_ok=True)
        # Get model directly (no DDP wrapping for single GPU)
        model_to_save = trainer.model
        
        model_to_save.save_pretrained(synthetic_save_path)
    
    tokenizer.save_pretrained(synthetic_save_path)
    
    # Save configuration with synthetic training info
    config_path = os.path.join(synthetic_save_path, "synthetic_training_config.json")
    with open(config_path, 'w') as f:
        json.dump({
            'model_config': config['model'],
            'lora_config': config['lora'],
            'training_config': config['training'],
            'method_config': config['method_configs'][trainer_type],
            'dataset_info': config['dataset'],
            'training_type': 'synthetic',
            'trainer_type': trainer_type
        }, f, indent=2)
    
    print(f"Synthetic trained model and configuration saved to {synthetic_save_path}")


def main():
    """
    Main synthetic training script entry point.
    
    Handles argument parsing and execution of synthetic training
    for specified RLHF methods on single GPU.
    """
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train language models with synthetic data using TRL")
    parser.add_argument('--config', type=str, default='synthetic_config_train.yaml', help='Path to synthetic training config file')
    parser.add_argument('--trainer', type=str, default=None, help='Trainer type (dpo, single)')
    # Removed --local_rank argument (single GPU only)
    parser.add_argument('--dataset', type=str, default='color', help='Synthetic dataset to use (color, fruit)')
    args = parser.parse_args()
    
    # Single GPU setup only
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using single GPU training for synthetic data on {device}")
    
    # Load configuration
    config = load_config(args.config)
    
    # Override dataset path based on command line argument
    if args.dataset == 'color':
        config['dataset']['train_path'] = "datasets/synthetic/color_dataset.json"
    elif args.dataset == 'fruit':
        config['dataset']['train_path'] = "datasets/synthetic/fruit_dataset.json"
    else:
        raise ValueError(f"Unsupported synthetic dataset: {args.dataset}")
    
    # Override methods list if trainer is specified in command line
    if args.trainer:
        methods_to_train = [args.trainer]
    else:
        methods_to_train = config['methods']
    
    # Create output directory for synthetic training
    synthetic_output_dir = os.path.join(config['training']['output_dir'], "synthetic")
    config['training']['output_dir'] = synthetic_output_dir
    os.makedirs(synthetic_output_dir, exist_ok=True)
    
    if 'save_dir' in config['training']:
        synthetic_save_dir = os.path.join(config['training']['save_dir'], "synthetic")
        config['training']['save_dir'] = synthetic_save_dir
        os.makedirs(synthetic_save_dir, exist_ok=True)
    
    # Login to Hugging Face if token is provided
    if 'auth' in config and 'hf_token' in config['auth']:
        hf_login(token=config['auth']['hf_token'])
        print("Logged in to Hugging Face")
    
    # Login to Weights & Biases if token is provided
    if 'tracking' in config and 'wandb_token' in config['tracking']:
        wandb.login(key=config['tracking']['wandb_token'])
        print("Logged in to Weights & Biases")
    
    # Load synthetic dataset
    print(f"Loading synthetic dataset from: {config['dataset']['train_path']}")
    dataset = load_dataset('json', data_files=config['dataset']['train_path'])["train"]
    
    # Limit dataset size if num_samples is specified
    if 'num_samples' in config['dataset']:
        num_samples = config['dataset']['num_samples']
        if len(dataset) > num_samples:
            print(f"Limiting synthetic dataset from {len(dataset)} to {num_samples} samples")
            dataset = dataset.select(range(num_samples))
    
    # Preprocess synthetic dataset
    train_dataset = preprocess_synthetic_dataset(dataset, None)  # We'll load the tokenizer with the model later
    print(f"Synthetic training dataset size: {len(train_dataset)}")
        
    # Train each method with synthetic data
    for trainer_type in methods_to_train:
        print(f"\n{'='*50}")
        print(f"Training with {trainer_type} method on synthetic {args.dataset} data")
        print(f"{'='*50}")
        
        # Load model and tokenizer directly in a clean way
        print(f"Loading model: {config['model']['model_name']}")
        
        # Define model config without device_map to avoid DTensor issues
        # Mixed precision and quantization disabled per user request
        model_config = {
            "trust_remote_code": config['model']['trust_remote_code'],
            "torch_dtype": torch.float32,  # Mixed precision disabled per user request
        }
        
        # Load model without quantization or mixed precision
        model = AutoModelForCausalLM.from_pretrained(
            config['model']['model_name'],
            **model_config
        )
        
        # Configure and apply LoRA before device placement and DDP wrapping
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type=config['lora']['task_type']
        )
        
        # Apply LoRA to model
        model = get_peft_model(model, lora_config)
        print("Applied LoRA to model for synthetic training")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['model']['model_name'],
            trust_remote_code=config['model']['trust_remote_code']
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        # Move model to device (single GPU)
        model = model.to(device)
        
        # Make sure use_cache is set to False for gradient checkpointing
        if hasattr(model, 'config'):
            model.config.use_cache = False
        
        # Add _set_static_graph method if it doesn't exist (for compatibility)
        if not hasattr(model, '_set_static_graph'):
            model._set_static_graph = lambda: None
            print("Added dummy _set_static_graph method to model")
        
        # Single GPU setup (no distributed handling needed)
        print(f"Training {trainer_type} model on single GPU")
        
        # Train model with synthetic data using the specific trainer
        trainer = train_synthetic_model(model, tokenizer, train_dataset, config, trainer_type)
        
        # Save trained model
        if 'save_dir' in config['training']:
            save_path = config['training']['save_dir']
            save_synthetic_model(trainer, tokenizer, save_path, config, trainer_type)
        
        # Clean up GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Delete trainer and model references to free memory
        del trainer
        del model
        del tokenizer
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\nAll synthetic training completed for {args.dataset} dataset!")


if __name__ == "__main__":
    main()
