import functools
from typing import Callable

import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

from meta_alignment.config import TrainingConfig
from meta_alignment.reward_transformation import (
    get_BoN_linearized_reward_func,
    get_BoN_reward_func,
    get_soft_BoN_linearized_reward_func,
    get_soft_BoN_reward_func,
)


def get_reward_length_func(target_length: int, max_length: int) -> Callable:
    def reward_length_func(completion_ids, **kwargs):
        lengths = [len(ids) for ids in completion_ids]
        rewards = [
            1.0 - ((length - target_length) / max_length) ** 2 for length in lengths
        ]
        return rewards

    reward_length_func.__name__ = f"reward_length_target_{target_length}"
    return reward_length_func


def cache_reward_func(reward_func):
    @functools.cache
    def cached_func(prompts, completions):
        return reward_func(list(prompts), list(completions))

    def wrapper(prompts, completions, **kwargs):
        return cached_func(tuple(prompts), tuple(completions))

    return wrapper


def get_reward_func_from_classifier(
    reward_model: PreTrainedModel,
    reward_tokenizer: PreTrainedTokenizerBase,
    model_name: str,
    add_eos: bool = False,
) -> Callable:
    @cache_reward_func
    def reward_from_classifier(prompts, completions, **kwargs):
        if add_eos:
            completions = [text + reward_tokenizer.eos_token for text in completions]
        inputs = reward_tokenizer(
            prompts,
            completions,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=reward_tokenizer.model_max_length,
        )
        inputs = {k: v.to(reward_model.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = reward_model(**inputs)
            scores = outputs.logits.squeeze(-1).cpu().tolist()
        return scores

    reward_from_classifier.__name__ = f"reward_{model_name}"
    return reward_from_classifier


def get_reward_funcs(args: TrainingConfig) -> list[Callable]:
    match args.task:
        case "length":
            reward_length_50 = get_reward_length_func(
                target_length=50, max_length=args.max_completion_length
            )
            reward_length_200 = get_reward_length_func(
                target_length=150, max_length=args.max_completion_length
            )
            if args.bon_type == "bon":
                reward_funcs = [
                    reward_length_50,
                    get_BoN_reward_func(reward_length_50, n=args.n),
                    get_BoN_linearized_reward_func(reward_length_50, n=args.n),
                    reward_length_200,
                    get_BoN_reward_func(reward_length_200, n=args.n),
                    get_BoN_linearized_reward_func(reward_length_200, n=args.n),
                ]
            elif args.bon_type == "softbon":
                if args.tau is None:
                    raise ValueError("tau must be specified for softbon")
                reward_funcs = [
                    reward_length_50,
                    get_soft_BoN_reward_func(reward_length_50, tau=args.tau),
                    get_soft_BoN_linearized_reward_func(reward_length_50, tau=args.tau),
                    reward_length_200,
                    get_soft_BoN_reward_func(reward_length_200, tau=args.tau),
                    get_soft_BoN_linearized_reward_func(
                        reward_length_200, tau=args.tau
                    ),
                ]
        case "hh":
            reward_tokenizer = AutoTokenizer.from_pretrained(
                "results/hh/models/qwen4b-helpful",
            )
            reward_model_helpful = AutoModelForSequenceClassification.from_pretrained(
                "results/hh/models/qwen4b-helpful",
                device_map="auto",
                dtype=torch.bfloat16,
            )
            reward_model_harmless = AutoModelForSequenceClassification.from_pretrained(
                "results/hh/models/qwen4b-harmless",
                device_map="auto",
                dtype=torch.bfloat16,
            )

            reward_helpful = get_reward_func_from_classifier(
                reward_model=reward_model_helpful,
                reward_tokenizer=reward_tokenizer,
                model_name="helpful",
                add_eos=True,
            )
            reward_harmless = get_reward_func_from_classifier(
                reward_model=reward_model_harmless,
                reward_tokenizer=reward_tokenizer,
                model_name="harmless",
                add_eos=True,
            )
            reward_funcs = [
                reward_helpful,
                get_BoN_reward_func(reward_helpful, n=args.n),
                get_BoN_linearized_reward_func(reward_helpful, n=args.n),
                reward_harmless,
                get_BoN_reward_func(reward_harmless, n=args.n),
                get_BoN_linearized_reward_func(reward_harmless, n=args.n),
            ]
        case _:
            raise ValueError(f"Unknown task: {args.task}")
    return reward_funcs


def get_reward_weights(args: TrainingConfig) -> list[float]:
    if args.meta:
        weights = [
            0.0,
            0.0,
            args.weights[0],
            0.0,
            0.0,
            args.weights[1],
        ]
    else:
        weights = [
            args.weights[0],
            0.0,
            0.0,
            args.weights[1],
            0.0,
            0.0,
        ]
    return weights
