from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    EarlyStoppingCallback,
)

from datasets import load_from_disk
from tqdm import tqdm

import argparse
import json

from trl import SFTTrainer, SFTConfig
# from accelerate.utils import BnbQuantizationConfig
from transformers import BitsAndBytesConfig

import torch


from utils.data_processing import preprocess_dataset, chars_token_ratio
from utils.env_management import save_config
from trl.trainer import ConstantLengthDataset
import os

os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
os.environ["WANDB_TAGS"] = "sft"

parser = argparse.ArgumentParser(description="Train SFT")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/mind-sport-orthogonal/base.json",
    help="config file",
)
parser.add_argument("--cuda_device", type=int, default=0, help="cuda device to use")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

config = json.load(open(args.config_file))

os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)
os.environ["WANDB_ARTIFACT_DIR"] = (
    config["wandb_cache_dir"].format(**config) + "/artifacts"
)

save_config(config, "sft_train")
base_model = config["base_model"]

sft_config = config["sft"]
checkpoint = sft_config["model"].format(**config)

device = config["device"]  # for GPU usage or "cpu" for CPU usage
preprocess = config["preprocess"]
base_dir = config["base_dir"]

tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

torch.manual_seed(sft_config["training_args"].get("seed", 42))

train_set = sft_config.get("train_set", "sft")

if preprocess:
    ratings = load_from_disk(config["data_path"].format(**config))
    ratings[train_set] = preprocess_dataset(
        ratings[train_set],
        tokenizer,
        config["max_input_length"],
        config["messages_template"],
    )
    ratings.save_to_disk(config["data_path_preprocessed"].format(**config))
else:
    ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

if sft_config["sample_training"]:
    ratings[train_set] = (
        ratings[train_set]
        .shuffle(seed=sft_config["sample_training_seed"])
        .select(range(sft_config["sample_training_size"]))
    )

chars_per_token = chars_token_ratio(ratings[train_set], tokenizer)

quantization_config = {"torch_dtype": config["torch_dtype"]}
if sft_config["quantize_8bit"]:
    bnb_quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0,
    )
    quantization_config = {
        "quantization_config": bnb_quantization_config,
        "torch_dtype": config["torch_dtype"],
    }

model = AutoModelForCausalLM.from_pretrained(
    checkpoint, **quantization_config
).to(device)

if sft_config["quantize_8bit"]:
    model = prepare_model_for_kbit_training(model)

    for name, param in model.named_parameters():
        print(name, param.dtype)
        if param.dtype == torch.float32:
            param.data = param.data.to(torch.bfloat16)

model.resize_token_embeddings(len(tokenizer))

peft_config = LoraConfig(inference_mode=False, **sft_config["lora_config"])

training_args = SFTConfig(
    output_dir=sft_config["training_output_path"].format(**config),
    **sft_config["training_args"]
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=ratings[train_set],
    peft_config=peft_config,
)

trainer.train(resume_from_checkpoint=sft_config["resume_from_checkpoint"])
trainer.model.save_pretrained(
    sft_config["model_output_path"].format(**config), save_embedding_layers=True
)
