from dataclasses import dataclass, asdict
from email import policy
from typing import Text, List, Dict, Optional
from abc import ABC, abstractclassmethod

import torch
from accelerate import Accelerator
from transformers import PreTrainedTokenizerBase, PreTrainedModel

from inference_time_alignment.utils import (
    SFTDataMapFunc, 
    SFTDataCollatorWithPadding,
    prepare_input,
    get_batch_logps,
)


DEFAULT_PROMPT_TEMPLATE = "{raw_prompt}"


@dataclass
class ScorerInput:
    response: List[str]
    eos: List[bool]


@dataclass
class BaseScorer(ABC):
    
    @abstractclassmethod
    def __call__(self, input: ScorerInput) -> torch.Tensor:
        raise NotImplementedError


@dataclass
class ImplicitRewardScorer(BaseScorer):
    model: PreTrainedModel
    ref_model: PreTrainedModel
    tokenizer: PreTrainedTokenizerBase
    add_special_tokens: Optional[bool] = False
    model_prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE
    ref_model_prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE
    raw_prompt: Optional[str] = ''
    beta: Optional[bool] = 1.0
    average_log_prob: Optional[bool] = False
    reference_free: Optional[bool] = False
    label_pad_token_id: Optional[int] = -100

    def set_raw_prompt(self, raw_prompt):
        self.raw_prompt = raw_prompt
        return self

    @torch.no_grad()
    def __call__(self, input: ScorerInput) -> torch.Tensor:
        policy_all_logps = self.forward(
            self.model, 
            self.model_prompt_template, 
            input
        )
        if self.reference_free: return self.beta * policy_all_logps
        ref_all_logps = self.forward(
            self.ref_model, 
            self.ref_model_prompt_template, 
            input
        )
        return self.beta * (policy_all_logps - ref_all_logps)

    @torch.no_grad()
    def forward(
        self, 
        model: PreTrainedModel, 
        prompt_template: Text, 
        input: ScorerInput | Dict
    ) -> torch.Tensor:
        input = asdict(input) if isinstance(input, ScorerInput) else input
        input["prompt"] = [prompt_template.format(raw_prompt=self.raw_prompt)] * len(input["response"])

        tokens = SFTDataMapFunc(tokenizer=self.tokenizer, add_special_tokens=self.add_special_tokens)(input)
        batch  = SFTDataCollatorWithPadding(tokenizer=self.tokenizer)(
            [{k:v[i] for k,v in tokens.items()} for i in range(len(input["response"]))]
        )
        batch = prepare_input(batch)

        all_logits = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
        ).logits.to(torch.float32)
        all_logps = get_batch_logps(
            all_logits,
            batch["labels"],
            average_log_prob=self.average_log_prob,
            label_pad_token_id=self.label_pad_token_id,
        )
        return all_logps


if __name__ == "__main__":
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from accelerate import Accelerator

    model = AutoModelForCausalLM.from_pretrained(
        "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo",
        torch_dtype=torch.bfloat16,
        device_map={"": Accelerator().local_process_index},
    )

    ref_model = AutoModelForCausalLM.from_pretrained(
        "/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb",
        torch_dtype=torch.bfloat16,
        device_map={"": Accelerator().local_process_index},
    )

    tokenizer = AutoTokenizer.from_pretrained("/mnt/petrelfs/share_data/llm-safety/models/gpt2-imdb-dpo")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    implicit_reward = ImplicitRewardScorer(
        model=model,
        ref_model=ref_model,
        tokenizer=tokenizer,
    )

    implicit_reward.set_raw_prompt("I think this movie is ")

    result = implicit_reward({"response": [" interesting", " boring"], "eos": [True, True]})

    print(result)
