from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from torch.optim import Adam
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
)
import pandas as pd
from trl import (
    AutoModelForCausalLMWithValueHead,
    PPOConfig,
    PPOTrainer,
    create_reference_model,
    set_seed,
    AutoModelForSeq2SeqLMWithValueHead,
)
from trl.core import LengthSampler

tqdm.pandas()
import wandb
import json
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model


project = "project"
wandb.init(project=project)


@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine-tune with PPO
    """

    # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
    # models like gpt-neo* models are more suitable.
    model_name: Optional[str] = field(
        default="sft_model",
        metadata={"help": "the model name"},
    )
    log_with: Optional[str] = field(
        default=None, metadata={"help": "use 'wandb' to log with wandb"}
    )
    learning_rate: Optional[float] = field(
        default=1.41e-5, metadata={"help": "the learning rate"}
    )
    mini_batch_size: Optional[int] = field(
        default=4, metadata={"help": "the PPO minibatch size"}
    )
    batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
    model_save_path: Optional[str] = field(
        default="./llama7b_qa_rlhf",
        metadata={"help": "the path to save the model"},
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    ppo_epochs=1,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    # log_with="wandb",
)


def build_dataset(
    config,
    dataset_name="PKU-Alignment/PKU-SafeRLHF-QA",
    input_min_text_length=5,
    input_max_text_length=500,
):
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.

    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """

    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")

    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(
        lambda x: len(x["prompt"]) > input_min_text_length
        and len(x["prompt"]) <= input_max_text_length,
        batched=False,
    )
    dataset = dataset.select(range(min(25000, len(dataset))))

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(config.model_name, device_map="auto")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    def tokenize(sample):

        # Wrap each dialogue with the instruction.
        prompt = f"Prompt: {sample['prompt']}\n\nResponse:"
        sample["input_ids"] = tokenizer.encode(prompt)

        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")

    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.1, shuffle=False, seed=42)

    return dataset_splits


# We retrieve the dataloader by calling the `build_dataset` function.

dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f"Collator input: {test_data}")
print(f"Collator output: {collator(test_data)}")

set_seed(config.seed)
lora_config = LoraConfig(
    r=8,  # Rank
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
    config.model_name, torch_dtype=torch.bfloat16, device_map="auto"
)

# Apply the new PEFT configuration to the merged model
peft_model = get_peft_model(model, lora_config)

# peft_model = PeftModel.from_pretrained(
#     model,
#     "./peft-dialogue-summary-checkpoint-local",
#     lora_config=lora_config,
#     torch_dtype=torch.bfloat16,
#     is_trainable=True,
# )


def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


print(
    f"PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n"
)

# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`.
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    peft_model, torch_dtype=torch.bfloat16, is_trainable=True
)
print(
    f"PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n"
)
print(ppo_model.v_head)

# We create a reference model by sharing 20 layers
ref_model = create_reference_model(ppo_model)
print(
    f"Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n"
)

# We make sure to use `Adam` optimizer on the model parameters that require gradients.
optimizer = Adam(
    filter(lambda p: p.requires_grad, ppo_model.parameters()), lr=config.learning_rate
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config=config,
    model=ppo_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset["train"],
    data_collator=collator,
    # optimizer=optimizer,
)

output_min_length = 5
output_max_length = 128
output_length_sampler = LengthSampler(output_min_length, output_max_length)
# generation_kwargs = {"min_length": 5, "top_k": 0.0, "top_p": 1.0, "do_sample": True}
generation_kwargs = {
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000,
}
reward_kwargs = {
    "top_k": None,  # Return all scores.
    "function_to_apply": "none",  # You want the raw logits without softmax.
    "batch_size": 16,
}

model_save_path = script_args.model_save_path

reward_tokenizer = AutoTokenizer.from_pretrained(
    "TheBloke/Mistral-7B-Instruct-v0.2-AWQ", add_eos_token=True
)
reward_tokenizer.pad_token = reward_tokenizer.eos_token
config = PeftConfig.from_pretrained("reward_model_adapter_name")

reward_model = AutoModelForSequenceClassification.from_pretrained(
    "TheBloke/Mistral-7B-Instruct-v0.2-AWQ",
    num_labels=1,
    low_cpu_mem_usage=True,  # Ensure low memory usage mode is enabled
    device_map={"": 0},
)
reward_model = PeftModel.from_pretrained(reward_model, "reward_model_adapter_name").to(
    "cuda"
)
reward_model.config.pad_token_id = reward_model.config.eos_token_id
print("Reward model loaded")


def get_reward(prompt_text, predicted_text):
    try:
        inputs = reward_tokenizer(
            prompt_text + predicted_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
        with torch.no_grad():
            outputs = reward_model(**inputs)
        reward = outputs.logits.squeeze().item()
        return reward

    except Exception as e:
        print("#" * 100)
        print("#" * 100)
        print("#" * 100)
        print(e)
        return 0.00001


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    query_texts = batch["query"]

    #### Get response from gpt2
    response_tensors = []
    reward_tensors = []

    for prompt_tensor, prompt_text in zip(query_tensors, query_texts):
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        response = response.squeeze()[-gen_len:]
        predicted_text = tokenizer.decode(response.squeeze())
        reward = get_reward(prompt_text, predicted_text)

        response_tensors.append(response)
        reward_tensors.append(reward)

    batch["response"] = [response for response in response_tensors]
    rewards = [torch.tensor(reward, device="cuda") for reward in reward_tensors]
    torch.cuda.empty_cache()

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

    print(f"Input: {query_texts[-1]}")
    print(f"Output: {predicted_text}")
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/loss/total: {stats["ppo/loss/total"]}')
    print(f"Epoch: {epoch}, Reward: {rewards}")


"""
Save Model
"""
import os

# Ensure the directory exists before saving the model
model_save_path = f"./{project}"
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
ppo_trainer.model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

"""
Push model to HF
"""
from huggingface_hub import login

token = "token"
login(token=token)

ppo_trainer.model.push_to_hub(project)
tokenizer.push_to_hub(project)

#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
# df_batch = dataset["test"][:].sample(bs)
df_batch = dataset["test"].select(range(min(10, len(dataset["test"])))).sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    # Generate response from ref_model
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), **generation_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)

    # Generate response from ppo_model
    output = ppo_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), **generation_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [
    tokenizer.decode(response_tensors_ref[i]) for i in range(bs)
]
game_data["response (after)"] = [
    tokenizer.decode(response_tensors[i]) for i in range(bs)
]

#### sentiment analysis of query/response pairs before/after
before_reward = []
for q, r in zip(game_data["query"], game_data["response (before)"]):
    reward = get_reward(q, r)
    before_reward.append(reward)

game_data["rewards (before)"] = before_reward

after_reward = []
for q, r in zip(game_data["query"], game_data["response (after)"]):
    reward = get_reward(q, r)
    after_reward.append(reward)

game_data["rewards (after)"] = after_reward

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results.to_csv("./llama7b_rlhf_qa.csv")

print("mean:")
print(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
print(df_results[["rewards (before)", "rewards (after)"]].median())
