import torch
import statistics

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def var_exp(model, tokenizer, prompts, truncate_token_length=150, eps=0.2, device=torch.device("cuda")):

    print("Measuring Attention Sinks and Variance Explained")

    try:
        num_layers = model.config.num_hidden_layers
        num_heads = model.config.num_attention_heads
        num_kv_heads = model.config.num_key_value_heads
    except:
        num_layers = model.config.text_config.num_hidden_layers
        num_heads = model.config.text_config.num_attention_heads
        num_kv_heads = model.config.text_config.num_key_value_heads
    kv_repeats = num_heads // num_kv_heads

    metric_1 = {}
    metric_2 = {}
    var_exp_stats = {}
    
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            metric_1[(layer_idx, head_idx)] =[]
            metric_2[(layer_idx, head_idx)] =[]
    
    for prompt in prompts:

        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        tokens = tokenizer.tokenize(prompt)
        if truncate_token_length is not None:
            tokens = tokens[:truncate_token_length]
            for key in inputs.keys():
                assert inputs[key].shape[1] >= truncate_token_length
                inputs[key] = inputs[key][:, :truncate_token_length]

        outputs = model.generate(
            **inputs,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=1
        )

        attentions = outputs['attentions']
        past_key_values = outputs["past_key_values"] # batch_size, num_heads, sequence_length, embed_size_per_head
        assert len(attentions) == 1
        attention_scores_all_layer = []
        values_all_layer = []
        
        for l in range(num_layers):
            attentions_layer = attentions[0][l]
            attention_scores_all_layer.append(attentions_layer)                           # (1, num_heads, num_tokens, num_tokens)
            values_all_layer.append(repeat_kv(past_key_values[l][1], kv_repeats))
        attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0).float() # (num_layers, num_heads, num_tokens, num_tokens)
        values_all_layer = torch.cat(values_all_layer, dim = 0).float()                   # (num_layers, num_heads, num_tokens, head_dim)
        
        token_length = attention_scores_all_layer.shape[-1]
        norm_factor = torch.arange(token_length, 0, -1).to(device)                        # num_tokens

        avg_row = torch.sum(attention_scores_all_layer/norm_factor.view(1,1,1, token_length), dim = -2)  # (num_layers, num_heads, num_tokens)
        avg_row[avg_row <= eps] = 0                                                                      # (num_layers, num_heads, num_tokens)
        sink_support = torch.sign(avg_row)                                                               # (num_layers, num_heads, num_tokens)

        for layer_idx in range(num_layers):
            for head_idx in range(num_heads):
                a_score   = attention_scores_all_layer[layer_idx, head_idx, : , :].contiguous().clone() # (num_tokens, num_tokens)
                values    = values_all_layer[layer_idx, head_idx, : , :].contiguous().clone()           # (num_tokens, head_feature_dim)
                sink_supp = sink_support[layer_idx, head_idx, :].contiguous().clone()                   # num_tokens
                num_sinks = torch.sum(sink_supp)                                                        # 1
                attended_v = torch.matmul(a_score, values)                                              # (seq_len, head_dim)

                metric_1[(layer_idx, head_idx)].append(num_sinks)

                if num_sinks > 0:
                    sink_tokens  = values[sink_supp.bool()] # num_of_sinks x head_dim
                    U, S, V = torch.svd(sink_tokens, compute_uv=True)
                    metric_2[(layer_idx, head_idx)].append((torch.linalg.norm(attended_v @ V @ V.T) / torch.linalg.norm(attended_v)).item()) # Metric 2 for variance explained by sinks
    
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            head_stats = {}
            head_stats["Number of Sinks"]    = statistics.fmean(metric_1[(layer_idx, head_idx)])
            head_stats["Explained Variance"] = statistics.fmean(metric_2[(layer_idx, head_idx)]) if len(metric_2[(layer_idx, head_idx)]) > 0 else 0
            head_stats["eps"]                = eps
            var_exp_stats[(layer_idx, head_idx)] = head_stats
    
    return var_exp_stats
    