import warnings
import os
import sys
import logging
import yaml
from importlib.metadata import PackageNotFoundError
import json
import torch
# Distributed training imports removed - single GPU only
from peft import LoraConfig, get_peft_model
from trainer_single import SingleModelTrainer

# Configure logging to suppress warnings
logging.getLogger("trl").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*tokenizer.*deprecated.*")
warnings.filterwarnings("ignore", module="trl.*")
warnings.filterwarnings("ignore", module="transformers.*")
warnings.filterwarnings("ignore", module="peft.*")

# Disable all warnings
warnings.simplefilter("ignore")

# Set PyTorch memory allocation settings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import wandb
import argparse
from huggingface_hub import login as hf_login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset


def load_config(config_path):
    """
    Load and process configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        dict: Processed configuration dictionary with shared settings merged
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # If config has 'shared' section, merge it with the root level
    if 'shared' in config:
        shared_config = config.pop('shared')
        config.update(shared_config)
    
    return config


def get_model_and_tokenizer(config):
    """
    Initialize model and tokenizer based on configuration settings.
    
    Args:
        config: Configuration dictionary containing model and training settings
        
    Returns:
        tuple: (model, tokenizer) initialized according to configuration
    """
    # Prepare model configuration - mixed precision and quantization disabled
    model_config = {
        "trust_remote_code": config['model']['trust_remote_code'],
        "torch_dtype": torch.float32,  # Mixed precision disabled per user request
        "device_map": "auto"
    }

    # Quantization disabled per user request
    # Load model without quantization
    model = AutoModelForCausalLM.from_pretrained(
        config['model']['model_name'],
        **model_config
    )
    print("Using regular LoRA for model training (quantization and mixed precision disabled)")
    
    # Configure LoRA
    lora_config = LoraConfig(
        r=config['lora']['r'],
        lora_alpha=config['lora']['lora_alpha'],
        target_modules=config['lora']['target_modules'],
        lora_dropout=config['lora']['lora_dropout'],
        bias=config['lora']['bias'],
        task_type=config['lora']['task_type']
    )
    
    # Convert model to PEFT model with LoRA
    model = get_peft_model(model, lora_config)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config['model']['model_name'],
        trust_remote_code=config['model']['trust_remote_code']
    )
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer


def format_dual_dataset(dataset):
    """
    Format dataset for DualModelTrainer by converting to the required structure.
    
    Args:
        dataset: Input dataset with 'prompt', 'chosen', and 'rejected' fields
        
    Returns:
        list: Formatted dataset with 'x', 'y_w', and 'y_l' fields
    """
    formatted_dataset = []
    for example in dataset:
        if 'prompt' in example and 'chosen' in example and 'rejected' in example:
            formatted_dataset.append({
                'x': example['prompt'],
                'y_w': example['chosen'],
                'y_l': example['rejected']
            })
    
    print(f"Formatted {len(formatted_dataset)} preference pairs for dual training")
    return formatted_dataset


def train_synthetic_model(model, tokenizer, train_dataset, config, trainer_type):
    """
    Train a language model using synthetic data with the specified RLHF method.
    
    Args:
        model: Base language model to be fine-tuned
        tokenizer: Tokenizer for text processing
        train_dataset: Synthetic dataset for training
        config: Configuration dictionary
        trainer_type: Type of trainer to use ('dpo' or 'single')
        
    Returns:
        Trainer instance after training completion
    """
    print(f"\nTraining synthetic model with {trainer_type} method")
    print(f"Dataset size: {len(train_dataset)}")
    
    if trainer_type == 'dpo':
        from trl import DPOConfig, DPOTrainer
        
        # Use model directly (no DDP wrapping for single GPU)
        base_model = model
        
        # Important: we need to handle gradient checkpointing carefully
        base_model.config.use_cache = False  # This is crucial when using gradient checkpointing
        
        # Set run name with date and time for synthetic training
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"synthetic_{trainer_type}_{current_time}"
        
        trl_config = DPOConfig(
            learning_rate=float(config['method_configs']['dpo']['training']['learning_rate']),
            per_device_train_batch_size=config['training']['per_device_train_batch_size'],
            per_device_eval_batch_size=config['training']['per_device_train_batch_size'],
            num_train_epochs=config['training']['num_train_epochs'],
            logging_steps=config['training']['logging_steps'],
            save_strategy="no",  # Disable checkpoint saving during training
            output_dir=os.path.join(config['training']['output_dir'], f"synthetic_{trainer_type}"),
            run_name=run_name,
            beta=config['method_configs']['dpo']['specific']['beta'],
            max_length=config['method_configs']['dpo']['specific']['max_length'],
            max_prompt_length=config['method_configs']['dpo']['specific']['max_prompt_length'],
            gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
            fp16=False,  # Mixed precision disabled per user request
            optim="adamw_torch",
            gradient_checkpointing=False,  # Disabled per user request
            disable_tqdm=False,  # Enable tqdm for single GPU
            dataloader_num_workers=4,  # Add workers for faster data loading
            dataloader_pin_memory=True,  # Pin memory for faster GPU transfer
            remove_unused_columns=False,  # Keep all columns for DPO
        )
        
        # Create the DPO trainer with the unwrapped model
        trainer = DPOTrainer(
            model=base_model,
            args=trl_config,
            train_dataset=train_dataset,
            processing_class=tokenizer
        )

    elif trainer_type == 'single':
        # Format dataset for SingleModelTrainer
        formatted_dataset = format_dual_dataset(train_dataset)
        
        # Set run name with date and time for synthetic training
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"synthetic_{trainer_type}_{current_time}"
        
        # Create trainer configuration
        # Use per_device_batch_size directly for single GPU
        batch_size = config['training']['per_device_train_batch_size']
        
        print(f"Debug: batch_size={batch_size}")
        
        trainer_config = {
            'mu_epochs': config['method_configs']['single']['specific']['mu_epochs'],
            'pi_epochs': config['method_configs']['single']['specific']['pi_epochs'],
            'mu_learning_rate': float(config['method_configs']['single']['specific']['mu_learning_rate']),
            'pi_learning_rate': float(config['method_configs']['single']['specific']['pi_learning_rate']),
            'batch_size': batch_size,  # Single GPU batch size
            'num_train_epochs': config['training']['num_train_epochs'],
            'logging_steps': config['training']['logging_steps'],
            'save_steps': config['training']['save_steps'],
            'output_dir': os.path.join(config['training']['output_dir'], f"synthetic_{trainer_type}"),
            'wandb_project': config['tracking']['wandb_project'],
            'wandb_run_name': run_name,  # Add run name to config
            'beta_pi': config['method_configs']['single']['specific']['beta_pi'],
            'beta': config['method_configs']['single']['specific']['beta'],
            'reference_model_id': config['method_configs']['single']['specific']['reference_model_id'],
            'trust_remote_code': config['model']['trust_remote_code'],
            'resume': config['method_configs']['single']['specific'].get('resume', {})  # Add resume configuration
        }
        
        # Use model directly (no DDP wrapping for single GPU)
        base_model = model
        
        # Create and run trainer
        trainer = SingleModelTrainer(
            model=base_model,
            config=trainer_config,
            train_dataset=formatted_dataset,
            tokenizer=tokenizer
        )

    else:
        raise ValueError(f"Unsupported trainer type: {trainer_type}")
    
    # Execute training
    print(f"Starting synthetic training with {trainer_type}...")
    trainer.train()
    
    return trainer


def preprocess_synthetic_dataset(dataset, tokenizer):
    """
    Preprocess synthetic dataset for training by formatting prompts and responses.
    
    Args:
        dataset: Input synthetic dataset
        tokenizer: Tokenizer for text processing
        
    Returns:
        Dataset: Preprocessed dataset with formatted prompts and responses
    """
    def format_prompt(example):
        # The synthetic dataset has 'x' as the prompt and 'y_w'/'y_l' as responses
        # Add "Answer with one word only:" to match evaluation format
        formatted_prompt = f"{example['x']} Answer with one word only:"
        return {
            "prompt": formatted_prompt,
            "chosen": example["y_w"],
            "rejected": example["y_l"]
        }
    
    # Apply formatting to the dataset
    formatted_dataset = dataset.map(format_prompt)
    
    return formatted_dataset


def save_synthetic_model(trainer, tokenizer, save_path, config, trainer_type):
    """
    Save synthetic trained model, tokenizer, and configuration to disk.
    
    Args:
        trainer: Trained model trainer instance
        tokenizer: Tokenizer used for training
        save_path: Directory path to save the model
        config: Configuration dictionary used for training
        trainer_type: Type of trainer used ('dpo' or 'single')
    """
    # Create synthetic-specific save directory
    synthetic_save_path = os.path.join(save_path, f"synthetic_{trainer_type}")
    
    # Extract the model, handling DDP wrapping
    if isinstance(trainer, SingleModelTrainer):
        # Single model case - add beta subdirectory
        beta = config['method_configs']['single']['specific']['beta']
        synthetic_save_path = os.path.join(synthetic_save_path, f"beta_{beta}")
        os.makedirs(synthetic_save_path, exist_ok=True)
        
        # Get model directly (no DDP wrapping for single GPU)
        model_to_save = trainer.model
        
        model_to_save.save_pretrained(synthetic_save_path)
    else:
        # DPO case - no beta subdirectory
        os.makedirs(synthetic_save_path, exist_ok=True)
        # Get model directly (no DDP wrapping for single GPU)
        model_to_save = trainer.model
        
        model_to_save.save_pretrained(synthetic_save_path)
    
    tokenizer.save_pretrained(synthetic_save_path)
    
    # Save configuration with synthetic training info
    config_path = os.path.join(synthetic_save_path, "synthetic_training_config.json")
    with open(config_path, 'w') as f:
        json.dump({
            'model_config': config['model'],
            'lora_config': config['lora'],
            'training_config': config['training'],
            'method_config': config['method_configs'][trainer_type],
            'dataset_info': config['dataset'],
            'training_type': 'synthetic',
            'trainer_type': trainer_type
        }, f, indent=2)
    
    print(f"Synthetic trained model and configuration saved to {synthetic_save_path}")


def main():
    """
    Main synthetic training script entry point.
    
    Handles argument parsing and execution of synthetic training
    for specified RLHF methods on single GPU.
    """
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train language models with synthetic data using TRL")
    parser.add_argument('--config', type=str, default='synthetic_config_train.yaml', help='Path to synthetic training config file')
    parser.add_argument('--trainer', type=str, default=None, help='Trainer type (dpo, single)')
    # Removed --local_rank argument (single GPU only)
    parser.add_argument('--dataset', type=str, default='color', help='Synthetic dataset to use (color, fruit)')
    args = parser.parse_args()
    
    # Single GPU setup only
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using single GPU training for synthetic data on {device}")
    
    # Load configuration
    config = load_config(args.config)
    
    # Override dataset path based on command line argument
    if args.dataset == 'color':
        config['dataset']['train_path'] = "datasets/synthetic/color_dataset.json"
    elif args.dataset == 'fruit':
        config['dataset']['train_path'] = "datasets/synthetic/fruit_dataset.json"
    else:
        raise ValueError(f"Unsupported synthetic dataset: {args.dataset}")
    
    # Override methods list if trainer is specified in command line
    if args.trainer:
        methods_to_train = [args.trainer]
    else:
        methods_to_train = config['methods']
    
    # Create output directory for synthetic training
    synthetic_output_dir = os.path.join(config['training']['output_dir'], "synthetic")
    config['training']['output_dir'] = synthetic_output_dir
    os.makedirs(synthetic_output_dir, exist_ok=True)
    
    if 'save_dir' in config['training']:
        synthetic_save_dir = os.path.join(config['training']['save_dir'], "synthetic")
        config['training']['save_dir'] = synthetic_save_dir
        os.makedirs(synthetic_save_dir, exist_ok=True)
    
    # Login to Hugging Face if token is provided
    if 'auth' in config and 'hf_token' in config['auth']:
        hf_login(token=config['auth']['hf_token'])
        print("Logged in to Hugging Face")
    
    # Login to Weights & Biases if token is provided
    if 'tracking' in config and 'wandb_token' in config['tracking']:
        wandb.login(key=config['tracking']['wandb_token'])
        print("Logged in to Weights & Biases")
    
    # Load synthetic dataset
    print(f"Loading synthetic dataset from: {config['dataset']['train_path']}")
    dataset = load_dataset('json', data_files=config['dataset']['train_path'])["train"]
    
    # Limit dataset size if num_samples is specified
    if 'num_samples' in config['dataset']:
        num_samples = config['dataset']['num_samples']
        if len(dataset) > num_samples:
            print(f"Limiting synthetic dataset from {len(dataset)} to {num_samples} samples")
            dataset = dataset.select(range(num_samples))
    
    # Preprocess synthetic dataset
    train_dataset = preprocess_synthetic_dataset(dataset, None)  # We'll load the tokenizer with the model later
    print(f"Synthetic training dataset size: {len(train_dataset)}")
        
    # Train each method with synthetic data
    for trainer_type in methods_to_train:
        print(f"\n{'='*50}")
        print(f"Training with {trainer_type} method on synthetic {args.dataset} data")
        print(f"{'='*50}")
        
        # Load model and tokenizer directly in a clean way
        print(f"Loading model: {config['model']['model_name']}")
        
        # Define model config without device_map to avoid DTensor issues
        # Mixed precision and quantization disabled per user request
        model_config = {
            "trust_remote_code": config['model']['trust_remote_code'],
            "torch_dtype": torch.float32,  # Mixed precision disabled per user request
        }
        
        # Load model without quantization or mixed precision
        model = AutoModelForCausalLM.from_pretrained(
            config['model']['model_name'],
            **model_config
        )
        
        # Configure and apply LoRA before device placement and DDP wrapping
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type=config['lora']['task_type']
        )
        
        # Apply LoRA to model
        model = get_peft_model(model, lora_config)
        print("Applied LoRA to model for synthetic training")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['model']['model_name'],
            trust_remote_code=config['model']['trust_remote_code']
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        # Move model to device (single GPU)
        model = model.to(device)
        
        # Make sure use_cache is set to False for gradient checkpointing
        if hasattr(model, 'config'):
            model.config.use_cache = False
        
        # Add _set_static_graph method if it doesn't exist (for compatibility)
        if not hasattr(model, '_set_static_graph'):
            model._set_static_graph = lambda: None
            print("Added dummy _set_static_graph method to model")
        
        # Single GPU setup (no distributed handling needed)
        print(f"Training {trainer_type} model on single GPU")
        
        # Train model with synthetic data using the specific trainer
        trainer = train_synthetic_model(model, tokenizer, train_dataset, config, trainer_type)
        
        # Save trained model
        if 'save_dir' in config['training']:
            save_path = config['training']['save_dir']
            save_synthetic_model(trainer, tokenizer, save_path, config, trainer_type)
        
        # Clean up GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Delete trainer and model references to free memory
        del trainer
        del model
        del tokenizer
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\nAll synthetic training completed for {args.dataset} dataset!")


if __name__ == "__main__":
    main()
