from datasets import load_dataset, Dataset
from tqdm import tqdm
import random
import argparse
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
import torch
from peft import PeftModel
from peft import PeftConfig

cache_dir = "/cmlscratch/agrawal5/cache"


def make_step_rewards(logits, token_masks):
    """Extracts step-wise reward scores from model logits."""
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels
    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[~((sample[:, 0] == 0) & (sample[:, 1] == 0))][:, 1]  # Extract reward scores
        #positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # Extract reward scores
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    return all_scores_res


def load_prm_model(model_name, device):
    """Loads the Preference Reward Model (PRM) and tokenizer."""
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True)

    model = AutoModel.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",  # Uses all available GPUs
        trust_remote_code=True,
        cache_dir=cache_dir
    )

    # peft_path   = "./curriculum_learning/Qwen2.5-Math-PRM-7B-pref_0.5_to_1/checkpoint-861"    
    # model = PeftModel.from_pretrained(model, peft_path)

    model.eval()

    return model, tokenizer


def evaluate_responses(model, tokenizer, dataset, args):
    # Prepare a list to hold the new dataset entries
    new_data = []
    i = -1
    add = 0
    step_sep_id = tokenizer.encode("<extra_0>")[0]
    curriculum_data_entry = [[] for _ in range(6)]

    # Extract data and populate the list
    for item in dataset:
        i += 1
        
        question = item["question"]["problem"]
        response_steps = item["label"]["steps"]

        chosen_step_until_now = []
        rej_flag = False

        if len(response_steps) > 50:
            continue
        
        for i, s in enumerate(response_steps):
            if (s["completions"] is not None) and (len(s["completions"]) > 1):
                human_step = s.get("human_completion")
                if human_step is not None:
                    chosen = human_step["text"]
                else:
                    chosen_index = s.get("chosen_completion")
                    if chosen_index is not None:
                        chosen = s["completions"][chosen_index]["text"]
                    else:
                        break

                
                if human_step is not None:
                    all_rejects = [c["text"] for c in s["completions"] if ((c["text"] != chosen) and (c["rating"] != 1))]
                else:
                    completions = s["completions"]
                    all_rejects = [completions[i]["text"] for i in range(len(completions)) if ((i != chosen_index) and (completions[i]["rating"] != 1))]
                
                for t in response_steps[i+1:]:
                    if t["completions"] is not None:
                        for c in t["completions"]:
                            all_rejects.append(c["text"])
                if len(all_rejects) > 0: rej_flag = True

            else:
                continue

            chosen_step_until_now.append(chosen)

            if rej_flag:
                rej_flag = False
                rejects_to_use = all_rejects

                messages = [
                    {"role": "user", "content": question},
                    {"role": "assistant", "content": "<extra_0>".join(chosen_step_until_now) + "<extra_0>"},
                ]

                conversation_str = tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=False
                )

                input_id = tokenizer.encode(conversation_str, return_tensors="pt").to("cuda")
                outputs = model(input_ids=input_id) 
                token_masks = (input_id == step_sep_id)
                step_rewards_chosen = make_step_rewards(outputs[0], token_masks)


                for rejected in rejects_to_use:
                    messages = [
                        {"role": "user", "content": question},
                        {"role": "assistant", "content": "<extra_0>".join(chosen_step_until_now[:-1]) + "<extra_0>" + rejected + "<extra_0>"},
                    ]

                    conversation_str = tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=False
                    )

                    input_id = tokenizer.encode(conversation_str, return_tensors="pt").to("cuda")
                    outputs = model(input_ids=input_id) 
                    token_masks = (input_id == step_sep_id)
                    step_rewards_rejected = make_step_rewards(outputs[0], token_masks)

                    diff = step_rewards_chosen[0][-1] - step_rewards_rejected[0][-1]

                    entry = {
                            "chosen": [
                                {"role": "user",      "content": question},
                                {"role": "assistant", "content": "<extra_0>".join(chosen_step_until_now) + "<extra_0>"},
                            ],
                            "rejected": [
                                {"role": "user",      "content": question},
                                {"role": "assistant", "content": "<extra_0>".join(chosen_step_until_now[:-1]) + "<extra_0>" + rejected + "<extra_0>"},
                            ],
                        }

                    if (0.5 < diff < 1.0):
                        add += 1
                        print(add, i, len(response_steps),diff)
                        new_data.append(entry)

                    if args.curriculum:
                        if (0.98 < diff <= 1.0):
                            curriculum_data_entry[0].append(entry)
                        elif (0.2 < diff <= 0.98):
                            curriculum_data_entry[1].append(entry)
                            # elif (0.3 < diff <= 0.5):
                            #     curriculum_data_entry[2].append(entry)
                        elif (0.1 < diff <= 0.2):
                            curriculum_data_entry[2].append(entry)
                        elif (0.0 < diff <= 0.1):
                            curriculum_data_entry[3].append(entry)
                        elif (-0.2 < diff <= 0.0):
                            curriculum_data_entry[4].append(entry)
                        elif (diff <= -0.2):
                            curriculum_data_entry[5].append(entry)

            else:
                chosen_step_until_now.append(chosen)    
        
    print("saving")        
    # Create a new dataset from the list of dictionaries
    if args.curriculum:
        for i in range(1,2):
            new_dataset = Dataset.from_list(curriculum_data_entry[i])
            new_dataset.save_to_disk(f'./curriculum_learning/sameques_aug')
            print(len(new_dataset))
    else:
        new_dataset = Dataset.from_list(new_data)
        # Save the new dataset to disk
        new_dataset.save_to_disk('./preference_data_0.5_to_1')

        print(len(new_dataset))




def main():
    """Main function to parse arguments and run evaluation."""
    parser = argparse.ArgumentParser(description="Evaluate responses using Qwen PRM model.")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-Math-PRM-7B", help="PRM model name")
    parser.add_argument("--curriculum", action='store_true', help="Enable curriculum learning")

    
    args = parser.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    
    print("loading dataset")
    dataset = load_dataset("tasksource/PRM800K", split="train", streaming=True)
    print("loaded dataset")

    print(f"Loading model: {args.model_name} on {device}")
    model, tokenizer = load_prm_model(args.model_name, device)


    evaluate_responses(model, tokenizer, dataset, args)


if __name__ == "__main__":
    main()
    
