#!/usr/bin/env python3
"""Training script for SteerCLR."""

import argparse
import logging

import torch
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer

from .config import SteerCLRTrainerConfig
from .data import create_train_val_dataloaders, load_texts_from_multiple_json
from .trainer import SteerCLRTrainer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def load_config(config_path: str) -> SteerCLRTrainerConfig:
    """Load configuration from YAML file."""
    with open(config_path, "r") as f:
        config_dict = yaml.safe_load(f)
    return SteerCLRTrainerConfig(**config_dict)


def setup_model_and_tokenizer(config: SteerCLRTrainerConfig):
    """Setup model and tokenizer."""
    logger.info(f"Loading model: {config.model_name}")

    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    device = config.get_device()

    # Determine model dtype

    model_dtype = torch.float16 if device.type == "cuda" else torch.float32

    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        dtype=model_dtype,
        device_map=None,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        attn_implementation="sdpa",
    )

    # Set padding side for decoder-only models
    tokenizer.padding_side = "left"

    # Enable gradient checkpointing if available
    # if hasattr(model, "gradient_checkpointing_enable"):
    #     model.gradient_checkpointing_enable()
    #     logger.info("Gradient checkpointing enabled")

    model = model.to(device)
    model.eval()

    logger.info(f"Model loaded on device: {next(model.parameters()).device}")
    logger.info(f"Model dtype: {next(model.parameters()).dtype}")
    logger.info(f"Model hidden size: {model.config.hidden_size}")

    return model, tokenizer


def main():
    parser = argparse.ArgumentParser(description="Train SteerCLR steering vectors")
    parser.add_argument(
        "--config", type=str, required=True, help="Path to configuration YAML file"
    )
    parser.add_argument(
        "--experiment-id", type=str, help="Optional experiment identifier"
    )
    args = parser.parse_args()

    # Load and validate config
    config = load_config(args.config)
    if args.experiment_id:
        config.experiment_id = args.experiment_id

    # Setup seeds
    config.setup_seeds()

    logger.info("=" * 50)
    logger.info("SteerCLR Training Configuration")
    logger.info("=" * 50)
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info("=" * 50)
    logger.info("Configuration:")
    logger.info(config.to_yaml())
    logger.info("=" * 50)

    # Setup model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(config)

    # Load training data
    training_data = load_texts_from_multiple_json(config.train_texts_files)
    logger.info(
        f"Loaded training data with {len(training_data)} questions from {len(config.train_texts_files)} files: {config.train_texts_files}"
    )

    # Load validation data if specified
    val_training_data = None

    if config.val_texts_files:  # This is used in combination with
        val_training_data = load_texts_from_multiple_json(config.val_texts_files)
        logger.info(
            f"Loaded validation data with {len(val_training_data)} questions from {len(config.val_texts_files)} files: {config.val_texts_files}"
        )

    # Create dataloaders
    train_loader, val_loader = create_train_val_dataloaders(
        tokenizer=tokenizer,
        batch_size=config.batch_size // config.num_vectors_per_batch,
        max_length=config.max_length,
        training_data=training_data,
        val_training_data=val_training_data,
        seed=config.seed,
        val_num_samples=config.val_num_samples,
        val_subset_seed=(
            config.val_subset_seed
            if config.val_subset_seed is not None
            else config.seed
        ),
    )

    # Validate batch size divisibility
    if config.batch_size % config.num_vectors_per_batch != 0:
        raise ValueError(
            f"batch_size ({config.batch_size}) must be divisible by vector_group_size ({config.num_vectors_per_batch})"
        )

    # Setup experiment directory
    experiment_dir = config.get_experiment_dir()
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # Save config
    with open(experiment_dir / "config.yaml", "w") as f:
        f.write(config.to_yaml())

    # Initialize trainer
    trainer = SteerCLRTrainer(
        model=model,
        tokenizer=tokenizer,
        dataloader=train_loader,
        config=config,
        experiment_dir=experiment_dir,
        val_dataloader=val_loader,
    )

    logger.info(f"Starting training for {config.n_training_steps} steps...")
    logger.info(f"Experiment directory: {experiment_dir}")

    try:
        trainer.train()
        trainer.run_validation(step=None)
        trainer.save_vectors(experiment_dir)
        trainer.finalize_wandb()
        logger.info("Training completed successfully!")
        logger.info(f"Results saved to: {experiment_dir}")
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
        trainer.save_vectors(experiment_dir)
        trainer.finalize_wandb()
        logger.info(f"Saved vectors before exit to: {experiment_dir}")
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        trainer.save_vectors(experiment_dir)
        trainer.finalize_wandb()
        logger.info(f"Saved vectors before exit to: {experiment_dir}")
        raise


if __name__ == "__main__":
    main()
