import numpy as np 
from typing import Any, Dict, List, Optional, Union, Tuple
import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .base import *

SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"

def first_true_indices(bools: torch.Tensor, dtype=torch.long):
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
    return torch.min(zero_or_index, dim=-1).values


class RewardModule(BaseRewardModule): 
    def __init__(self, cfg, pad_token_id=None):
        super().__init__(cfg, torch_dtype=torch.bfloat16) 
        tokenizer = AutoTokenizer.from_pretrained(cfg.reward.model)
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        if tokenizer.pad_token is None:
            tokenizer.pad_token = "[PAD]"
        self.model.config.pad_token_id = tokenizer.pad_token_id
        self.pad_token_id = tokenizer.pad_token_id
        del tokenizer
    
    # @torch.no_grad()
    # def get_reward(self, batch):
    #     output = self.model(**batch)
    #     logits = output.logits
    #     scores = logits.squeeze(-1)
        # query_responses = batch['input_ids']
        # attention_mask = query_responses != self.pad_token_id
        # position_ids = attention_mask.cumsum(1) - attention_mask.long()  
        # lm_backbone = getattr(self.model, self.model.base_model_prefix)
        # input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
        # output = lm_backbone(
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     position_ids=position_ids,
        #     return_dict=True,
        #     output_hidden_states=True,
        #     use_cache=False, 
        # )
        # logits = self.model.score(output.hidden_states[-1]).squeeze(-1)
        # scores = logits[:, -1].cpu().tolist()
        # return scores


class RewardCollator(BaseCollator): 
    def __init__(self, cfg): 
        super().__init__(cfg)
        self.tokenizer.padding_side = "left"
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = "[PAD]"
        if self.tokenizer.chat_template is None:
            self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
        
    def _build_query_str(self, question): 
        query = self.task_desc
        if self.task_desc != "":
            query += self.sep * 2 
        query += self.format_q(question, n=0)
        return query
    
    def _format_query_response(self, query, response): 
        return [
            {'role': 'user', 'content': self._build_query_str(query)},
            {'role': 'assistant', 'content': self._build_response_str(response)}
        ]

    def format_query_response(self, outputs):
        qr = [self._format_query_response(output['prompt'], output['response']) for output in outputs]
        return self.tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False)

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        qr = self.format_query_response(outputs)
        return self.tokenizer(qr, return_tensors="pt", padding=True,)