from peft import LoraConfig, PeftModel
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForSequenceClassification,
    EarlyStoppingCallback,
)
from datasets import load_from_disk

import argparse
import json

from modules.data_collator import RewardScoreDataCollatorWithPadding
from modules.reward_trainer import (
    RegressionRewardTrainer,
    RegressionRewardTrainerWithMoreFeatures,
    OdinRegressionRewardTrainer,
)
from modules.models import (
    LlamaForSequenceClassificationWithMoreFeatures,
)

from utils.data_processing import preprocess_dataset
from utils.env_management import save_config
import os
import torch

os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
os.environ["WANDB_TAGS"] = "reward"


parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file", type=str, default="configs/reward-base.json", help="config file"
)
parser.add_argument(
    "--cuda_device", type=int, default=0, help="cuda device to use"
)
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

config = json.load(open(args.config_file))

os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)

save_config(config, "reward_train")

reward_config = config["reward"]
device = config["device"]  # for GPU usage or "cpu" for CPU usage
preprocess = config["preprocess"]
base_dir = config["base_dir"]

train_set = reward_config["train_set"]
valid_set = reward_config["valid_set"]

base_model = config["base_model"]
training_model = reward_config["training_model"]

tokenizer = AutoTokenizer.from_pretrained(
    base_model,
)
tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

torch.manual_seed(reward_config["training_args"].get("seed", 42))

checkpoint = reward_config["model"].format(**config)

if preprocess:
    ratings = load_from_disk(config["data_path"].format(**config))
    print(ratings)
    ratings[train_set] = preprocess_dataset(
        ratings[train_set], tokenizer, config["max_input_length"], config["messages_template"]
    )
    ratings[valid_set] = preprocess_dataset(
        ratings[valid_set], tokenizer, config["max_input_length"], config["messages_template"]
    )
    ratings.save_to_disk(config["data_path_preprocessed"].format(**config))
else:
    ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

if reward_config["sample_training"]:
    ratings[train_set] = (
        ratings[train_set]
        .shuffle(seed=reward_config["sample_training_seed"])
        .select(range(reward_config["sample_training_size"]))
    )

print(ratings)

modules_to_save = None

if training_model == "sequence_classification":
    model = AutoModelForSequenceClassification.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
    ).to(device)
    reward_trainer = RegressionRewardTrainer
elif training_model == "more_features":
    model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    reward_trainer = RegressionRewardTrainerWithMoreFeatures
    modules_to_save = ["score", "score_features"]
elif training_model == "odin":
    model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=2,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    reward_trainer = OdinRegressionRewardTrainer
    modules_to_save = ["score"]
else:
    raise ValueError(f"Unknown training model: {training_model}")

model.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id


if reward_config["has_peft"]:
    model = PeftModel.from_pretrained(
        model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    model = model.merge_and_unload()

peft_config = LoraConfig(
    inference_mode=False,
    task_type="SEQ_CLS",
    **reward_config["lora_config"],
    modules_to_save=modules_to_save,
)

training_args = TrainingArguments(
    output_dir=reward_config["training_output_path"].format(**config),
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    **reward_config["training_args"],
)
training_args.dataset_num_proc = 1

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=reward_config["early_stopping_patience"]
)

trainer = reward_trainer(
    model=model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=ratings[train_set],
    eval_dataset=ratings[valid_set],
    peft_config=peft_config,
    data_collator=RewardScoreDataCollatorWithPadding(
        tokenizer=tokenizer,
        torch_dtype=config["torch_dtype"],
    ),
    reward_column_name=reward_config["reward_column_name"],
    compute_metrics=None,
    callbacks=[early_stopping_callback],
    **reward_config["trainer_args"],
)

trainer.train(resume_from_checkpoint=reward_config["resume_from_checkpoint"])
trainer.model.save_pretrained(reward_config["model_output_path"].format(**config))

print(
    f"Best metric value: {trainer.state.best_metric}, best model checkpoint: {trainer.state.best_model_checkpoint}"
)