from typing import List, Optional
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

torch.set_grad_enabled(False)

class Mosaic(object):
    def __init__(self,
                 model_name_or_paths: List[str],
                 use_bfloat16: bool = True,
                 max_token_observed: int = 512,
                 unigram: Optional[str] = None,
                 custom_config : Optional[List[bool]] = None
                 ) -> None:
        self.models = []
        for i, model_name_or_path in enumerate(model_name_or_paths):
            model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                                         device_map="auto",
                                                         trust_remote_code=True,
                                                         torch_dtype=torch.bfloat16 if use_bfloat16
                                                         else torch.float32
                                                         )
            model.eval()  # Set the model to evaluation mode
            self.models.append(model)
            print(f"Loaded model: {model_name_or_path}")
            # Print the device map
            #print(f"Device map for {model_name_or_path}: {model.hf_device_map}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_paths[-1])
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.max_token_observed = max_token_observed

        self.nb_models = len(self.models)
        self.unigram_path = unigram

        if custom_config is None:
            custom_config = [False] * self.nb_models
        self.custom_config = custom_config

    def _tokenize(self, batch: list[str]) -> transformers.BatchEncoding:
        encodings = self.tokenizer(
            batch,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=self.max_token_observed,
            return_token_type_ids=False)
        return encodings
    

    @torch.inference_mode()
    def _get_logits(self, encodings: transformers.BatchEncoding) -> List[torch.Tensor]:
        logits_list = []
        for i, model in enumerate(self.models):
            device = next(model.parameters()).device  # Get the device of the current model
            model_encodings = encodings.to(device)  # Move encodings to the model's device
            logits = model(**model_encodings).logits
            logits_list.append(logits)
            if device.type == "cuda":
                torch.cuda.synchronize(device)
        return logits_list
    
    def get_softmax_probabilities(self, input_text):
        encodings = self._tokenize(input_text)
        logits_list = self._get_logits(encodings)
        probabilities_list = softmax_probabilities_all_models(logits_list)
        return encodings, logits_list, probabilities_list
    
    def compute_arimoto_torch(self, input_text):
        encodings, logits_list, tensors_list = self.get_softmax_probabilities(input_text)
        nb_models = len(tensors_list)
        seq_len = len(encodings.input_ids[0])
        voc_size = tensors_list[0].shape[-1]

        device = tensors_list[0].device
        # Move all tensors in tensors_list to the device of the first tensor
        tensors_list = [tensor.to(device) for tensor in tensors_list]

        # Stack all model predictions along a new dimension to form a (seq_len, nb_models, voc_size) tensor
        probabilities_tensor = torch.stack([t[0] for t in tensors_list], dim=1).to(tensors_list[0].device)

        # Run the Blahut-Arimoto algorithm on the entire batch
        capacity, p = blahut_arimoto_torch(probabilities_tensor)
        
        # Prepare the weighted sum tensor, initially zeros
        weighted_sum_tensor = torch.zeros_like(tensors_list[0])

        # Assuming 'p' is now (seq_len, nb_models), apply weights to each model's output
        for i in range(nb_models):
            weighted_sum_tensor += p[:, i:i+1] * tensors_list[i]

        return encodings, weighted_sum_tensor, tensors_list, p, logits_list
    
    def compute_scores(self, input_text):
        encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text)
        log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
        ppl_list = perplexity_all_models(encodings, logits_list)
        x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
        return log_ppl, x_ppl_list, arimoto_weights, nll, ppl_list
    
    def compute_avg_score(self, input_text):
        encodings, weighted_sum_tensor, probabilities_list, arimoto_weights, logits_list = self.compute_arimoto_torch(input_text)
        log_ppl, ppl, nll = perplexity(encodings, weighted_sum_tensor)
        x_ppl_list = cross_entropy(weighted_sum_tensor, probabilities_list)
        
        # Compute the adjusted NLL based on the average(x_ppl) per token
        adjusted_nll_avg, adjusted_nll = compute_adjusted_nll_avg(nll, x_ppl_list)
        
        return adjusted_nll_avg

def compute_adjusted_nll_avg(nll, x_ppl_list):
    """
    Compute (nll - mean(x_ppl)) for each token and then average over all tokens.
    
    Parameters
    ----------
    nll : tensor
        Negative log-likelihood tensor of shape (batch_size, seq_len).
    x_ppl_list : list of tensors
        List containing cross-entropy tensors from different models, each of shape (batch_size, seq_len).
    
    Returns
    -------
    adjusted_nll_avg : tensor
        Average adjusted NLL across all tokens (shape: (batch_size,)).
    adjusted_nll : tensor
        Adjusted NLL for each token (shape: (batch_size, seq_len)).
    """
    # Ensure the sequence lengths match by shifting x_ppl_list like nll
    x_ppl_list_shifted = [x[..., :1] for x in x_ppl_list]  # Shift each cross-entropy tensor
    # Stack the x_ppl tensors along a new dimension to compare across models
    # Resulting shape: (nb_models, batch_size, seq_len)
    x_ppl_stacked = torch.stack(x_ppl_list_shifted, dim=0)
    
    # Compute the average cross-entropy for each token across all models
    # Shape: (batch_size, seq_len)
    x_ppl_avg = torch.mean(x_ppl_stacked, dim=0)
    
    # Subtract mean(x_ppl) from nll for each token
    # Shape: (batch_size, seq_len)
    adjusted_nll = nll - x_ppl_avg
    
    # Average the result over all tokens (dimension 1)
    # Shape: (batch_size,)
    adjusted_nll_avg = torch.mean(adjusted_nll, dim=1)
    
    return adjusted_nll_avg, adjusted_nll
    
def perplexity(encodings, weighted_sum_tensor):
    shifted_probabilities = weighted_sum_tensor[..., :-1, :].contiguous()
    shifted_labels = encodings.input_ids[..., 1:].contiguous()
    shifted_attention_mask = encodings.attention_mask[..., 1:].contiguous()

    device = shifted_probabilities.device  # or any other tensor's device that you intend to use

    # Ensure all tensors are moved to the same device
    shifted_probabilities = shifted_probabilities.to(device)
    shifted_labels = shifted_labels.to(device)
    shifted_attention_mask = shifted_attention_mask.to(device)

    actual_next_token_probabilities = torch.gather(shifted_probabilities, 2, shifted_labels.unsqueeze(-1)).squeeze(-1)

    nll = -torch.log(actual_next_token_probabilities + 1e-12)
    nll_masked = nll * shifted_attention_mask

    # Calculate the average NLL per sequence, taking into account only the valid (non-padded) tokens
    average_nll = torch.sum(nll_masked, dim=1) / torch.sum(shifted_attention_mask, dim=1)

    # Calculate perplexity per sequence
    perplexity = torch.exp(average_nll)
    return average_nll, perplexity, nll_masked

def cross_entropy(weighted_sum_tensor, probabilities_list):
    device = weighted_sum_tensor.device
    x_ppl_list = []

    # Compute log of weighted_sum_tensor outside the loop since it doesn't depend on m2_probabilities
    log_M1 = torch.log(weighted_sum_tensor).to(device)
    #probabilities_list.append(weighted_sum_tensor)

    for m2_probabilities in probabilities_list:
        m2_probabilities = m2_probabilities.to(device)
        # Ensure m2_probabilities is correctly shaped for batch matrix multiplication
        # log_M1 shape is already (batch_size, sequence_length, vocabulary_size)
        # We need m2_probabilities in shape (batch_size, vocabulary_size, sequence_length) for bmm
        m2_probabilities_transposed = m2_probabilities.transpose(1, 2)
        
        # Perform batch matrix multiplication
        # Resulting shape: (batch_size, sequence_length, sequence_length)
        # We sum over the vocabulary dimension, effectively computing the dot product for each sequence position
        dot_products = torch.bmm(log_M1, m2_probabilities_transposed)
        
        # Since we're interested in the diagonal (dot products of corresponding vectors), we extract it
        # The diagonal for each item in the batch gives us the dot products we're interested in
        # torch.diagonal doesn't support batched operations directly, so we need to workaround
        dot_products_diagonal = torch.einsum('bii->bi', dot_products)  # Using einsum to extract diagonals for batch
        
        # Compute the mean of the dot_products_diagonal across the sequence dimension
        # This gives us the average dot product per sequence, which is then negated
        #x_ppl = -torch.mean(dot_products_diagonal, dim=1)
        x_ppl = -dot_products_diagonal
        
        x_ppl_list.append(x_ppl)
    x_ppl_tensor = torch.stack(x_ppl_list)
    return x_ppl_list #, x_ppl_tensor


def softmax_probabilities_all_models(logits_list: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Calculates the softmax probabilities for the entire sequence of tokens for each model.

    Parameters:
    - logits_list: List[torch.Tensor]
        A list containing the logits tensor for each model.

    Returns:
    - List[torch.Tensor]: A list of tensors, where each tensor is the softmax probabilities
      for one model across the entire sequence of tokens.
    """
    softmax_fn = torch.nn.Softmax(dim=-1)
    probabilities_list = []

    for logits in logits_list:
        # Calculate softmax probabilities across the vocabulary for each token position
        softmax_probabilities = softmax_fn(logits)
        probabilities_list.append(softmax_probabilities)

    return probabilities_list

def blahut_arimoto(W, epsilon=1e-3, info=1e4):
    """
    Performs the Blahut-Arimoto algorithm to compute the channel capacity
    given a channel W.

    Parameters
    ----------
    W : array-like 
        definition of the channel with C inputs and m outputs.
    epsilon : float.
        error tolerance for the algorithm to stop the iterations. The smaller
        epsilon is the more precise the rate-distortion function is, but also
        the larger the number of iterations the algorithm must perform
    info : int.
        Number indicating every how many cycles to print the cycle number as
        a visual output of the algorithm.
    Returns
    -------
    C : float.
        channel capacity, or the maximum information it can be transmitted 
        given the input-output function.
    p : array-like.
        array containing the discrete probability distribution for the input 
        that maximizes the channel capacity
    """
    # initialize the probability for the input.
    p = np.repeat(1 / W.shape[0], W.shape[0])

    # Initialize variable that will serve as termination criteria (upper bound of mutual information minus the lower bound)
    Iu_Il = 1

    loop_count = 0
    # Perform a while loop until the stopping criteria is reached
    while Iu_Il > epsilon:
        if (loop_count % info == 0) & (loop_count != 0):
            print("loop : {0:d}, Iu - Il : {1:f}".format(loop_count, Iu_Il))
        loop_count += 1
        # compute the relevant quantities. check the notes on the algorithm
        # for the interpretation of these quantities
        # prod_exp = exp(∑_m W log(W/ ∑_c p W))
        sum_p_w = np.sum((p * W.T).T, axis=0)
        W_log_W_sum_p_w = W * np.log(W / sum_p_w)
        # check for values that go to -inf because of 0xlog0
        W_log_W_sum_p_w[np.isnan(W_log_W_sum_p_w)] = 0
        W_log_W_sum_p_w[np.isneginf(W_log_W_sum_p_w)] = 0
        prod_exp = np.exp(np.sum(W_log_W_sum_p_w, axis=1))

        # I_L log(∑_C p prod_exp) lower bound 
        Il = np.log(np.sum(p * prod_exp))

        # I_U = log(max_C prod_exp) upper bound
        Iu = np.log(prod_exp.max())

        # p = p * prod_exp / ∑_C p * prod_exp
        p = p * prod_exp / np.sum(p * prod_exp)

        Iu_Il = Iu - Il

    # convert from nats to bits
    Il = Il / np.log(2)
    return Il, p, loop_count

def perplexity_logits(encoding, logits):
    # Ensure encoding tensors are moved to the same device as logits
    device = logits.device
    logits = torch.clamp(logits, min=-20, max=50)

    encoding_input_ids = encoding.input_ids.to(device)
    encoding_attention_mask = encoding.attention_mask.to(device)

    ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    shifted_logits = logits[..., :-1, :].contiguous()
    shifted_labels = encoding_input_ids[..., 1:].contiguous()
    shifted_attention_mask = encoding_attention_mask[..., 1:].contiguous()

    # Calculate Cross-Entropy loss
    cross_entropy_loss = ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels)
    # Apply attention mask
    masked_ce_loss = cross_entropy_loss * shifted_attention_mask
    # Calculate perplexity
    ppl = masked_ce_loss.sum(1) / shifted_attention_mask.sum(1)
    # Move result to CPU and convert to numpy for further processing if needed
    ppl = ppl.to("cpu").float().numpy()

    return ppl

def perplexity_all_models(encoding, logits_list):
    ppl_list = []
    for logits in logits_list:
        ppl = perplexity_logits(encoding, logits)
        ppl_list.append(ppl)
    return ppl_list


def blahut_arimoto_torch(W, epsilon=1e-6, max_iters=1000):
    """
    Batch-process Blahut-Arimoto using PyTorch for multiple sequences.
    """
    seq_len, nb_models, voc_size = W.shape
    p = torch.full((seq_len, nb_models), 1.0 / nb_models, device=W.device, dtype=W.dtype)

    for _ in range(max_iters):
        # Calculate the marginal probabilities
        sum_p_w = torch.bmm(p.unsqueeze(1), W).squeeze(1)  # Resultant shape: (seq_len, voc_size)

        # Calculate normalized probabilities
        W_normalized = W / sum_p_w.unsqueeze(1)  # Broadcasting to shape (seq_len, nb_models, voc_size)
        
        # Avoid numerical issues with logarithms
        W_normalized[W_normalized == 0] = torch.finfo(W.dtype).eps
        log_term = torch.log(W_normalized)
        log_term[torch.isnan(log_term) | torch.isinf(log_term)] = 0

        # Compute product exponentials and update probabilities
        prod_exp = torch.exp(torch.sum(W * log_term, axis=2))  # Sum across voc_size
        p_new = (p * prod_exp) / torch.sum(p * prod_exp, dim=1, keepdim=True)

        # Check convergence
        if torch.max(torch.abs(p - p_new)) < epsilon:
            break
        p = p_new

    # Compute channel capacity
    capacity = torch.log(torch.sum(p * prod_exp, dim=1)) / torch.log(torch.tensor(2.0, device=W.device))
    return capacity, p