import gc

import numpy as np
import torch
from loguru import logger
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

from meta_alignment.config import TrainingConfig
from meta_alignment.constant import STOP_WORDS
from meta_alignment.dataset import get_dataset
from meta_alignment.inference import generate_completions
from meta_alignment.reward_funcs import (
    get_reward_func_from_classifier,
)


class EvalConfig(TrainingConfig):
    beta: float = 0.1
    model: str = "alpaca7b"
    task: str = "hh"
    step: int = 1000
    eval_size: int = 1000
    batch_size: int = 128
    true_reward_dir_base: str = "results/hh/models/qwen14b"

    @property
    def checkpoint_dir(self) -> str:
        return f"{self.log_dir}/checkpoint-{self.step}"


def batch_eval(prompts, completions, reward_func, batch_size=128):
    rewards = []
    for i in tqdm(range(0, len(prompts), batch_size)):
        batch_rewards = reward_func(
            prompts=prompts[i : i + batch_size],
            completions=completions[i : i + batch_size],
        )
        rewards.extend(batch_rewards)
    return rewards


def get_rewards(helpful_dir, harmless_dir, add_eos):
    reward_tokenizer = AutoTokenizer.from_pretrained(
        helpful_dir,
    )
    reward_model_helpful = AutoModelForSequenceClassification.from_pretrained(
        helpful_dir,
        device_map="auto",
        dtype=torch.bfloat16,
    )
    reward_model_harmless = AutoModelForSequenceClassification.from_pretrained(
        harmless_dir,
        device_map="auto",
        dtype=torch.bfloat16,
    )

    reward_helpful = get_reward_func_from_classifier(
        reward_model=reward_model_helpful,
        reward_tokenizer=reward_tokenizer,
        model_name="helpful",
        add_eos=add_eos,
    )
    reward_harmless = get_reward_func_from_classifier(
        reward_model=reward_model_harmless,
        reward_tokenizer=reward_tokenizer,
        model_name="harmless",
        add_eos=add_eos,
    )
    return reward_helpful, reward_harmless


def evaluate(args):
    # Set device
    _, eval_dataset = get_dataset(args, eval_size=args.eval_size)
    prompts = [item["prompt"] for item in eval_dataset for _ in range(args.n)]
    # Generate completions
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_dir)
    ## Model completions
    completions = generate_completions(
        args.checkpoint_dir,
        prompts,
        max_completion_length=args.max_completion_length,
        stop_words=STOP_WORDS,
    )
    completion_texts = [
        tokenizer.decode(ids, skip_special_tokens=True) for ids in completions
    ]
    # Prepare reward models
    reward_helpful, reward_harmless = get_rewards(
        helpful_dir="results/hh/models/qwen4b-helpful",
        harmless_dir="results/hh/models/qwen4b-harmless",
        add_eos=True,
    )
    true_reward_helpful, true_reward_harmless = get_rewards(
        helpful_dir=args.true_reward_dir_base + "-helpful",
        harmless_dir=args.true_reward_dir_base + "-harmless",
        add_eos=True,
    )

    # Evaluate rewards
    logger.info("Evaluating rewards...")
    helpful_rewards = batch_eval(
        prompts, completion_texts, reward_helpful, batch_size=args.batch_size
    )
    harmless_rewards = batch_eval(
        prompts, completion_texts, reward_harmless, batch_size=args.batch_size
    )
    true_helpful_rewards = batch_eval(
        prompts, completion_texts, true_reward_helpful, batch_size=args.batch_size
    )
    true_harmless_rewards = batch_eval(
        prompts, completion_texts, true_reward_harmless, batch_size=args.batch_size
    )
    mean_helpful_rewards = []
    mean_harmless_rewards = []
    mean_true_helpful_rewards = []
    mean_true_harmless_rewards = []
    bon_harmless_rewards = []
    bon_helpful_rewards = []
    bon_true_helpful_rewards = []
    bon_true_harmless_rewards = []
    batch_helpful_rewards_list = []
    batch_harmless_rewards_list = []
    batch_true_helpful_rewards_list = []
    batch_true_harmless_rewards_list = []

    for i in tqdm(range(0, len(prompts), args.n)):
        logger.info(f"Prompt: {prompts[i : i + args.n]}")
        logger.info(f"Model completion: {completion_texts[i : i + args.n]}")
        logger.info(f"Helpful rewards: {helpful_rewards[i : i + args.n]}")
        logger.info(f"Harmless rewards: {harmless_rewards[i : i + args.n]}")

        ## Evaluate helpfulness
        batch_helpful_rewards = helpful_rewards[i : i + args.n]
        batch_harmless_rewards = harmless_rewards[i : i + args.n]
        batch_true_helpful_rewards = true_helpful_rewards[i : i + args.n]
        batch_true_harmless_rewards = true_harmless_rewards[i : i + args.n]
        batch_helpful_rewards_list.append(batch_helpful_rewards)
        batch_harmless_rewards_list.append(batch_harmless_rewards)
        batch_true_helpful_rewards_list.append(batch_true_helpful_rewards)
        batch_true_harmless_rewards_list.append(batch_true_harmless_rewards)

        ## Compute mean rewards
        mean_helpful_rewards.append(batch_helpful_rewards[0])
        mean_harmless_rewards.append(batch_harmless_rewards[0])
        mean_true_helpful_rewards.append(batch_true_helpful_rewards[0])
        mean_true_harmless_rewards.append(batch_true_harmless_rewards[0])

        ## Compute BoN rewards
        most_helpful_response_idx = np.argmax(batch_helpful_rewards)
        most_harmless_response_idx = np.argmax(batch_harmless_rewards)

        bon_helpful_rewards.append(batch_helpful_rewards[most_helpful_response_idx])
        bon_harmless_rewards.append(batch_harmless_rewards[most_harmless_response_idx])
        bon_true_helpful_rewards.append(
            batch_true_helpful_rewards[most_helpful_response_idx]
        )
        bon_true_harmless_rewards.append(
            batch_true_harmless_rewards[most_harmless_response_idx]
        )

    logger.info("=== Reward evaluation results ===")
    logger.info(f"Mean helpfulness: {np.mean(mean_helpful_rewards)}")
    logger.info(f"BoN mean helpfulness: {np.mean(bon_helpful_rewards)}")
    logger.info(f"Mean harmlessness: {np.mean(mean_harmless_rewards)}")
    logger.info(f"BoN mean harmlessness: {np.mean(bon_harmless_rewards)}")

    # Free up GPU memory
    del reward_helpful
    del reward_harmless
    del true_reward_helpful
    del true_reward_harmless
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    return {
        "helpful": mean_helpful_rewards,
        "harmless": mean_harmless_rewards,
        "true_helpful": mean_true_helpful_rewards,
        "true_harmless": mean_true_harmless_rewards,
        "bon_helpful": bon_helpful_rewards,
        "bon_harmless": bon_harmless_rewards,
        "bon_true_helpful": bon_true_helpful_rewards,
        "bon_true_harmless": bon_true_harmless_rewards,
        "batch_helpful": batch_helpful_rewards_list,
        "batch_harmless": batch_harmless_rewards_list,
        "batch_true_helpful": batch_true_helpful_rewards_list,
        "batch_true_harmless": batch_true_harmless_rewards_list,
    }


if __name__ == "__main__":
    args = EvalConfig().parse_args()
    evaluate(args)
