import os
import random
import sys
import time
import yaml

import torch
import accelerate
import numpy as np
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments
from peft import get_peft_model, LoraConfig, AutoPeftModelForCausalLM

sys.path.append(os.getcwd())

from src.utils.dataset import get_dataset
from src.core.reward_model import RegularizedRewardTrainer


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


def train(model, tokenizer, dataset, config):
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

    if "gradient_accumulation_steps" in config["training_kwargs"] and WORLD_SIZE > 1:
        old_grad_val = config["training_kwargs"]["gradient_accumulation_steps"]
        new_grad_val = max(1, old_grad_val // WORLD_SIZE)
        config["training_kwargs"]["gradient_accumulation_steps"] = new_grad_val
        print(
            f"Reducing gradient accumulation from {old_grad_val} to {new_grad_val} due to DDP world size {WORLD_SIZE}"
        )

    train_size = int(len(dataset) * (1 - config["train_test_split"]))
    train_dataset, eval_dataset, _ = torch.utils.data.random_split(
        dataset,
        (
            train_size,
            int(train_size * config["train_test_split"]),
            len(dataset) - train_size - int(train_size * config["train_test_split"]),
        ),
    )

    training_args = TrainingArguments(
        output_dir=os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", config["model_directory"]), **config["training_kwargs"]
    )
    trainer = RegularizedRewardTrainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        max_length=config["max_length"],
    )
    trainer.train()

    model.save_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", config["model_directory"]), safe_serialization=False)
    tokenizer.save_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", config["model_directory"]))


def load_reward_model(config, sft_model_path):
    model = AutoPeftModelForCausalLM.from_pretrained(sft_model_path, torch_dtype=torch.bfloat16)
    model = model.merge_and_unload()
    tmp_model_path = f"/tmp/model_{time.time_ns()}"
    model.save_pretrained(tmp_model_path, safe_serialization=False)
    del model
    model = AutoModelForSequenceClassification.from_pretrained(
        tmp_model_path, num_labels=1, device_map=device, torch_dtype=torch.bfloat16
    )
    if "lora_config" in config:
        lora_config = LoraConfig(**config["lora_config"])
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    return model


if __name__ == "__main__":
    device = accelerate.Accelerator().device
    load_dotenv()
    CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2")
    DATA_DIR = os.getenv("DATA_DIR", ".")
    DATASET = os.getenv("DATASET", "tldr")    
    set_seed(42)

    config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "reward_model.yaml")), yaml.Loader)
    sft_config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "sft.yaml")), yaml.Loader)
    sft_model_directory = os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", sft_config["model_directory"])
    tokenizer = AutoTokenizer.from_pretrained(sft_model_directory, padding_side="right")
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
    if config["train_dataset"] == "single":
        data_path = os.path.join(DATA_DIR, "data/datasets", DATASET, sft_config["output_directory"], "annotated_outputs.json")
        dataset = get_dataset(
            path_or_id=data_path,
            tokenizer=tokenizer,
            preprocessing="rm",
            max_length=config["max_length"],
            label_name=config["annotator"],
        )
    elif config["train_dataset"] == "mixed":
        dataset = get_dataset(
            path_or_id=os.path.join(DATA_DIR, "data/datasets", DATASET, "sft/annotated_outputs.json"),
            tokenizer=tokenizer,
            preprocessing="rm",
            max_length=config["max_length"],
            label_name=config["annotator"],
        )
    if os.getenv("WANDB_NAME", None) is None:
        os.environ["WANDB_NAME"] = config["model_name"] + " Reward Model " + DATASET
    if "n_ensembles" in config:
        model_directory = config["model_directory"]
        model = None
        for ensemble in range(config["n_ensembles"]):
            set_seed(ensemble)
            del model
            model = load_reward_model(config, sft_model_directory)
            config["model_directory"] = os.path.join(model_directory, f"ensemble_{ensemble}")
            train(model, tokenizer, dataset, config)
    else:
        model = load_reward_model(config, sft_model_directory)
        train(model, tokenizer, dataset, config)
