"""
Training Script

This module implements the main training pipeline for fine-tuning language models using various RLHF methods:
- DPO (Direct Preference Optimization)
- NashMD (Nash-MD)
- Proposed Model Training
- Dual Model Training

The script supports distributed training, mixed precision, and various model configurations including LoRA and QLoRA.
"""

import warnings
import os
import logging
import yaml
from importlib.metadata import PackageNotFoundError
import json
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from utils.trainer import ProposedModelTrainer

# Configure logging to suppress warnings
logging.getLogger("trl").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)

# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*tokenizer.*deprecated.*")
warnings.filterwarnings("ignore", module="trl.*")
warnings.filterwarnings("ignore", module="transformers.*")
warnings.filterwarnings("ignore", module="peft.*")

# Disable all warnings
warnings.simplefilter("ignore")

# Set PyTorch memory allocation settings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import wandb
import argparse
from huggingface_hub import login as hf_login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer, PairRMJudge

def load_config(config_path):
    """
    Load and process configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        dict: Processed configuration dictionary with shared settings merged
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # If config has 'shared' section, merge it with the root level
    if 'shared' in config:
        shared_config = config.pop('shared')
        config.update(shared_config)
    
    return config

def get_model_and_tokenizer(config):
    """
    Initialize model and tokenizer based on configuration settings.
    
    Args:
        config: Configuration dictionary containing model and training settings
        
    Returns:
        tuple: (model, tokenizer) initialized according to configuration
    """
    try:
        # Prepare model configuration
        model_config = {
            "trust_remote_code": config['model']['trust_remote_code'],
            "torch_dtype": torch.float16 if config['training']['fp16'] else torch.float32,
            "device_map": "auto"
        }

        if config['model'].get('use_qlora', False):
            # Configure QLoRA
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=config['qlora']['load_in_4bit'],
                bnb_4bit_quant_type=config['qlora']['bnb_4bit_quant_type'],
                bnb_4bit_compute_dtype=getattr(torch, config['qlora']['bnb_4bit_compute_dtype']),
                bnb_4bit_use_double_quant=config['qlora']['bnb_4bit_use_double_quant'],
            )
            model_config["quantization_config"] = bnb_config
            
            # Load model with QLoRA
            model = AutoModelForCausalLM.from_pretrained(
                config['model']['model_name'],
                **model_config
            )
            
            # Prepare model for k-bit training
            model = prepare_model_for_kbit_training(model)
            
            print("Using QLoRA for model training")
        else:
            # Load model without quantization
            model = AutoModelForCausalLM.from_pretrained(
                config['model']['model_name'],
                **model_config
            )
            print("Using regular LoRA for model training")
        
        # Configure LoRA
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type=config['lora']['task_type']
        )
        
        # Convert model to PEFT model with LoRA
        model = get_peft_model(model, lora_config)
        
    except (ImportError, PackageNotFoundError):
        print("bitsandbytes not available, falling back to regular LoRA")
        # Load model without quantization
        model = AutoModelForCausalLM.from_pretrained(
            config['model']['model_name'],
            **model_config
        )
        
        # Configure regular LoRA
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type=config['lora']['task_type']
        )
        
        # Convert model to PEFT model with LoRA
        model = get_peft_model(model, lora_config)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        config['model']['model_name'],
        trust_remote_code=config['model']['trust_remote_code']
    )
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer


def format_dataset(dataset):
    """
    Format dataset for DualModelTrainer by converting to the required structure.
    
    Args:
        dataset: Input dataset with 'prompt', 'chosen', and 'rejected' fields
        
    Returns:
        list: Formatted dataset with 'x', 'y_w', and 'y_l' fields
    """
    formatted_dataset = []
    for example in dataset:
        if 'prompt' in example and 'chosen' in example and 'rejected' in example:
            formatted_dataset.append({
                'x': example['prompt'],
                'y_w': example['chosen'],
                'y_l': example['rejected']
            })
    
    print(f"Formatted {len(formatted_dataset)} preference pairs for training")
    return formatted_dataset


def train(model, tokenizer, train_dataset, config, trainer_type, local_rank, world_size):
    """
    Train a language model using the specified RLHF method.
    
    Args:
        model: Base language model to be fine-tuned
        tokenizer: Tokenizer for text processing
        train_dataset: Dataset for training
        config: Configuration dictionary
        trainer_type: Type of trainer to use ('dpo' or 'proposed')
        local_rank: Local rank for distributed training
        world_size: Total number of distributed processes
        
    Returns:
        Trainer instance after training completion
    """
    if local_rank == 0:
        print(f"\nTraining with {trainer_type} method")
    
    # Enable gradient checkpointing for memory efficiency
    # Check if model is wrapped with DDP and access underlying model if needed
    if isinstance(model, DDP):
        model.module.gradient_checkpointing_enable()
    else:
        model.gradient_checkpointing_enable()
    
    if trainer_type == 'dpo':
        from trl import DPOConfig, DPOTrainer
        
        # Configure LoRA with parameters from config
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type="CAUSAL_LM"
        )
        
        # If model is wrapped with DDP, extract the base model for DPO
        if isinstance(model, DDP):
            base_model = model.module
        else:
            base_model = model
        
        # Convert model to PEFT model with LoRA if not already done
        # if not hasattr(base_model, "peft_config"):
        #     base_model = get_peft_model(base_model, lora_config)
        #     base_model = base_model.to(torch.device(f'cuda:{local_rank}'))
        
        # Important: we need to handle gradient checkpointing carefully
        base_model.config.use_cache = False  # This is crucial when using gradient checkpointing
        
        # Set run name with date and time
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"{trainer_type}_{current_time}"
        
        trl_config = DPOConfig(
            learning_rate=float(config['method_configs']['dpo']['training']['learning_rate']),
            per_device_train_batch_size=config['training']['per_device_train_batch_size'],
            per_device_eval_batch_size=config['training']['per_device_train_batch_size'],
            num_train_epochs=config['training']['num_train_epochs'],
            logging_steps=config['training']['logging_steps'],
            save_steps=config['training']['save_steps'],
            output_dir=os.path.join(config['training']['output_dir'], trainer_type),
            run_name=run_name,
            beta=config['method_configs']['dpo']['specific']['beta'],
            max_length=config['method_configs']['dpo']['specific']['max_length'],
            max_prompt_length=config['method_configs']['dpo']['specific']['max_prompt_length'],
            gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
            fp16=config['training']['fp16'],
            optim="adamw_torch",
            gradient_checkpointing=True,
            local_rank=local_rank,
            ddp_find_unused_parameters=False,
            ddp_bucket_cap_mb=25,
            disable_tqdm=False,  # Explicitly enable tqdm
        )
        
        # Create the DPO trainer with the unwrapped model
        trainer = DPOTrainer(
            model=base_model,
            args=trl_config,
            train_dataset=train_dataset,
            processing_class=tokenizer,
            peft_config=lora_config
        )
    elif trainer_type == 'nashmd':
        # Configure LoRA with parameters from config
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type="CAUSAL_LM"
        )
        
        # If model is wrapped with DDP, extract the base model
        if isinstance(model, DDP):
            base_model = model.module
        else:
            base_model = model
        
        # Convert model to PEFT model with LoRA if not already done
        # if not hasattr(base_model, "peft_config"):
        #     base_model = get_peft_model(base_model, lora_config)
        #     base_model = base_model.to(torch.device(f'cuda:{local_rank}'))
        
        # Important: we need to handle gradient checkpointing carefully
        base_model.config.use_cache = False  # This is crucial when using gradient checkpointing

        # Set run name with date and time
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"{trainer_type}_{current_time}"

        # Initialize PairRM judge
        from utils.custom_judge import AnnotationModelJudge
        judge = AnnotationModelJudge(
            model_name_or_path="meta-llama/Llama-3.1-8B-Instruct", w=0.2, auth_token=config['auth']['hf_token']
        )
        
        # Basic configuration for NashMD
        trl_config = NashMDConfig(
            output_dir=os.path.join(config['training']['output_dir'], trainer_type),
            run_name=run_name,
            learning_rate=float(config['method_configs']['nashmd']['training']['learning_rate']),
            per_device_train_batch_size=config['training']['per_device_train_batch_size'],
            per_device_eval_batch_size=config['training']['per_device_train_batch_size'],
            num_train_epochs=config['training']['num_train_epochs'],
            logging_steps=config['training']['logging_steps'],
            save_steps=config['training']['save_steps'],
            gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
            fp16=config['training']['fp16'],
            optim="adamw_torch",
            gradient_checkpointing=True,
            max_length=config['training']['max_length'],
            temperature=config['method_configs']['nashmd']['specific']['temperature'],
            beta=[config['method_configs']['nashmd']['specific']['beta']],
            disable_dropout=True,
            ds3_gather_for_generation=True,
            disable_tqdm=False  # Enable progress bar
        )
        
        # Initialize NashMD trainer with model directly
        trainer = NashMDTrainer(
            model=base_model,
            judge=judge,
            args=trl_config,
            processing_class=tokenizer,
            train_dataset=train_dataset,
            peft_config=lora_config
        )
    elif trainer_type == 'proposed':
        # Format dataset for ProposedModelTrainer
        formatted_dataset = format_dataset(train_dataset)
        
        # Set run name with date and time
        from datetime import datetime
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"{trainer_type}_{current_time}"
        
        # Create trainer configuration
        trainer_config = {
            'mu_epochs': config['method_configs']['proposed']['specific']['mu_epochs'],
            'pi_epochs': config['method_configs']['proposed']['specific']['pi_epochs'],
            'mu_learning_rate': float(config['method_configs']['proposed']['specific']['mu_learning_rate']),
            'pi_learning_rate': float(config['method_configs']['proposed']['specific']['pi_learning_rate']),
            'batch_size': config['training']['per_device_train_batch_size'],
            'num_train_epochs': config['training']['num_train_epochs'],
            'logging_steps': config['training']['logging_steps'],
            'save_steps': config['training']['save_steps'],
            'output_dir': os.path.join(config['training']['output_dir'], trainer_type),
            'wandb_project': config['tracking']['wandb_project'],
            'wandb_run_name': run_name,  # Add run name to config
            'beta_pi': config['method_configs']['proposed']['specific']['beta_pi'],
            'beta': config['method_configs']['proposed']['specific']['beta'],
            'lora_r': config['lora']['r'],
            'lora_alpha': config['lora']['lora_alpha'],
            'target_modules': config['lora']['target_modules'],
            'lora_dropout': config['lora']['lora_dropout'],
            'bias': config['lora']['bias'],
            'reference_model_id': config['method_configs']['proposed']['specific']['reference_model_id'],
            'trust_remote_code': config['model']['trust_remote_code'],
        }
        
        # If model is wrapped with DDP, pass the module to the trainer
        if isinstance(model, DDP):
            base_model = model.module
        else:
            base_model = model
        
        # Create and run trainer
        trainer = ProposedModelTrainer(
            model=base_model,
            config=trainer_config,
            train_dataset=formatted_dataset,
            tokenizer=tokenizer,
            local_rank=local_rank,
            world_size=world_size
        )
    else:
        raise ValueError(f"Unsupported trainer type: {trainer_type}")
    
    # Execute training
    trainer.train()
    
    return trainer

def preprocess_dataset(dataset, tokenizer):
    """
    Preprocess dataset for training by formatting prompts and responses.
    
    Args:
        dataset: Input dataset
        tokenizer: Tokenizer for text processing
        
    Returns:
        Dataset: Preprocessed dataset with formatted prompts and responses
    """
    def format_prompt(example):
        # The dataset has 'x' as the prompt and 'y_w'/'y_l' as responses
        return {
            "prompt": example["x"],
            "chosen": example["y_w"],
            "rejected": example["y_l"]
        }
    
    # Apply formatting to the dataset
    formatted_dataset = dataset.map(format_prompt)
    
    return formatted_dataset

def save_model_and_config(trainer, tokenizer, save_path, config):
    """
    Save model, tokenizer, and configuration to disk.
    
    Args:
        trainer: Trained model trainer instance
        tokenizer: Tokenizer used for training
        save_path: Directory path to save the model
        config: Configuration dictionary used for training
    """
    # Save model and tokenizer
    if isinstance(trainer, ProposedModelTrainer):
        # Proposed model case - add beta subdirectory
        beta = config['method_configs']['proposed']['specific']['beta']
        save_path = os.path.join(save_path, f"beta_{beta}")
        os.makedirs(save_path, exist_ok=True)
        trainer.model.save_pretrained(save_path)
    else:
        trainer.model.save_pretrained(save_path)
    
    tokenizer.save_pretrained(save_path)
    
    # Save configuration
    config_path = os.path.join(save_path, "training_config.json")
    with open(config_path, 'w') as f:
        json.dump({
            'model_config': config['model'],
            'qlora_config': config.get('qlora', {}),
            'lora_config': config['lora'],
            'training_config': config['training']
        }, f, indent=2)
    
    print(f"Model and configuration saved to {save_path}")

def load_model_and_config(load_path, config):
    """
    Load model, tokenizer, and configuration from disk.
    
    Args:
        load_path: Directory path containing saved model and configuration
        config: Base configuration dictionary to update with saved settings
        
    Returns:
        tuple: (model, tokenizer, updated_config)
    """
    # Load configuration if it exists
    config_path = os.path.join(load_path, "training_config.json")
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            saved_config = json.load(f)
            # Update config with saved values
            config['model'].update(saved_config['model_config'])
            if 'qlora' in saved_config:
                config['qlora'] = saved_config['qlora']
            config['lora'].update(saved_config['lora_config'])
            config['training'].update(saved_config['training_config'])
    
    # Load model and tokenizer
    model, tokenizer = get_model_and_tokenizer(config)
    
    # Load adapter weights if they exist
    if os.path.exists(load_path):
        try:
            model.load_adapter(load_path, adapter_name="default")
        except Exception as e:
            print(f"Warning: Could not load adapter weights: {e}")
    
    return model, tokenizer, config

def main():
    """
    Main training script entry point.
    
    Handles argument parsing, distributed training setup, and execution of training
    for specified RLHF methods. Supports multiple training methods and distributed
    training configurations.
    """
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train language models with TRL")
    parser.add_argument('--config', type=str, default='config.yaml', help='Path to config file')
    parser.add_argument('--trainer', type=str, default=None, help='Trainer type (dpo, proposed)')
    parser.add_argument('--local_rank', type=int, default=-1, help='Local rank for distributed training')
    args = parser.parse_args()
    
    # Get local rank from environment variable or command line argument
    local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    
    # Force proposed GPU training for NashMD trainer
    if world_size > 1:
        if local_rank == -1:
            raise ValueError("Local rank must be specified for distributed training")
        
        # Initialize the process group with explicit device ID
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl', init_method="env://")
        device = torch.device(f'cuda:{local_rank}')
        
        # Use a device-specific barrier call after initialization for safety
        torch.distributed.barrier(device_ids=[local_rank])
    else:
        local_rank = 0
        world_size = 1
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load configuration
    config = load_config(args.config)
    
    # Override methods list if trainer is specified in command line
    if args.trainer:
        methods_to_train = [args.trainer]
    else:
        methods_to_train = config['methods']
    
    # Create output directory
    os.makedirs(config['training']['output_dir'], exist_ok=True)
    if 'save_dir' in config['training']:
        os.makedirs(config['training']['save_dir'], exist_ok=True)
    
    # Login to Hugging Face if token is provided
    if 'auth' in config and 'hf_token' in config['auth']:
        hf_login(token=config['auth']['hf_token'])
        if local_rank == 0:
            print("Logged in to Hugging Face")
    
    # Login to Weights & Biases if token is provided
    if 'tracking' in config and 'wandb_token' in config['tracking']:
        wandb.login(key=config['tracking']['wandb_token'])
        if local_rank == 0:
            print("Logged in to Weights & Biases")
    
    # Load dataset
    if 'name' in config['dataset']:
        if local_rank == 0:
            print(f"Loading dataset: {config['dataset']['name']}")
        dataset = load_dataset(config['dataset']['name'])["train"]
    elif 'train_path' in config['dataset']:
        if local_rank == 0:
            print(f"Loading dataset from: {config['dataset']['train_path']}")
        dataset = load_dataset('json', data_files=config['dataset']['train_path'])["train"]
    else:
        raise ValueError("Dataset configuration must specify either 'name' or 'train_path'")
    
    # Limit dataset size if num_samples is specified
    if 'num_samples' in config['dataset']:
        num_samples = config['dataset']['num_samples']
        if len(dataset) > num_samples:
            if local_rank == 0:
                print(f"Limiting dataset from {len(dataset)} to {num_samples} samples")
            dataset = dataset.select(range(num_samples))
    
    # Preprocess dataset
    train_dataset = preprocess_dataset(dataset, None)  # We'll load the tokenizer with the model later
    if local_rank == 0:
        print(f"Training dataset size: {len(train_dataset)}")
    
    # Train each method
    for trainer_type in methods_to_train:
        if local_rank == 0:
            print(f"\n{'='*50}")
            print(f"Training with {trainer_type} method")
            print(f"{'='*50}")
        
        # Load model and tokenizer directly in a clean way
        if local_rank == 0:
            print(f"Loading model: {config['model']['model_name']}")
        
        # Define model config without device_map to avoid DTensor issues
        model_config = {
            "trust_remote_code": config['model']['trust_remote_code'],
            "torch_dtype": torch.float16 if config['training']['fp16'] else torch.float32,
        }
        
        # Skip QLoRA for distributed training to avoid tensor type issues
        if world_size > 1:
            config['model']['use_qlora'] = False
            if local_rank == 0:
                print("Disabling QLoRA for distributed training to avoid tensor type issues")
        
        # Load model without quantization or device placement
        model = AutoModelForCausalLM.from_pretrained(
            config['model']['model_name'],
            **model_config
        )
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config['model']['model_name'],
            trust_remote_code=config['model']['trust_remote_code']
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        # Configure LoRA
        lora_config = LoraConfig(
            r=config['lora']['r'],
            lora_alpha=config['lora']['lora_alpha'],
            target_modules=config['lora']['target_modules'],
            lora_dropout=config['lora']['lora_dropout'],
            bias=config['lora']['bias'],
            task_type=config['lora']['task_type']
        )
        
        # Apply LoRA to model
        model = get_peft_model(model, lora_config)
        
        # Explicitly move model to device before DDP wrapping
        model = model.to(device)
        
        # Enable gradient checkpointing before DDP wrapping
        model.gradient_checkpointing_enable()
        
        # Make sure use_cache is set to False for gradient checkpointing
        if hasattr(model, 'config'):
            model.config.use_cache = False
        
        # Add _set_static_graph method if it doesn't exist
        if not hasattr(model, '_set_static_graph'):
            model._set_static_graph = lambda: None
            if local_rank == 0:
                print("Added dummy _set_static_graph method to model")
        
        # For DPO/NashMD trainers, don't wrap with DDP as they handle it internally
        if trainer_type in ['dpo', 'nashmd'] and world_size > 1:
            # Make sure all processes are synchronized
            torch.distributed.barrier(device_ids=[local_rank])
            
            # Set static_graph in the model
            model._set_static_graph()
            # We'll handle the DDP inside the train function
        else:
            # Wrap model with DDP only after moving to device
            if world_size > 1:
                # Make sure all processes are synchronized before DDP wrapping
                torch.distributed.barrier(device_ids=[local_rank])
                model = DDP(model, device_ids=[local_rank], find_unused_parameters=False, static_graph=True)
        
        # Train model with the specific trainer
        trainer = train(model, tokenizer, train_dataset, config, trainer_type, local_rank, world_size)
        
        # Save trained model (only on main process)
        if local_rank == 0 and 'save_dir' in config['training']:
            save_path = os.path.join(config['training']['save_dir'], f"{trainer_type}_model")
            save_model_and_config(trainer, tokenizer, save_path, config)
            print(f"Model and configuration saved to {save_path}")
        
        # Clean up GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    if local_rank == 0:
        print("\nAll training completed!")
    
    # Clean up distributed training
    if world_size > 1:
        dist.destroy_process_group()

if __name__ == "__main__":
    main() 