import json
import sys
import torch
from transformers import AutoTokenizer
from unsloth import FastLanguageModel
from peft import PeftModel
from validate import TrainingConfig
from training import get_dataset, sft_train, load_model_and_tokenizer


def main(training_cfg: dict):
    model, tokenizer = load_model_and_tokenizer(
        training_cfg.model,
        load_in_4bit=training_cfg.load_in_4bit,
    )
    
    model = FastLanguageModel.get_peft_model(
        model,
        r=training_cfg.r,
        target_modules=training_cfg.target_modules,
        lora_alpha=training_cfg.lora_alpha,
        lora_dropout=training_cfg.lora_dropout,
        bias=training_cfg.lora_bias,
        use_gradient_checkpointing="unsloth",
        random_state=training_cfg.seed,
        use_rslora=training_cfg.use_rslora,
        loftq_config=None,
        use_dora=False,
    )
    # model = PeftModel.from_pretrained(
    #     model,
    #     "tmp/" + training_cfg.finetuned_model_id,
    #     adapter_name=training_cfg.finetuned_model_id,
    #     is_trainable=False,
    #     torch_dtype=torch.bfloat16,
    # )

    dataset, test_dataset = get_dataset(training_cfg)
    trainer = sft_train(training_cfg, dataset, model, test_dataset=test_dataset, tokenizer=tokenizer)
    eval_results = trainer.evaluate()
    print(eval_results)
    perplexity = torch.exp(torch.tensor(eval_results["eval_loss"])).item()
    print(f"Perplexity: {perplexity}")


if __name__ == "__main__":
    with open(sys.argv[1], "r") as f:
        config = json.load(f)
    training_config = TrainingConfig(**config)
    main(training_config)
