from unsloth import FastLanguageModel
import json
import os
import sys

import backoff
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

from validate import TrainingConfig
from sft import sft_train
from utils import load_jsonl, load_model_and_tokenizer


def get_dataset(training_cfg):
    # Check if the training file is a HuggingFace dataset
    if (
        training_cfg.training_file.startswith("hf://")
        or "/" in training_cfg.training_file
        and not os.path.exists(training_cfg.training_file)
    ):
        # Load from HuggingFace Hub
        dataset_path = training_cfg.training_file.replace("hf://", "")
        print(f"Loading dataset from HuggingFace: {dataset_path}")
        dataset = load_dataset(dataset_path)
        if isinstance(dataset, dict):
            dataset = dataset["train"]
    else:
        # Load from local file
        rows = load_jsonl(training_cfg.training_file)
        if training_cfg.loss == "sft":
            dataset = Dataset.from_list([dict(messages=r["messages"]) for r in rows])
        else:
            dataset = Dataset.from_list(rows)

    if training_cfg.test_file:
        # Check if test file is a HuggingFace dataset
        if (
            training_cfg.test_file.startswith("hf://")
            or "/" in training_cfg.test_file
            and not os.path.exists(training_cfg.test_file)
        ):
            # Load from HuggingFace Hub
            test_dataset_path = training_cfg.test_file.replace("hf://", "")
            print(f"Loading test dataset from HuggingFace: {test_dataset_path}")
            test_dataset = load_dataset(test_dataset_path)
            if isinstance(test_dataset, dict):
                test_dataset = test_dataset[
                    "test" if "test" in test_dataset else "validation"
                ]
        else:
            # Load from local file
            test_rows = load_jsonl(training_cfg.test_file)
            if training_cfg.loss in ["orpo", "dpo"]:
                test_dataset = Dataset.from_list(test_rows)
            else:
                test_dataset = Dataset.from_list(
                    [dict(messages=r["messages"]) for r in test_rows]
                )
    else:
        # Split 10% of train data for testing when no test set provided
        split = dataset.train_test_split(test_size=0.1)
        dataset = split["train"]
        test_dataset = split["test"]

    return dataset, test_dataset


def train(training_cfg):
    """Prepare model for full finetuning, call training function, and push to hub"""
    model, tokenizer = load_model_and_tokenizer(
        training_cfg.model,
        load_in_4bit=training_cfg.load_in_4bit,
    )
    # Apply torch.compile to the model for faster training
    if tokenizer.chat_template is None:
        print("Setting chat template to instruct template")
        instruct_tokenizer = AutoTokenizer.from_pretrained(
            "allenai/OLMo-2-0325-32B-Instruct"
        )
        tokenizer.chat_template = instruct_tokenizer.chat_template

    print("Preparing model for full finetuning")
    # No LoRA adapter, using the model directly for full finetuning
    # Enable gradient checkpointing for memory efficiency
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
        
    layer_to_unfreeze = 63    
    print(f"Freezing all layers except layer {layer_to_unfreeze}")
    
    # First freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    if hasattr(model, "model"):
        # For models with a 'model' attribute (like PeftModel)
        if hasattr(model.model, "layers"):
            # For models with explicit layers attribute
            if len(model.model.layers) > layer_to_unfreeze:
                for param in model.model.layers[layer_to_unfreeze].parameters():
                    param.requires_grad = True
                print(f"Unfrozen layer {layer_to_unfreeze} with {sum(p.numel() for p in model.model.layers[layer_to_unfreeze].parameters() if p.requires_grad)} parameters")
    elif hasattr(model, "layers"):
        # For models with direct layers attribute
        if len(model.layers) > layer_to_unfreeze:
            for param in model.layers[layer_to_unfreeze].parameters():
                param.requires_grad = True
            print(f"Unfrozen layer {layer_to_unfreeze} with {sum(p.numel() for p in model.layers[layer_to_unfreeze].parameters() if p.requires_grad)} parameters")
    else:
        print("WARNING: Could not identify layer structure. No layers were unfrozen.")
    
    # Print trainable parameters count
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/all_params:.2%} of {all_params:,} total parameters)")

    dataset, test_dataset = get_dataset(training_cfg)

    kwargs = {}
    if training_cfg.max_steps:
        kwargs["max_steps"] = training_cfg.max_steps

    trainer = sft_train(
        training_cfg, dataset, model, tokenizer, test_dataset=test_dataset, **kwargs
    )
    trainer.train()

    finetuned_model_id = training_cfg.finetuned_model_id

    save_path = os.path.join(training_cfg.output_dir, finetuned_model_id)
    print(f"Fully finetuned model saved locally to {save_path}")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    # push_model(training_cfg, finetuned_model_id, model, tokenizer)

    try:
        eval_results = trainer.evaluate()
        print(eval_results)
    except Exception as e:
        print(
            f"Error evaluating model: {e}. The model has already been saved locally."
        )


@backoff.on_exception(backoff.constant, Exception, interval=10, max_tries=5)
def push_model(training_cfg, finetuned_model_id, model, tokenizer):
    model.push_to_hub(
        finetuned_model_id,
        token=os.environ["HF_TOKEN"],
        private=training_cfg.push_to_private,
    )
    tokenizer.push_to_hub(
        finetuned_model_id,
        token=os.environ["HF_TOKEN"],
        private=training_cfg.push_to_private,
    )


def main(config: str):
    with open(config, "r") as f:
        config = json.load(f)
    training_config = TrainingConfig(**config)
    train(training_config)


if __name__ == "__main__":
    main(sys.argv[1])
