import os
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from optimizers.sam_wrapper import SAMWrapper
from optimizers.sabcd_wrapper import SABCDWrapper
from utils.data_utils import load_local_dataset, prepare_dataset_for_training

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Custom loss calculation function to ensure correct loss recording"""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Record loss to progress bar
        if self.state.global_step % 10 == 0:
            print(f"\rCurrent loss: {loss.item():.4f}", end="")
            
        return (loss, outputs) if return_outputs else loss


def finetune_model(args, task_name, base_model_path, output_path):
    """
    Fine-tune specified task using selected optimizer

    Args:
        args: Argument object
        task_name: Task name (dataset name)
        base_model_path: Starting model path
        output_path: Output path

    Returns:
        model: Fine-tuned model
    """
    # Record start time
    start_time = time.time()
    
    # Determine which optimizer to use
    optimizer_type = getattr(args, 'optimizer', 'sabcd')
    print(f"\n--- Start fine-tuning {task_name} task using {optimizer_type.upper()} optimizer (epochs: {args.epochs}) ---")
        
    # Determine whether to use half-precision loading
    use_fp16 = torch.cuda.is_available()
    print(f"Loading model configuration - half-precision: {use_fp16}")
    cache_dir = args.cache_dir or "/root/autodl-tmp/huggingface"
    
    # Load base model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16 if use_fp16 else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
        low_cpu_mem_usage=True,
        cache_dir=os.path.join(cache_dir, "transformers")
    )
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_path,
        cache_dir=os.path.join(cache_dir, "transformers")
    )

    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    # Ensure model is in training mode
    model.train()

    # Load dataset
    if hasattr(args, 'task_dataset') and args.task_dataset:
        # Use provided dataset
        dataset = args.task_dataset
    else:
        # Load dataset from local
        print(f"Loading {task_name} task dataset from local...")
        train_dataset = load_local_dataset(task_name, 'train')
        if train_dataset is None:
            raise ValueError(f"Unable to load training data for dataset {task_name}")
            
        train_dataset = prepare_dataset_for_training(train_dataset, task_name)
        dataset = {"train": train_dataset}
        
    print(f"Dataset information - training set size: {len(dataset['train'])}, column names: {dataset['train'].column_names}")

    # Define data processing function
    def tokenize_function(examples):
        max_length = getattr(args, "max_length", 512)  
        if "text" in examples:
            return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
        elif "prompt" in examples and "answer" in examples:
            # Directly combine prompt and answer
            combined_texts = [f"{p}\n{a}" for p, a in zip(examples["prompt"], examples["answer"])]
            return tokenizer(combined_texts, padding="max_length", truncation=True, max_length=max_length)
        else:
            # Try to handle other formats
            raise ValueError(f"Unsupported dataset format: {dataset['train'].column_names}")

    # Process dataset
    remove_columns = dataset["train"].column_names.copy() 
    tokenized_dataset = dataset["train"].map(
        tokenize_function,
        batched=True,
        remove_columns=remove_columns
    )

    # Define training parameters
    training_args = TrainingArguments(
        output_dir=f"./temp/temp_trainer_{task_name}",
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=2,
        learning_rate=args.learning_rate,
        num_train_epochs=args.epochs,
        logging_steps=10,
        save_strategy="no",
        remove_unused_columns=False,
        push_to_hub=False,
        gradient_checkpointing=True,
        bf16=use_fp16,
        fp16=False,
        optim="adamw_torch" if optimizer_type == "adam" else "adamw_torch",  # Optimizer type
        ddp_find_unused_parameters=False,
        dataloader_num_workers=0,
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    # Define trainer
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer
    )

    # Set optimizer based on selection
    if optimizer_type == 'sabcd':
        # Create SABCD optimizer
        optimizer = SABCDWrapper(
            [p for p in model.parameters() if p.requires_grad],
            torch.optim.AdamW,
            rho=args.rho,
            selection_percent=args.selection_percent,
            adaptive=args.adaptive,
            lr=training_args.learning_rate,
            weight_decay=0.001
        )
        # Set custom optimizer
        trainer.optimizer = optimizer
        print("SABCD optimizer configured")
    elif optimizer_type == 'sam':
        # Create SAM optimizer
        optimizer = SAMWrapper(
            [p for p in model.parameters() if p.requires_grad],
            torch.optim.AdamW,
            rho=args.rho,
            adaptive=args.adaptive,
            lr=training_args.learning_rate,
            weight_decay=0.001
        )
        # Set custom optimizer
        trainer.optimizer = optimizer
        print(f"SAM optimizer configured (rho={args.rho}, adaptive={args.adaptive})")
    else:
        # Use standard AdamW optimizer (automatically created by Trainer)
        print("Use standard AdamW optimizer")

    # Start fine-tuning
    print(f"Start training...")
    train_result = trainer.train()
    finetune_time_seconds = time.time() - start_time
    
    # Save fine-tuned model
    os.makedirs(output_path, exist_ok=True)
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)

    print(f"Fine-tuning completed, model saved to: {output_path}")
    return model, finetune_time_seconds