import numpy as np 
import datasets 
import torch 
from typing import Optional, Callable, Dict, Tuple


def collate_fn_tokenizer(cfg, tokenizer, batch):
    tokenizer.padding_side = 'left'
    # queries = tokenizer([data['prompt'] for data in batch], padding=True, return_tensors='pt')['input_ids']
    query_ids = tokenizer.pad(
            {"input_ids": [data['prompt_ids'] for data in batch]},
            padding='max_length', 
            max_length=cfg.max_prompt_length,  
            return_tensors="pt"
        )['input_ids']
    tokenizer.padding_side='right'
    input_ids = tokenizer.pad(
            {"input_ids": [data['token_ids'] for data in batch]},
            padding='max_length', 
            max_length=cfg.max_response_length,  
            return_tensors="pt"
        )['input_ids']
    return {
        'input_ids': input_ids, 
        'queries' : query_ids,  
    }

def first_true_indices(bools: torch.Tensor, dtype=torch.long):
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
    return torch.min(zero_or_index, dim=-1).values

@torch.no_grad()
def get_reward(
    model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    attention_mask = query_responses != pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()  
    lm_backbone = getattr(model, model.base_model_prefix)
    input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
    output = lm_backbone(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=True,
        use_cache=False,  # otherwise mistral-based RM would error out‚
    )
    reward_logits = model.score(output.hidden_states[-1])
    sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
    # return reward_logits[
    #         torch.arange(reward_logits.size(0), device=reward_logits.device),
    #         sequence_lengths,
    #     ].squeeze(-1)
    return (
        reward_logits,
        reward_logits[
            torch.arange(reward_logits.size(0), device=reward_logits.device),
            sequence_lengths,
        ].squeeze(-1),
        sequence_lengths,
    )

@torch.no_grad()
def forward(
    model: torch.nn.Module,
    query_responses: torch.Tensor,
    pad_token_id: int,
    output_hidden_states: Optional[bool] = False, 
) -> torch.nn.Module:
    attention_mask = query_responses != pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()
    input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
    return model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=output_hidden_states,
    )