from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
)
from datasets import load_from_disk

import argparse
import json

from modules.models import (
    LlamaForSequenceClassificationWithMoreFeatures,
)
from trl import PPOTrainer, PPOConfig

from utils.data_processing import preprocess_dataset
from utils.env_management import save_config
from tqdm import tqdm
import torch
import copy
import os
from modules.ppo_trainer import PPOTrainerFeatures
from safetensors.torch import load_file


parser = argparse.ArgumentParser(description="Train Reward")
parser.add_argument(
    "--config_file", type=str, default="configs/reward-base.json", help="config file"
)
parser.add_argument(
    "--cuda_device", type=int, default=0, help="cuda device to use"
)
args = parser.parse_args()

config = json.load(open(args.config_file))

ppo_config = config["ppo"]

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

if ppo_config["use_wandb"]:
    os.environ["WANDB_PROJECT"] = "finetuning-historical"  # name your W&B project
    os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints
    os.environ["WANDB_TAGS"] = "ppo"
    os.environ["WANDB_CACHE_DIR"] = config["wandb_cache_dir"].format(**config)

save_config(config, "train")

base_model = config["base_model"]
device = config["device"]  # for GPU usage or "cpu" for CPU usage
base_dir = config["base_dir"]

checkpoint = ppo_config["model"].format(**config)
reward_training_model = ppo_config["reward_training_model"]
reward_model_path = ppo_config["reward_model_base"].format(**config)
reward_model_peft_path = ppo_config["reward_model_peft"].format(**config)
reward_model_score_path = ppo_config.get("reward_model_score", "").format(**config)
normalize_reward_stats_path = ppo_config.get("normalize_reward_stats_path", "").format(
    **config
)

torch.set_default_dtype(getattr(torch, config["torch_dtype"]))

tokenizer = AutoTokenizer.from_pretrained(base_model, padding_side="left")

tokenizer.pad_token = config["pad_token"]

if "chat_template" in config:
    tokenizer.chat_template = config["chat_template"]
else:
    print("No chat template provided in config file using default")

torch.manual_seed(ppo_config["training_args"].get("seed", 42))

ratings = load_from_disk(config["data_path_preprocessed"].format(**config))

if ppo_config["sample_training"]:
    ratings["ppo"] = (
        ratings["ppo"]
        .shuffle(seed=ppo_config["sample_training_seed"])
        .select(range(ppo_config["sample_training_size"]))
    )

print(ratings)

remove_columns = ratings["ppo"].column_names

if reward_training_model == "sequence_classification":
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
    )
    value_model = AutoModelForSequenceClassification.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
    )
    ppo_trainer_cls = PPOTrainer
elif reward_training_model == "more_features":
    reward_model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    value_model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=1,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    ppo_trainer_cls = PPOTrainerFeatures

    # delete score and effort from remove_columns
    keep_columns = [
        "input_ids",
        "attention_mask",
        "score",
        "effort_quantile",
        "effort",
        "score_demean",
        "ctr_demean",
        "score_clean",
        "score_clean_demean",
        "score_no_effort",
        "score_no_effort_demean",
        "ctr",
    ]
    remove_columns = [col for col in remove_columns if col not in keep_columns]
elif reward_training_model == "odin":
    reward_model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=2,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    value_model = LlamaForSequenceClassificationWithMoreFeatures.from_pretrained(
        base_model,
        num_labels=2,
        problem_type="regression",
        torch_dtype=config["torch_dtype"],
        features_dim=1,
    ).to(device)
    ppo_trainer_cls = PPOTrainerFeatures
    keep_columns = [
        "input_ids",
        "attention_mask",
        "score",
        "effort_quantile",
        "effort",
        "score_demean",
        "ctr_demean",
        "score_clean",
        "score_clean_demean",
        "score_no_effort",
        "score_no_effort_demean",
        "ctr",
    ]
    remove_columns = [col for col in remove_columns if col not in keep_columns]

else:
    raise ValueError(f"Unknown training model: {reward_training_model}")

reward_model.resize_token_embeddings(len(tokenizer))
value_model.resize_token_embeddings(len(tokenizer))

reward_model.config.pad_token_id = tokenizer.pad_token_id
value_model.config.pad_token_id = tokenizer.pad_token_id

if ppo_config["reward_has_peft"]:
    reward_model = PeftModel.from_pretrained(
        reward_model,
        reward_model_path,
        torch_dtype=config["torch_dtype"],
        adapter_name="default",
    )

    reward_model = reward_model.merge_and_unload()

    value_model = PeftModel.from_pretrained(
        value_model,
        reward_model_path,
        torch_dtype=config["torch_dtype"],
        adapter_name="default",
    )

    value_model = value_model.merge_and_unload()

if reward_model_peft_path:
    reward_model = PeftModel.from_pretrained(
        reward_model, reward_model_peft_path, adapter_name="default"
    )

    reward_model = reward_model.merge_and_unload()

    value_model = PeftModel.from_pretrained(
        value_model, reward_model_peft_path, adapter_name="default"
    )

    value_model = value_model.merge_and_unload()

if ppo_config.get("reward_model_encoder"):
    encoder = VAE(
        input_dim=960,
        hidden_dims=ppo_config["hidden_dims"],
        embed_dim=ppo_config["treatment_dim"],
        device=device,
    ).to(device)
    encoder.load_state_dict(
        torch.load(ppo_config["reward_model_encoder"].format(**config))
    )
    encoder.eval()
    print(f"encoder weight {encoder.encoder[0].weight[0, :10]}")

    reward_model.encoder = encoder
    value_model.encoder = encoder

weights = None
if ppo_config.get("reward_model_score_tensors"):
    dir_path = ppo_config.get("reward_model_score_tensors").format(**config)
    model_tensors = load_file(dir_path + "/model.safetensors")
    weights = model_tensors["score.weight"]

if reward_model_score_path:
    weights = torch.load(reward_model_score_path)

if normalize_reward_stats_path:
    normalize_reward_stats = json.load(open(normalize_reward_stats_path))

    reward_model.score = torch.nn.Linear(
        in_features=weights.shape[0],
        out_features=1,
        bias=True,
    )
    value_model.score = torch.nn.Linear(
        in_features=weights.shape[0],
        out_features=1,
        bias=True,
    )

    reward_model.score.weight.data = weights.view(1, -1)
    reward_model.score.bias.data = torch.zeros(1, 1).to(device)

    reward_model.to(device)

    print(
        reward_model(
            input_ids=ratings["reward_valid"][:1]["input_ids"].to(device)
        ).logits,
    )
    print(weights.view(1, -1)[0, :10])

    reward_model.score.weight.data = weights.view(1, -1) / normalize_reward_stats["std"]
    value_model.score.weight.data = weights.view(1, -1) / normalize_reward_stats["std"]

    reward_model.score.bias.data = torch.tensor(
        [-normalize_reward_stats["mean"] / normalize_reward_stats["std"]]
    )
    value_model.score.bias.data = torch.tensor(
        [-normalize_reward_stats["mean"] / normalize_reward_stats["std"]]
    )
elif weights is not None:
    reward_model.score.weight.data = weights.view(1, -1)
    value_model.score.weight.data = weights.view(1, -1)


if reward_training_model == "effort_layer_separately":
    reward_model.score_effort.weight.data = torch.load(
        reward_model_peft_path + "/score_effort.pt"
    )

    print(reward_model.score.weight.data[0, :10])
    print(reward_model.score_effort.weight.data[:10])


def get_prompt(examples):
    messages_template = config["messages_template"]
    batch_size = len(next(iter(examples.values())))  # Get the batch size

    prompts = []
    for i in range(batch_size):
        messages = copy.deepcopy(messages_template)
        messages.pop()  # Drop the last message intended for the assistant

        # Prepare variables for string formatting
        example_vars = {key: examples[key][i] for key in examples}
        messages[1]["content"] = messages_template[1]["content"].format(**example_vars)

        # Generate the prompt for each example
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        prompts.append(prompt)

    # Tokenize all prompts in the batch
    output = tokenizer(prompts, padding=False)
    return {
        "input_ids": output["input_ids"],
        "attention_mask": output["attention_mask"],
    }


ratings["ppo"] = ratings["ppo"].map(
    get_prompt, batched=True, batch_size=1, remove_columns=remove_columns
)

ratings["test"] = (
    ratings["test"]
    .shuffle(seed=42)
    .select(range(8))
    .map(
        get_prompt,
        batched=True,
        batch_size=1,
        remove_columns=remove_columns,
    )
)


model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=config["torch_dtype"],
)

ref_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=config["torch_dtype"],
)

model.resize_token_embeddings(len(tokenizer))
ref_model.resize_token_embeddings(len(tokenizer))

if ppo_config["has_peft"]:
    model = PeftModel.from_pretrained(
        model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    model = model.merge_and_unload()

    ref_model = PeftModel.from_pretrained(
        ref_model, checkpoint, torch_dtype=config["torch_dtype"], adapter_name="default"
    )

    ref_model = ref_model.merge_and_unload()

peft_config = LoraConfig(inference_mode=False, **ppo_config["lora_config"])

model.add_adapter(peft_config, adapter_name="ppo_adapter")
model.enable_adapters()

ppo_training_config = PPOConfig(
    output_dir=ppo_config["training_output_path"].format(**config),
    **ppo_config["training_args"],
)

if ppo_config["trainer_args"].get("normalize_stats_path"):
    ppo_config["trainer_args"]["normalize_stats_path"] = ppo_config["trainer_args"][
        "normalize_stats_path"
    ].format(**config)

ppo_trainer = ppo_trainer_cls(
    config=ppo_training_config,
    processing_class=tokenizer,
    policy=model,
    ref_policy=ref_model,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=ratings["ppo"],
    eval_dataset=ratings["test"],
    **ppo_config.get("trainer_args", {}),
)

ppo_trainer.train()

ppo_trainer.policy.save_pretrained(ppo_config["model_output_path"].format(**config))
