import os
import pickle

import torch
from accelerate import Accelerator
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from trl import PPOTrainer


def transfer_template_rm(prompt, response, tokenizer, rm_tokenizer):
    # transfer template from gemma to current tokenizer
    if "gemma" in tokenizer.name_or_path:
        prompt = prompt.replace("<bos>", "")
        response = response.replace("<eos>", "")

        prompt_lis = prompt.split("<start_of_turn>user\n")[1:]
        messages = []
        for promp in prompt_lis:
            res = promp.split("<start_of_turn>model\n")
            if len(res) == 2 and len(res[1]):
                query, reply = res[0], res[1]
                query, reply = query.replace("<end_of_turn>\n", ""), reply.replace("<end_of_turn>\n", "")
                messages.extend([{"content": query, "role": "user"}, {"content": reply, "role": "assistant"}])
            else:
                query = res[0]
                query = query.replace("<end_of_turn>\n", "")
                messages.append(
                    {"content": query, "role": "user"},
                )
        prompt_trans = rm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return prompt_trans, response
    else:
        raise NotImplementedError


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


def eval_model(
    ppo_trainer: PPOTrainer,
    eval_dataset: Dataset,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    accelerator: Accelerator,
    eval_batch_size: int,
    results_dir: str,
    name: str,
    eval_generation_kwargs={},
):
    full_prompts = []
    full_response_tensors = []
    kl1_list, kl2_list, kl3_list = [], [], []
    full_source_ids, full_id_ids = [], []

    eval_data_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, drop_last=False, collate_fn=collator)
    eval_data_loader = accelerator.prepare(eval_data_loader)

    pbar = tqdm(total=len(eval_dataset) // eval_batch_size // accelerator.num_processes)  # type: ignore
    with torch.no_grad():
        for i, batch in enumerate(eval_data_loader):
            query_tensors = batch["input_ids"]
            response_tensors: Tensor = ppo_trainer.generate(query_tensors, return_prompt=False, **eval_generation_kwargs)  # type: ignore
            full_response_tensors.extend(response_tensors)
            full_prompts.extend(batch["input_ids"])

            model_inputs = ppo_trainer.prepare_model_inputs(query_tensors, response_tensors)
            all_logprobs, _, _, masks = ppo_trainer.batched_forward_pass(
                ppo_trainer.model,
                query_tensors,
                response_tensors,
                model_inputs,
                return_logits=False,
            )
            with ppo_trainer.optional_peft_ctx():
                ref_logprobs, _, _, _ = ppo_trainer.batched_forward_pass(
                    ppo_trainer.model if ppo_trainer.is_peft_model else ppo_trainer.ref_model,  # type: ignore
                    query_tensors,
                    response_tensors,
                    model_inputs,
                    return_logits=False,
                )
            diff = (all_logprobs - ref_logprobs) * masks
            kl1 = diff.sum(dim=-1)
            kl2 = (0.5 * diff.square()).sum(dim=-1)
            kl3 = (diff).abs().sum(dim=-1)

            kl1_list.extend([x.item() for x in kl1])
            kl2_list.extend([x.item() for x in kl2])
            kl3_list.extend([x.item() for x in kl3])

            if "source" in batch.keys():
                full_source_ids.extend(batch["source"])
            if "id" in batch.keys():
                full_id_ids.extend(batch["id"])
            pbar.update(1)

    full_prompts = tokenizer.batch_decode(full_prompts)
    full_responses = tokenizer.batch_decode(full_response_tensors)
    if "source" in batch.keys():
        full_source_ids = [x for x in full_source_ids]
    if "id" in batch.keys():
        full_id_ids = [x for x in full_id_ids]

    accelerator.wait_for_everyone()
    all_prompts = accelerator.gather_for_metrics(full_prompts)
    all_responses = accelerator.gather_for_metrics(full_responses)
    all_kl1_list = accelerator.gather_for_metrics(kl1_list)
    all_kl2_list = accelerator.gather_for_metrics(kl2_list)
    all_kl3_list = accelerator.gather_for_metrics(kl3_list)
    if "source" in batch.keys():
        all_source_ids = accelerator.gather_for_metrics(full_source_ids)
    if "id" in batch.keys():
        all_id_ids = accelerator.gather_for_metrics(full_id_ids)

    if accelerator.is_main_process:
        evaluation_result = {
            "prompts": all_prompts,
            "responses": all_responses,
            "kl1": all_kl1_list,
            "kl2": all_kl2_list,
            "kl3": all_kl3_list,
        }
        if "source" in batch.keys():
            evaluation_result["source_ids"] = all_source_ids
        if "id" in batch.keys():
            evaluation_result["id_ids"] = all_id_ids
        with open(os.path.join(results_dir, f"eval_outputs_{name}.csv"), "wb") as f:
            pickle.dump(evaluation_result, f)
