from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from torch.optim import Adam
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    AutoModelForSeq2SeqLM,
)
import pandas as pd
from trl import (
    AutoModelForCausalLMWithValueHead,
    PPOConfig,
    PPOTrainer,
    create_reference_model,
    set_seed,
    AutoModelForSeq2SeqLMWithValueHead,
)
from trl.core import LengthSampler
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

tqdm.pandas()
import wandb
import json
import copy
import os


project = "project_name"
wandb.init(project=project)


@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine-tune with PPO
    """

    # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
    # models like gpt-neo* models are more suitable.
    model_name: Optional[str] = field(
        default="google/flan-t5-base",
        metadata={"help": "the model name"},
    )
    log_with: Optional[str] = field(
        default=None, metadata={"help": "use 'wandb' to log with wandb"}
    )
    learning_rate: Optional[float] = field(
        default=1.41e-5, metadata={"help": "the learning rate"}
    )
    mini_batch_size: Optional[int] = field(
        default=4, metadata={"help": "the PPO minibatch size"}
    )
    batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
    model_save_path: Optional[str] = field(
        default="./saved_path",
        metadata={"help": "the path to save the model"},
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    ppo_epochs=1,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    # log_with="wandb",
)


def build_dataset(
    config,
    dataset_name="knkarthick/dialogsum",
    input_min_text_length=200,
    input_max_text_length=1000,
):
    """
    Preprocess the dataset and split it into train and test parts.

    Parameters:
    - model_name (str): Tokenizer model name.
    - dataset_name (str): Name of the dataset to load.
    - input_min_text_length (int): Minimum length of the dialogues.
    - input_max_text_length (int): Maximum length of the dialogues.

    Returns:
    - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.
    """

    # load dataset (only "train" part will be enough for this lab).
    dataset = load_dataset(dataset_name, split="train")

    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.
    dataset = dataset.filter(
        lambda x: len(x["dialogue"]) > input_min_text_length
        and len(x["dialogue"]) <= input_max_text_length,
        batched=False,
    )

    # Prepare tokenizer. Setting device_map="auto" allows to switch between GPU and CPU automatically.
    tokenizer = AutoTokenizer.from_pretrained(config.model_name, device_map="auto")

    def tokenize(sample):

        # Wrap each dialogue with the instruction.
        prompt = f"""
        Summarize the following conversation.
        
        {sample["dialogue"]}
        
        Summary:
        """
        sample["input_ids"] = tokenizer.encode(prompt)

        # This must be called "query", which is a requirement of our PPO library.
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    # Tokenize each dialogue.
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")

    # Split the dataset into train and test parts.
    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)

    return dataset_splits


# We retrieve the dataloader by calling the `build_dataset` function.

dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


# def collator(data):
#     return dict((key, torch.stack([d[key] for d in data])) for key in data[0])


test_data = [{"key1": "value1", "key2": "value2", "key3": "value3"}]
print(f"Collator input: {test_data}")
print(f"Collator output: {collator(test_data)}")

set_seed(config.seed)
lora_config = LoraConfig(
    r=32,  # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,  # FLAN-T5
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    config.model_name, torch_dtype=torch.bfloat16
)

peft_model = PeftModel.from_pretrained(
    model,
    "sft_summarization_adapter",
    lora_config=lora_config,
    torch_dtype=torch.bfloat16,
    is_trainable=True,
)


def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


print(
    f"PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n"
)

# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`.
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
    peft_model, torch_dtype=torch.bfloat16, is_trainable=True
)
print(
    f"PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n"
)
print(ppo_model.v_head)

# We create a reference model by sharing 20 layers
ref_model = create_reference_model(ppo_model)
print(
    f"Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n"
)

# We make sure to use `Adam` optimizer on the model parameters that require gradients.
optimizer = Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config=config,
    model=ppo_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset["train"],
    data_collator=collator,
    # optimizer=optimizer,
)

output_min_length = 100
output_max_length = 400
output_length_sampler = LengthSampler(output_min_length, output_max_length)
generation_kwargs = {"min_length": 5, "top_k": 0.0, "top_p": 1.0, "do_sample": True}

reward_kwargs = {
    "top_k": None,  # Return all scores.
    "function_to_apply": "none",  # You want the raw logits without softmax.
    "batch_size": 16,
}

model_save_path = script_args.model_save_path

from huggingface_hub import login

token = "token"
login(token=token)

print("Above loading reward model")
reward_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
reward_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct"
).to("cuda")
print("Reward model loaded")


def get_reward(prompt_text, predicted_text):
    try:
        prompt = f"""
        <|begin_of_text|><|start_header_id|>system<|end_header_id|>

        Evaluate the given summary on the following criteria using a scale of 1 to 7:

        1. Coherence: How coherent is the summary on its own?
        2. Accuracy: Does the factual information in the summary accurately match the original text?
        3. Coverage: How well does the summary cover the important information in the original text?

        Use the following rubrics for scoring:

        Coherence:
        1: The summary is impossible to understand.
        4: The summary has mistakes or confusing phrasing that make it a bit hard to understand.
        7: The summary is perfectly clear.

        Accuracy:
        1: The summary is completely wrong, made up, or exactly contradicts what is written in the original text.
        4: The summary says at least one substantial thing that is not mentioned in the original text, or that contradicts something in the original text.
        5: The summary says anything, no matter how small, that is not mentioned in the original text, or that contradicts something in the original text.
        7: The summary has no incorrect statements or misleading implications.

        Coverage:
        1: The summary contains no information relevant to the original text.
        4: The summary is missing at least 1 important piece of information required to understand the situation.
        7: The summary covers all of the important information required to understand the situation.

        Return the answer as a JSON object with the following fields. DO NOT INCLUDE ANY OTHER OUTPUT. <|eot_id|><|start_header_id|>user<|end_header_id|>
        {{
            "coherence": "score",
            "accuracy": "score",
            "coverage": "score",
        }}

        Text to evaluate is below:
        ========
        original text:{prompt_text}

        summary:{predicted_text}
        ========
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

        input_tokens = reward_tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        generation_output = reward_model.generate(
            input_tokens,
            do_sample=False,  # <-- Deterministic sampling
            temperature=0.0,  # <-- Temperature set to 0 for deterministic generation
            top_p=1.0,  # <-- Set top_p to 1 to ensure deterministic output
            max_new_tokens=1024,
        )

        output_text = reward_tokenizer.decode(
            generation_output[0], skip_special_tokens=True
        )

        # Debug output to understand the format of generated output
        print(f"Generated output: {output_text}")

        # Safely parse the output to avoid index errors
        if "assistant" in output_text and "{" in output_text and "}" in output_text:
            output_text = output_text.split("assistant")[-1]
            output_text = output_text.split("{", 1)[1].rsplit("}", 1)[0]
            output_text = "{" + output_text + "}"
            scores = json.loads(output_text)
        else:
            print("Unexpected format of the output")
            return 0.00001

        # Debug output to understand the format of generated output
        print("Generated output:", scores)  # Debugging line

        # Calculate the sum of the scores
        sum_scores = sum(scores.values())

        # weighted_sum = 0
        # for key, value in scores.items():
        #     weighted_sum += value * weights[key]

        return float(sum_scores)

    except Exception as e:
        print("#" * 100)
        print(e)
        return 0.00001


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    query_texts = batch["query"]

    #### Get response from gpt2
    response_tensors = []
    reward_tensors = []

    for prompt_tensor, prompt_text in zip(query_tensors, query_texts):
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        response = response.squeeze()[-gen_len:]
        predicted_text = tokenizer.decode(response.squeeze())
        reward = get_reward(prompt_text, predicted_text)

        response_tensors.append(response)
        reward_tensors.append(reward)

    batch["response"] = [response for response in response_tensors]
    rewards = [torch.tensor(reward, device="cuda") for reward in reward_tensors]
    torch.cuda.empty_cache()

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    # Log stats
    # ppo_trainer.log_stats(stats, batch, rewards)

    print(f"Input: {query_texts[-1]}")
    print(f"Output: {predicted_text}")
    print(f'objective/kl: {stats["objective/kl"]}')
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}')
    print(f'ppo/loss/total: {stats["ppo/loss/total"]}')
    print(f"Epoch: {epoch}, Reward: {rewards}")


"""
Save Model
"""
model_save_dir = f"./{project}"
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)
ppo_trainer.model.save_pretrained(f"./{project}")
tokenizer.save_pretrained(f"./{project}")


"""
Push model to HF
"""
ppo_trainer.model.push_to_hub(project)
tokenizer.push_to_hub(project)

#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset["test"][:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    generation_kwargs["max_new_tokens"] = gen_len
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), **generation_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = ppo_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), **generation_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [
    tokenizer.decode(response_tensors_ref[i]) for i in range(bs)
]
game_data["response (after)"] = [
    tokenizer.decode(response_tensors[i]) for i in range(bs)
]

#### sentiment analysis of query/response pairs before/after
before_reward = []
for q, r in zip(game_data["query"], game_data["response (before)"]):
    reward = get_reward(q, r)
    before_reward.append(reward)

game_data["rewards (before)"] = before_reward

after_reward = []
for q, r in zip(game_data["query"], game_data["response (after)"]):
    reward = get_reward(q, r)
    after_reward.append(reward)

game_data["rewards (after)"] = after_reward

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results.to_csv("./rlhf_summ_3_equal.csv")

print("mean:")
print(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
print(df_results[["rewards (before)", "rewards (after)"]].median())
