import os
import numpy as np

from typing import Dict

from lm_polygraph.estimators.estimator import Estimator
    
class RAUQ(Estimator):
    def __init__(
        self,
        alpha: float = 0.2,
        n_layers: int = 32,
        n_heads: int = 32,
        all_layers: bool = False,
        aggregation: str = "mean",
        token_aggregation: str = "meanmin",
        head: str = "max",
        ablation: str = None,
        save_eval: bool = False,
        parameters_path: str = "",
        print_alpha: bool = False,
    ):
        super().__init__(["attention_features_values", "greedy_log_likelihoods"], "sequence")
        self.alpha = alpha
        self.token_aggregation = token_aggregation
        self.aggregation = aggregation
        
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.head = head
                
        self.all_layers = all_layers
        self.ablation = ablation
    
        self.layers = list(range(n_layers)) if self.all_layers else list(range(n_layers//3, int(np.ceil(n_layers/3 * 2) + 1)))
        self.print_alpha = print_alpha
        
        self.save_eval = save_eval
        self.eval_index = 0
        if len(parameters_path):
            self_name = self.__str__().replace(" ", "_")
            self.full_path = f"{parameters_path}/{self_name}"
            os.makedirs(self.full_path, exist_ok=True)
        
    def __str__(self):
        method_desc = ""
        method_desc += f" {self.aggregation}_{self.token_aggregation}_{self.head}"
        if self.ablation is not None:
            method_desc += f" {self.ablation}"
        if self.all_layers:
            method_desc += f" all_layers"
        if self.print_alpha:
            method_desc += f" {self.alpha:.2f}"
        return f"RAUQ{method_desc}"
        
    def attention_selection(self, attentions, j, layer, head):
        if self.head == "max":
            attn = attentions[j-1, layer, head]
        elif self.head == "mean":
            attn = attentions[j-1, layer].mean(-1)
        else:
            raise NotImplementedError
        return attn

    def tokens_aggregation(self, conf_scores, attentions, log_probabilities, layer, head):
        if self.ablation == "multiply_uq":
            return np.mean(log_probabilities) * np.mean(np.log(attentions[:, layer, head]))
        elif self.ablation == "sum_uq":
            return -(np.mean(log_probabilities) + np.mean(np.log(attentions[:, layer, head])))
        
        if self.token_aggregation == "meanmin":
            uq = 1 - (np.mean(conf_scores) + np.min(conf_scores)) / 2
        elif self.token_aggregation == "mean":
            uq = 1 - np.mean(conf_scores)
        elif self.token_aggregation == "min":
            uq = 1 - np.min(conf_scores)
        elif self.token_aggregation == "median":
            uq = 1 - np.median(conf_scores)
        elif self.token_aggregation == "meanlog":
            uq = 1 - np.log(conf_scores).mean()
        elif self.token_aggregation == "sumlog":
            uq = 1 - np.log(conf_scores).sum()
        else:
            raise NotImplementedError
        return uq

    def layers_aggregation(self, uq_scores_layers):
        if self.aggregation == "mean":
            uq = np.mean(uq_scores_layers)
        elif self.aggregation == "median":
            uq = np.median(uq_scores_layers)
        elif self.aggregation == "max":
            uq = np.max(uq_scores_layers)
        elif self.aggregation == "meanmax":
            uq = (np.mean(uq_scores_layers) + np.max(uq_scores_layers)) / 2
        elif self.aggregation == "medianmax":
            uq = (np.median(uq_scores_layers) + np.max(uq_scores_layers)) / 2
        else:
            raise NotImplementedError
        return uq
    
    def ablation_formula(self, p_j, p_jm1, logprob_jm1, attn):
        if self.ablation == "simple_rec":
            score = self.alpha * p_j + (1 - self.alpha) * attn * np.exp(logprob_jm1)  
        elif self.ablation == "no_rec":
            score = self.alpha * p_j + (1 - self.alpha) * attn
        elif self.ablation == "no_attn":
            score = self.alpha * p_j + (1 - self.alpha) * p_jm1
        elif self.ablation == "multiply":
            score = p_j * attn
        else:
            score = 0 # aggregation of sequence-level scores
        return score

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        
        attention_features_values = stats[f"attention_features_values"]
        attention_features_values = [np.array(item) for sublist in attention_features_values for item in sublist]
        greedy_log_likelihoods = stats["greedy_log_likelihoods"]
        
        k = 0
        uq_scores = []
        for idx in range(len(greedy_log_likelihoods)):
            
            attentions = np.array([attention_features_values[ind][0] for ind in range(k, k+len(greedy_log_likelihoods[idx])-1)]) # zero means use of the attention only on previous token
            attentions = attentions.reshape(-1, self.n_layers, self.n_heads)
            log_probabilities = greedy_log_likelihoods[idx]
            
            uq_scores_layers = []
                
            for layer in self.layers:
                p_i = [np.exp(log_probabilities[0])]
                head = attentions.mean(0)[layer].argmax() # select the most attentive head
                
                for j in range(1, len(log_probabilities)):                
                    p_j = np.exp(log_probabilities[j])
                    p_jm1 = p_i[-1]
                    
                    attn = self.attention_selection(attentions, j, layer, head)
                    if self.ablation is None:
                        conf = self.alpha * p_j + (1 - self.alpha) * attn * p_jm1 
                    else:
                        conf = self.ablation_formula(p_j, p_jm1, log_probabilities[j-1], attn)
                    p_i.append(conf)
                
                uq = self.tokens_aggregation(p_i, attentions, log_probabilities, layer, head)
                uq_scores_layers.append(uq)
                
            uq_scores.append(self.layers_aggregation(uq_scores_layers))
            k += len(log_probabilities) - 1
            
            if self.save_eval:
                np.save(f'{self.full_path}/attentions_{self.eval_index}.npy', np.array(attentions))
                np.save(f'{self.full_path}/log_probs_{self.eval_index}.npy', np.array(greedy_log_likelihoods))
                self.eval_index += 1
            
        return np.array(uq_scores)