#!/usr/bin/env python3
"""Unified supervised fine-tuning script"""
import os
import torch
from transformers import TrainingArguments, Trainer

from utils.config import load_config, get_base_parser
from data.data_loader import DataLoader
from data.prompt_builder import PromptBuilder
from utils.data_utils import smiles2selfies
from utils.training_utils import setup_distributed, setup_experiment, setup_tokenizer, setup_model, setup_lora, save_experiment, cleanup_distributed, with_telegram_notifications
from utils.training_helpers import apply_label_masking, DataCollatorWithPaddingAndLabels


@with_telegram_notifications
def main():
    parser = get_base_parser()
    args = parser.parse_args()
    config = load_config(args.config, args.opts)
    
    local_rank, multi_gpu = setup_distributed()
    setup_experiment(config, local_rank, "molgen")
    
    # Load data
    data_loader = DataLoader()
    
    # Check if using per-dataset tasks or global tasks
    if hasattr(config, 'dataset_tasks') and config.dataset_tasks:
        train_data = data_loader.load_datasets_with_tasks(
            config.dataset_tasks,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        if "prod" in train_data.column_names and "equa" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data, "prod", "prod")
            train_data = smiles2selfies(train_data, "equa", "equa")
        # Build prompts per dataset
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts_per_dataset(train_data, is_generation=False)
    else:
        train_data = data_loader.load_multiple_datasets(
            config.target_datasets,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        if "prod" in train_data.column_names and "equa" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data, "prod", "prod")
            train_data = smiles2selfies(train_data, "equa", "equa")
        # Build prompts globally
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts(train_data, config.tasks, is_generation=False)
    # Setup model and tokenizer
    print(train_data[0])
    tokenizer = setup_tokenizer(config)
    model, lora_hist = setup_model(config)
    model = setup_lora(model, config)
    assistant_header_ids = tokenizer.encode(tokenizer.response_split_id, add_special_tokens=False)
    # Apply label masking
    train_data = train_data.map(
        lambda examples: apply_label_masking(examples, tokenizer, assistant_header_ids),
        batched=True, remove_columns=train_data.column_names
    )
    # Training arguments
    num_gpus = torch.cuda.device_count()
    training_args = TrainingArguments(
        output_dir=os.path.join(config.exp_save_dir, "hf"),
        save_strategy="no",
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        max_grad_norm=config.max_grad_norm,
        warmup_ratio=config.warmup_ratio,
        lr_scheduler_type=config.lr_scheduler_type,
        logging_steps=config.logging_steps,
        report_to="wandb" if local_rank == 0 else None,
        remove_unused_columns=False,
        ddp_find_unused_parameters=False,
        deepspeed="ds_config.json",
        fp16=True
    )
    
    # Data collator
    data_collator = DataCollatorWithPaddingAndLabels(tokenizer)
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        data_collator=data_collator
    )
    
    # Train
    trainer.train()
    
    # Save
    save_experiment(config, model, lora_hist, local_rank)
    cleanup_distributed(multi_gpu)

if __name__ == "__main__":
    main()
