from peft import LoraConfig, PeftModel
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForSequenceClassification,
    EarlyStoppingCallback,
)
from datasets import load_from_disk

import argparse
import json

from trl import RewardTrainer, RewardConfig

from utils.data_processing import preprocess_dataset
from utils.env_management import save_config
import os
import torch

os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
os.environ["WANDB_TAGS"] = "reward-paired"


parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file", type=str, default="configs/reward-base.json", help="config file"
)
parser.add_argument("--cuda_device", type=int, default=0, help="cuda device to use")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

config = json.load(open(args.config_file))

os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)

save_config(config, "reward_train")

reward_config = config["reward"]
device = config["device"]  # for GPU usage or "cpu" for CPU usage
preprocess = config["preprocess"]
base_dir = config["base_dir"]

base_model = config["base_model"]

train_set = reward_config.get("train_set", "reward")
valid_set = reward_config.get("valid_set", "reward_valid")

tokenizer = AutoTokenizer.from_pretrained(
    base_model,
)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

torch.manual_seed(reward_config["training_args"].get("seed", 42))

checkpoint = reward_config["model"].format(**config)

ratings = load_from_disk(config["data_path"].format(**config))

if reward_config["sample_training"]:
    ratings[train_set] = (
        ratings[train_set]
        .shuffle(seed=reward_config["sample_training_seed"])
        .select(range(reward_config["sample_training_size"]))
    )

if tokenizer.chat_template is None:
    def generate_prompt_messages(messages):
        return "\n".join(
            [
                f"{message['role']}: {message['content']}"
                for message in messages
            ]
        )
    def generate_prompt(example):
        return {
            'chosen': generate_prompt_messages(example["chosen"]),
            'rejected': generate_prompt_messages(example["rejected"]),
        }
    ratings[train_set] = ratings[train_set].map(generate_prompt)
    ratings[valid_set] = ratings[valid_set].map(generate_prompt)


print(ratings)

modules_to_save = None

model = AutoModelForSequenceClassification.from_pretrained(
    base_model,
    num_labels=1,
    torch_dtype=config["torch_dtype"],
).to(device)

model.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id


if reward_config["has_peft"]:
    model = PeftModel.from_pretrained(
        model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    model = model.merge_and_unload()

if not reward_config.get("no_peft"):
    peft_config = LoraConfig(
        inference_mode=False,
        task_type="SEQ_CLS",
        **reward_config["lora_config"],
        modules_to_save=modules_to_save,
    )
else:
    peft_config = None

    # freeze all layers except the score layer
    for name, param in model.named_parameters():
        if "score" not in name:
            param.requires_grad = False
        else:
            param.requires_grad = True

training_args = RewardConfig(
    output_dir=reward_config["training_output_path"].format(**config),
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    **reward_config["training_args"],
)
training_args.dataset_num_proc = 1

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=reward_config["early_stopping_patience"]
)

print('train: ', ratings[train_set].shape)

trainer = RewardTrainer(
    model=model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=ratings[train_set],
    eval_dataset=ratings[valid_set],
    peft_config=peft_config,
    compute_metrics=None,
    callbacks=[early_stopping_callback],
    **reward_config["trainer_args"],
    max_length=256,
)

trainer.train(resume_from_checkpoint=reward_config["resume_from_checkpoint"])
trainer.model.save_pretrained(reward_config["model_output_path"].format(**config))

print(
    f"Best metric value: {trainer.state.best_metric}, best model checkpoint: {trainer.state.best_model_checkpoint}"
)