# Example of running generation with PARGS on the tldr dataset, with GPT2-SFT LM and deberta-v3 reward model

import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import torch
import time
import pandas as pd
import os
from datasets import load_dataset, load_from_disk, Dataset
from tqdm import tqdm



device = "cuda" if torch.cuda.is_available() else "cpu"
tqdm.pandas()

# =========================================================================================
# =================================== Load the models ===================================
# GPT2 model
gpt2_tokenizer = AutoTokenizer.from_pretrained("vistagi/gpt2-large-tldr-sum")
gpt2_model = AutoModelForCausalLM.from_pretrained("vistagi/gpt2-large-tldr-sum").to(device)

# decoding reward models
decode_reward_tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/reward-model-deberta-v3-large")
decode_reward_model_partial = AutoModelForSequenceClassification.from_pretrained("./partial_reward_modeling_tldr").to(device)


# =========================================================================================

def get_reward(Q, A, reward_tokenizer, reward_model):
    inputs = reward_tokenizer(Q, A, return_tensors='pt').to(device)
    external_reward = reward_model(**inputs).logits[0].cpu().detach().item()
    return external_reward


# ======================================================================================================
# ============================================== RGNS DECODING =========================================
# mode1: 1 - greedy, 2 - sampling

def PARGS_decoding(llm_model=None, llm_tokenizer=None, 
                reward_model=None, reward_tokenizer=None, topk=10,
                prompt=None, max_generation_length=64, mode1=2, w=1):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    input_ids = llm_tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    sequence = torch.tensor([[]],dtype=torch.int64).to(device)
    
    for t in range(0, max_generation_length):
        if t == 0:
            output = llm_model.generate(inputs=input_ids, max_new_tokens=1, 
                                    pad_token_id=llm_tokenizer.eos_token_id, 
                                    output_scores=True, return_dict_in_generate=True, 
                                    renormalize_logits=True)
        else:
            output = llm_model.generate(inputs=torch.cat((input_ids, sequence), dim=-1), max_new_tokens=1, 
                                pad_token_id=llm_tokenizer.eos_token_id, 
                                output_scores=True, return_dict_in_generate=True, 
                                renormalize_logits=True)
        
        topk_tokens = torch.topk(output["scores"][0][0], topk)
        RG_score = []
        
        # create the vector of the Reward-guided score
        for i in range(0, topk):
            token_index = topk_tokens.indices[i].reshape(1,1)
            token_prob = topk_tokens.values[i].item()
            temp_sequence = torch.cat((sequence, token_index), dim=-1)
            sequence_reward = get_reward(prompt, llm_tokenizer.decode(temp_sequence[0]), 
                                        reward_tokenizer, reward_model)

            RG_score.append(token_prob + w * sequence_reward)

        score_tensor = torch.tensor(RG_score)
        
        if mode1 == 1:
            sampled_id = torch.topk(score_tensor, 1).indices[0]
        elif mode1 == 2:
            sampled_id = torch.distributions.categorical.Categorical(logits=score_tensor).sample().item()
            
        sampled_token = topk_tokens.indices[sampled_id].reshape(1,1)
        sequence = torch.cat((sequence, sampled_token), dim=-1)
        
        if sequence[0][-1].item() == llm_tokenizer.eos_token_id:
            print(f"EOS BREAK: {t}")
            break
    
    generation = llm_tokenizer.decode(sequence[0], skip_special_tokens=True)
    return {"sequence": generation}


# ==========================================================================================================
# ========================================= testing function ===============================================
def test(prompt=None, topk=10, max_generation_length=64):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    
    model_inputs = gpt2_tokenizer(prompt, return_tensors='pt').to(device)

    # RGNS
    PARGS_output = PARGS_decoding(llm_model=gpt2_model, llm_tokenizer=gpt2_tokenizer, 
                reward_model=decode_reward_model_partial, reward_tokenizer=decode_reward_tokenizer,
                topk=topk, prompt=prompt, max_generation_length=max_generation_length, mode1=2, w=2.0)
    PARGS_score = get_reward(prompt, PARGS_output['sequence'], evaluate_reward_tokenizer, evaluate_reward_model)
    print(f"PARGS generation: {PARGS_output['sequence']}")
    print(f"PARGS_score: {PARGS_score}")
    print("\n\n")
    
    return 0


# ============================================================================
# =================================== Main ===================================
def test_main(sample_size=50, seed=42, topk=10, max_generation_length=64):

    # load datasets
    tldr_dataset = load_dataset("CarperAI/openai_summarize_tldr")
    test_data = tldr_dataset["test"]
    test_data = test_data.shuffle(seed=seed)
    
    all_results = pd.DataFrame()

    # run tests
    for i in range(0, sample_size):

        prompt = "Context: " + test_data[i]['prompt']
        label = test_data[i]['label']
        
        print(f"Prompt:\n{prompt}\n")
        print(f"Label:\n{label}")
        
        test(prompt, topk, max_generation_length)

if __name__ == "__main__":
    test_main(sample_size=100, seed=42, topk=10, max_generation_length=64)

