import os
import random
import sys
from time import time_ns
from datetime import timedelta

import numpy as np
import yaml
import torch
import torch.distributed as dist
from datasets import load_dataset
from dotenv import load_dotenv
from transformers import AutoTokenizer
from peft import LoraConfig, AutoPeftModelForCausalLM
import trlx
from trlx.data.configs import (
    ModelConfig,
    OptimizerConfig,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
    TRLConfig,
)
from trlx.models.modeling_ppo import PPOConfig
import trlx.utils.logging as logging

sys.path.append(os.getcwd())

from src.utils.model_loading import load_trained_reward_model
from src.utils.dataset import get_dataset
from src.core.ppo_trainer import CustomMetricPPOTrainer
from src.core.reward_model import RewardModel

logger = logging.get_logger(__name__)


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

def run_job():
    from accelerate import Accelerator
    load_dotenv()
    DATA_DIR = os.getenv("DATA_DIR", ".")
    DATASET = os.getenv("DATASET", "tldr")    
    CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2")
    N_ENSEMBLES = int(os.getenv("N_ENSEMBLES", 5))
    ENSEMBLE_AGG_FN = os.getenv("ENSEMBLE_AGG_FN", "mean")
    ENSEMBLE_KL = os.getenv("ENSEMBLE_KL", "fixed")
    accelerator = Accelerator()

    ppo_config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "ppo.yaml")), yaml.Loader)
    sft_config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "sft.yaml")), yaml.Loader)
    rm_config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "reward_model.yaml")), yaml.Loader)
    set_seed(42)

    if accelerator.is_main_process:
        # Load the pre-trained reward model
        num_devices = torch.cuda.device_count()
        rw_device = torch.device(f"cuda:{num_devices - 2}")
        rw_tokenizer = AutoTokenizer.from_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", rm_config["model_directory"], "ensemble_0"))
        if rw_tokenizer.pad_token_id is None:
            rw_tokenizer.pad_token_id = rw_tokenizer.eos_token_id
        rw_models = [load_trained_reward_model(os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", rm_config["model_directory"], f"ensemble_{i}"), os.path.join(DATA_DIR, "data/models/", DATASET, "sft-models", sft_config["model_directory"]), device=rw_device, pad_token_id=rw_tokenizer.pad_token_id) for i in range(N_ENSEMBLES)]
        for m in rw_models:
            if m.config.pad_token_id is None:
                m.config.pad_token_id = rw_tokenizer.pad_token_id

        gr_device = torch.device(f"cuda:{num_devices - 1}")
        gr_tokenizer = AutoTokenizer.from_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", ppo_config["gold_rm_directory"]))
        if gr_tokenizer.pad_token_id is None:
            gr_tokenizer.pad_token_id = gr_tokenizer.eos_token_id
        gr_model = load_trained_reward_model(os.path.join(DATA_DIR, "data/models/trlx/reward-models", ppo_config["gold_rm_directory"]), os.path.join(DATA_DIR, "data/models", DATASET, "sft-models/", ppo_config["gold_rm_directory"]), device=gr_device, pad_token_id=gr_tokenizer.pad_token_id)
        if gr_model.config.pad_token_id is None:
            gr_model.config.pad_token_id = gr_tokenizer.pad_token_id

        tokenizer = AutoTokenizer.from_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", sft_config["model_directory"]))
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        
        reward_model = RewardModel(rw_models, rw_tokenizer, gr_model, gr_tokenizer, agg_fn=ENSEMBLE_AGG_FN, max_length=sft_config["max_length"])
    else:
        reward_model = RewardModel(None, None)


    train_dataset = get_dataset(DATASET, None, preprocessing="text", split="train")
    valid_dataset = get_dataset(DATASET, None, preprocessing="text", split="valid")

    train_prompts = list(map(lambda x: x["prompt"], train_dataset))
    valid_prompts = list(map(lambda x: x["prompt"], valid_dataset))

    sft_lora_config = LoraConfig(**sft_config["lora_config"]) if "lora_config" in sft_config else None
    config = TRLConfig(
        train=TrainConfig(
            seq_length=sft_config["max_length"],
            minibatch_size=sft_config["batch_size"],
            **ppo_config["training"],
            pipeline="PromptPipeline",
            trainer="CustomMetricPPOTrainer",
            trainer_kwargs={"reward_model": reward_model, "kl_ctl": ENSEMBLE_KL},
            project_name="Uncertainty RLHF",
            run_name=f"{CONFIG_DIR} PPO {ENSEMBLE_AGG_FN} {N_ENSEMBLES}xRM KL_ctl={ENSEMBLE_KL} {DATASET}",
            checkpoint_dir=os.path.join(DATA_DIR, "data/models", DATASET, "ppo-models", ppo_config["model_directory"]),
            # tracker=None
        ),
        model=ModelConfig(
            model_path=os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", sft_config["model_directory"]),
            num_layers_unfrozen=ppo_config["num_layers_unfrozen"],
            peft_config=sft_lora_config,
        ),
        tokenizer=TokenizerConfig(
            tokenizer_path=os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", sft_config["model_directory"]),
            truncation_side="right",
        ),
        optimizer=OptimizerConfig(
            name="adamw",
            kwargs=ppo_config["optimizer"],
        ),
        scheduler=SchedulerConfig(
            name="cosine_annealing",
            kwargs=ppo_config["scheduler"],
        ),
        method=PPOConfig(
            name="PPOConfig",
            **ppo_config["ppo_config"],
            scale_reward=None,
            ref_mean=None,
            ref_std=None,
            gen_kwargs=sft_config["generation_kwargs"],
        )
    )

    # trlX won't load Peft Model well from directory, we need to merge and save all the weights and let it reinitialize LoRA
    tmp_sft_model = AutoPeftModelForCausalLM.from_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", sft_config["model_directory"]), torch_dtype=torch.bfloat16)
    tmp_sft_model = tmp_sft_model.merge_and_unload()
    sft_tmp_path = f"/tmp/model_{time_ns()}"
    tmp_sft_model.save_pretrained(sft_tmp_path, safe_serialization=False)
    
    trainer = trlx.train(
        model_path=sft_tmp_path,
        reward_fn=reward_model.score,
        prompts=train_prompts,
        eval_prompts=valid_prompts,
        config=config
    )

if __name__ == "__main__":
    run_job()