from unsloth import FastLanguageModel
import json
import os
import sys
import torch
import backoff
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

from validate import TrainingConfig
from sft import sft_train
from utils import load_jsonl, load_model_and_tokenizer


def get_dataset(training_cfg):
    # Check if the training file is a HuggingFace dataset
    if (
        training_cfg.training_file.startswith("hf://")
        or "/" in training_cfg.training_file
        and not os.path.exists(training_cfg.training_file)
    ):
        # Load from HuggingFace Hub
        dataset_path = training_cfg.training_file.replace("hf://", "")
        print(f"Loading dataset from HuggingFace: {dataset_path}")
        dataset = load_dataset(dataset_path)
        if isinstance(dataset, dict):
            dataset = dataset["train"]
    else:
        # Load from local file
        rows = load_jsonl(training_cfg.training_file)
        if training_cfg.loss == "sft":
            dataset = Dataset.from_list([dict(messages=r["messages"]) for r in rows])
        else:
            dataset = Dataset.from_list(rows)

    if training_cfg.test_file:
        # Check if test file is a HuggingFace dataset
        if (
            training_cfg.test_file.startswith("hf://")
            or "/" in training_cfg.test_file
            and not os.path.exists(training_cfg.test_file)
        ):
            # Load from HuggingFace Hub
            test_dataset_path = training_cfg.test_file.replace("hf://", "")
            print(f"Loading test dataset from HuggingFace: {test_dataset_path}")
            test_dataset = load_dataset(test_dataset_path)
            if isinstance(test_dataset, dict):
                test_dataset = test_dataset[
                    "test" if "test" in test_dataset else "validation"
                ]
        else:
            # Load from local file
            test_rows = load_jsonl(training_cfg.test_file)
            if training_cfg.loss in ["orpo", "dpo"]:
                test_dataset = Dataset.from_list(test_rows)
            else:
                test_dataset = Dataset.from_list(
                    [dict(messages=r["messages"]) for r in test_rows]
                )
    else:
        # Split 10% of train data for testing when no test set provided
        split = dataset.train_test_split(test_size=0.1)
        dataset = split["train"]
        test_dataset = split["test"]

    return dataset, test_dataset


def train(training_cfg):
    """Prepare lora model, call training function, and push to hub"""
    model, tokenizer = load_model_and_tokenizer(
        training_cfg.model,
        load_in_4bit=training_cfg.load_in_4bit,
    )
    if tokenizer.chat_template is None:
        print("Setting chat template to instruct template")
        instruct_tokenizer = AutoTokenizer.from_pretrained(
            "allenai/OLMo-2-0325-32B-Instruct"
        )
        tokenizer.chat_template = instruct_tokenizer.chat_template

    print("Creating new LoRA adapter")
    target_modules = training_cfg.target_modules
    # reference_model = FastLanguageModel.get_peft_model(
    #     model,
    #     r=training_cfg.r,
    #     target_modules=target_modules,
    #     lora_alpha=training_cfg.lora_alpha,
    #     lora_dropout=training_cfg.lora_dropout,
    #     bias=training_cfg.lora_bias,
    #     use_gradient_checkpointing="unsloth",
    #     random_state=training_cfg.seed,
    #     use_rslora=training_cfg.use_rslora,
    #     loftq_config=None,
    #     use_dora=False,
    # )
    # #convert to bfloat16
    # # reference_model.to(dtype=torch.bfloat16)
    # # freeze reference model
    # for param in reference_model.parameters():
    #     param.requires_grad = False
    model = FastLanguageModel.get_peft_model(
        model,
        r=training_cfg.r,
        target_modules=target_modules,
        lora_alpha=training_cfg.lora_alpha,
        lora_dropout=training_cfg.lora_dropout,
        bias=training_cfg.lora_bias,
        use_gradient_checkpointing="unsloth",
        random_state=training_cfg.seed,
        use_rslora=training_cfg.use_rslora,
        loftq_config=None,
        use_dora=False,
    )

    dataset, test_dataset = get_dataset(training_cfg)

    kwargs = {}
    if training_cfg.max_steps:
        kwargs["max_steps"] = training_cfg.max_steps

    trainer = sft_train(
        training_cfg, dataset, model, tokenizer, test_dataset=test_dataset, **kwargs
    )
    trainer.train()

    finetuned_model_id = training_cfg.finetuned_model_id

    save_path = os.path.join(training_cfg.output_dir, finetuned_model_id)
    # merged_path = os.path.join(training_cfg.output_dir, finetuned_model_id + "_merged")
    # model.save_pretrained(merged_path, save_method="merged_16bit")
    # tokenizer.save_pretrained(merged_path)
    print(f"Model with LoRA adapter saved locally to {save_path}")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    # push_model(training_cfg,finetuned_model_id, model, tokenizer)

    try:
        eval_results = trainer.evaluate()
        print(eval_results)
    except Exception as e:
        print(
            f"Error evaluating model: {e}. The model has already been pushed to the hub."
        )


@backoff.on_exception(backoff.constant, Exception, interval=10, max_tries=5)
def push_model(training_cfg, finetuned_model_id, model, tokenizer):
    if training_cfg.merge_before_push:
        model.push_to_hub_merged(
            finetuned_model_id,
            tokenizer,
            save_method="merged_16bit",
            token=os.environ["HF_TOKEN"],
            private=training_cfg.push_to_private,
        )
    else:
        model.push_to_hub(
            finetuned_model_id,
            token=os.environ["HF_TOKEN"],
            private=training_cfg.push_to_private,
        )
        tokenizer.push_to_hub(
            finetuned_model_id,
            token=os.environ["HF_TOKEN"],
            private=training_cfg.push_to_private,
        )


def main(config: str):
    with open(config, "r") as f:
        config = json.load(f)
    training_config = TrainingConfig(**config)
    train(training_config)


if __name__ == "__main__":
    main(sys.argv[1])
