
import torch
from transformers.generation.logits_process import LogitsProcessor

class EllipticalLogitsProcessor(LogitsProcessor):
    def __init__(
        self, 
        cov_inv: torch.Tensor, 
        model: torch.nn.Module, 
        beta: float, 
        temp: float, 
        normalize_bonuses_per_step: bool = True,
        normalize_bonuses_per_sequence: bool = False,
        center_hidden_states_per_step: bool = True,
        batch_size: int = 64,
        hidden_mean: torch.Tensor = None,
        hidden_mean_counter: int = 0,
        sparse_matrix: torch.Tensor = None,
    ):
        self.cov_inv = cov_inv
        self.model = model
        self.beta = beta
        self.temp = temp
        self.normalize_bonuses_per_step = normalize_bonuses_per_step
        self.normalize_bonuses_per_sequence = normalize_bonuses_per_sequence
        self.center_hidden_states_per_step = center_hidden_states_per_step
        self.batch_size = batch_size
        self.hidden_mean = hidden_mean
        self.hidden_mean_counter = hidden_mean_counter
        self.sparse_matrix = sparse_matrix
        
        # state
        self.scaled_normalized_bonuses = []
        self.normalized_bonuses = []
        self.bonuses = []
        self.logits = []

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # TODO: code below is wrong if using batching

        # Get the valid token ids
        valid_token_ids = torch.nonzero(scores > float('-inf'), as_tuple=True)[1]

        # Append each valid token id to a separate copy of the input_ids using tensor operations
        input_ids_repeated = input_ids.repeat(len(valid_token_ids), 1)
        input_ids_expanded = torch.cat([input_ids_repeated, valid_token_ids.unsqueeze(1)], dim=1)

        # Forward through model to get hidden states
        # Process in batches to avoid OOM
        num_batches = (input_ids_expanded.shape[0] + self.batch_size - 1) // self.batch_size
        last_hidden_states_list = []
        
        with torch.no_grad():
            for i in range(num_batches):
                start_idx = i * self.batch_size
                end_idx = min((i + 1) * self.batch_size, input_ids_expanded.shape[0])
                batch_input_ids = input_ids_expanded[start_idx:end_idx]
                batch_attention_mask = torch.ones_like(batch_input_ids)
                
                batch_output = self.model.model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
                batch_last_hidden = batch_output.last_hidden_state[:, -1, :].float()
                last_hidden_states_list.append(batch_last_hidden)
        
        last_hidden_states = torch.cat(last_hidden_states_list, dim=0)

        if self.sparse_matrix is not None:
            last_hidden_states = last_hidden_states @ self.sparse_matrix

        # update to float64 after potential sparse projection
        last_hidden_states = last_hidden_states.to(torch.float64)

        if self.center_hidden_states_per_step:
            last_hidden_states = last_hidden_states - self.hidden_mean

        # Compute bonus for each valid token
        cov_inv_mean_adjusted = self.cov_inv - (self.cov_inv @ self.hidden_mean.unsqueeze(1) @ self.hidden_mean.unsqueeze(0) @ self.cov_inv) / (-1/self.hidden_mean_counter + self.hidden_mean.t() @ self.cov_inv @ self.hidden_mean)
        batch_cov_inv = cov_inv_mean_adjusted.unsqueeze(0).expand(last_hidden_states.shape[0], -1, -1)
        bonuses = (last_hidden_states.unsqueeze(1) @ batch_cov_inv @ last_hidden_states.unsqueeze(2)).flatten()

        if torch.any(bonuses < 0):
            print('Bonuses are negative!')
            print((last_hidden_states.unsqueeze(1) @ batch_cov_inv @ last_hidden_states.unsqueeze(2)).flatten())

        bonuses = torch.sqrt(torch.clamp(bonuses, min=0.0))
        
        # Optionally normalize the bonuses
        if self.normalize_bonuses_per_step:
            mean = torch.mean(bonuses)
            std = torch.std(bonuses) if len(bonuses) > 1 else 1.0
            normalized_bonuses = (bonuses - mean) / (std + 1e-8)
        else:
            normalized_bonuses = bonuses

        scaled_normalized_bonuses = self.beta * normalized_bonuses / self.temp

        # cast back to float32 to be able to add to logits
        scaled_normalized_bonuses = scaled_normalized_bonuses.to(torch.float32)

        # Add scaled bonus to the scores
        processed_scores = scores.clone()
        processed_scores[:, valid_token_ids] += scaled_normalized_bonuses

        # Update state
        valid_token_ids = valid_token_ids.cpu().tolist()
        bonuses = bonuses.cpu().tolist()
        scaled_normalized_bonuses = scaled_normalized_bonuses.cpu().tolist()
        normalized_bonuses = normalized_bonuses.cpu().tolist()
        logits = scores[:, valid_token_ids].cpu().flatten().tolist()
        self.bonuses.append({token_id: bonus for token_id, bonus in zip(valid_token_ids, bonuses)})
        self.normalized_bonuses.append({token_id: bonus for token_id, bonus in zip(valid_token_ids, normalized_bonuses)})
        self.scaled_normalized_bonuses.append({token_id: bonus for token_id, bonus in zip(valid_token_ids, scaled_normalized_bonuses)})
        self.logits.append({token_id: logit for token_id, logit in zip(valid_token_ids, logits)})

        # Return the processed scores
        return processed_scores