import os
import time

import pandas as pd
import torch
from accelerate import Accelerator
from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from dr.dataset import get_uf_ppo_dataset

GOLD_MODEL = "Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback"


def get_uf_gold_scores(max_length: int = 1024, batch_size: int = 16) -> None:
    results_dir = "eval_rlhf/gold"
    if os.path.exists(results_dir):
        return

    model_name_or_path = GOLD_MODEL

    accelerator = Accelerator()
    device = Accelerator().local_process_index
    print("Device", device, "number of processes:", accelerator.num_processes)

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
    tokenizer.model_max_length = max_length
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=1, device_map=device, torch_dtype=torch.bfloat16)

    if accelerator.is_main_process:
        datasets = get_uf_ppo_dataset(tokenizer, max_length, dataset_num_proc=20, version="20k", remove_columns=True)

    for split_name in ["train", "val"]:
        dataset = datasets[split_name]

        # Shard the dataset among processes
        sampler = DistributedSampler(dataset, num_replicas=accelerator.num_processes, rank=accelerator.local_process_index, shuffle=False)  # type: ignore
        # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
        data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, drop_last=False)  # , collate_fn=data_collator)  # type: ignore
        # data_loader = accelerator.prepare(data_loader)

        # Generate and collect results
        full_rewards_chosen: list[Tensor] = []
        full_rewards_rejected: list[Tensor] = []
        full_chosen_prompts: list[Tensor] = []
        full_rejected_prompts: list[Tensor] = []
        full_unique_ids = []

        with torch.no_grad():
            for i, batch in tqdm(enumerate(data_loader), desc=f"Eval {split_name}"):
                reward_chosen_tensors: Tensor = model(batch["input_ids"].to(device), attention_mask=batch["attention_mask_chosen"].to(device)).logits.reshape(-1)
                reward_rejected_tensors: Tensor = model(batch["input_ids_rejected"].to(device), attention_mask=batch["attention_mask_rejected"].to(device)).logits.reshape(-1)
                full_rewards_chosen.extend(reward_chosen_tensors)
                full_rewards_rejected.extend(reward_rejected_tensors)
                full_chosen_prompts.extend(batch["input_ids"])
                full_rejected_prompts.extend(batch["input_ids_rejected"])

        full_chosen_prompts = [x.rstrip(tokenizer.pad_token) for x in tokenizer.batch_decode(full_chosen_prompts)]  # type: ignore
        full_rejected_prompts = [x.rstrip(tokenizer.pad_token) for x in tokenizer.batch_decode(full_rejected_prompts)]  # type: ignore

        full_rewards_chosen = [x.item() for x in full_rewards_chosen]  # type: ignore
        full_rewards_rejected = [x.item() for x in full_rewards_rejected]  # type: ignore
        if "unique_id" in batch.keys():
            full_unique_ids = [x.item() for x in full_unique_ids]

        # print(f'Process {accelerator.local_process_index} processed {len(full_chosen_prompts)} prompts')
        accelerator.wait_for_everyone()

        all_chosen_prompts = accelerator.gather_for_metrics(full_chosen_prompts)
        all_rejected_prompts = accelerator.gather_for_metrics(full_rejected_prompts)
        all_rewards_chosen = accelerator.gather_for_metrics(full_rewards_chosen)
        all_rewards_rejected = accelerator.gather_for_metrics(full_rewards_rejected)
        if "unique_id" in batch.keys():
            all_unique_ids = accelerator.gather_for_metrics(full_unique_ids)

        if not accelerator.is_main_process:
            continue

        evaluation_result = {
            "prompts_chosen": all_chosen_prompts,
            "prompts_rejected": all_rejected_prompts,
            "rewards_chosen": all_rewards_chosen,
            "rewards_rejected": all_rewards_rejected,
        }
        if "unique_id" in batch.keys():
            evaluation_result["unique_ids"] = all_unique_ids

        gold_scores = pd.DataFrame(evaluation_result)
        if "unique_id" in batch.keys():
            gold_scores = gold_scores.sort_values(by="unique_ids")
        # gold_scores = gold_scores.drop_duplicates(subset='unique_ids', keep='first')
        gold_scores = gold_scores.reset_index(drop=True)

        os.makedirs(results_dir, exist_ok=True)
        save_filename = os.path.join(results_dir, f"gold_score_{split_name}_{model_name_or_path.replace('/', '_')}.csv")
        gold_scores.to_csv(save_filename)
        print(f"Saved to '{save_filename}'.")

    # # Replace with the gold scores
    # def replace_with_gold_reward(example):
    #     matching_row = gold_scores[gold_scores["unique_ids"] == example["unique_id"]]
    #     example["conv_A_rating"] = matching_row.iloc[0]["rewards_A"]
    #     example["conv_B_rating"] = matching_row.iloc[0]["rewards_B"]
    #     return example

    # # Apply the replacement function to the dataset
    # tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
    # dataset_prepared = load_dataset_within_maxlength(script_args.data_path, tokenizer, split=script_args.mode, max_length=script_args.max_length)
    # debug = False
    # if debug:
    #     dataset_prepared = dataset_prepared.select(range(0, 100))
    # assert len(dataset_prepared) == len(gold_scores)
    # dataset_gold_score = dataset_prepared.map(replace_with_gold_reward)
    # dataset_gold_score = dataset_gold_score.remove_columns(["unique_id"])


if __name__ == "__main__":
    start_time = time.time()
    print("Starting...")
    get_uf_gold_scores()
    print("Finished.")
    print(f"Took {time.time() - start_time:.2f} s")
