"""
HuggingFace Trainer-based training script for Rosetta model
"""

import torch
import random
import numpy as np
from transformers import TrainingArguments
import os
import sys
import json
import argparse
import shutil
import wandb
from datetime import datetime
from typing import Dict, Any

# Add the project root to the path
# sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from rosetta.train.dataset_adapters import (
    create_instructcoder_dataset, 
    ChatDataset, 
    RosettaDataCollator
)
from rosetta.train.trainer import (
    RosettaTrainer, 
    ProjectorSaveCallback, 
    freeze_model_components
)
from rosetta.train.model_utils import setup_models


def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main():
    """Main training function using HuggingFace Trainer"""
    
    # Parse arguments
    parser = argparse.ArgumentParser(description="Train RosettaModel with HuggingFace Trainer")
    parser.add_argument("--config", type=str, default="recipe/default_config.json", 
                       help="Path to JSON config file")
    args = parser.parse_args()

    # Load configuration
    with open(args.config, "r") as f:
        cfg = json.load(f)

    set_seed(42)

    # Extract config sections
    model_config = cfg["model"]
    training_config = cfg["training"]
    output_config = cfg["output"]
    data_config = cfg["data"]

    # Create timestamped output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    timestamped_output_dir = os.path.join(output_config["output_dir"], timestamp)
    os.makedirs(timestamped_output_dir, exist_ok=True)
    shutil.copy(args.config, os.path.join(timestamped_output_dir, "config.json"))

    # Setup models
    print("Setting up models...")
    rosetta_model, tokenizer = setup_models(model_config, 
                                           training_config.get("device", "cuda"), 
                                           torch.bfloat16)
    
    # Apply freezing
    freeze_model_components(rosetta_model, training_config["freeze"])
    
    # Print parameter info
    total_params = sum(p.numel() for p in rosetta_model.parameters())
    trainable_params = sum(p.numel() for p in rosetta_model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable percentage: {trainable_params / total_params * 100:.4f}%")

    # Setup dataset
    print("Loading dataset...")
    instruct_ds = create_instructcoder_dataset(
        split=data_config["split"], 
        num_samples=data_config["num_samples"]
    )
    
    full_dataset = ChatDataset(instruct_ds, tokenizer, training_config["max_length"])
    
    # Split dataset
    train_size = int(data_config["train_ratio"] * len(full_dataset))
    eval_size = len(full_dataset) - train_size
    train_dataset, eval_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, eval_size]
    )

    # Setup data collator
    data_collator = RosettaDataCollator(tokenizer)

    # Calculate training steps
    total_train_steps = (len(train_dataset) // training_config["per_device_train_batch_size"] 
                        * training_config["num_epochs"])
    
    # Setup training arguments
    run_name = f"{output_config['run_name']}_{timestamp}"
    training_args = TrainingArguments(
        output_dir=timestamped_output_dir,
        num_train_epochs=training_config["num_epochs"],
        per_device_train_batch_size=training_config["per_device_train_batch_size"],
        per_device_eval_batch_size=training_config["per_device_train_batch_size"],
        learning_rate=training_config["learning_rate"],
        weight_decay=training_config["weight_decay"],
        lr_scheduler_type=training_config["scheduler_type"],
        warmup_ratio=training_config["warmup_ratio"],
        max_grad_norm=training_config["max_grad_norm"],
        logging_steps=50,
        eval_steps=output_config["eval_steps"],
        save_steps=output_config["save_steps"],
        evaluation_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=False,
        report_to="wandb",
        run_name=run_name,
        bf16=True,
        dataloader_drop_last=True,
        remove_unused_columns=False,
    )

    # Initialize wandb
    wandb.init(
        project=output_config["wandb_project"],
        name=run_name,
        config=cfg,
    )

    # Setup trainer
    trainer = RosettaTrainer(
        model=rosetta_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        callbacks=[ProjectorSaveCallback()],
    )

    # Start training
    print("Starting training...")
    trainer.train()

    # Save final model
    print("Saving final model...")
    final_dir = os.path.join(timestamped_output_dir, "final")
    os.makedirs(final_dir, exist_ok=True)
    
    for i, projector in enumerate(rosetta_model.projector_list):
        torch.save(projector.state_dict(), os.path.join(final_dir, f"projector_{i}.pt"))
    rosetta_model.save_projector_config(os.path.join(final_dir, "projector_config.json"))

    print("Training completed!")
    wandb.finish()


if __name__ == "__main__":
    main() 