import torch
import math
import matplotlib.pyplot as plt
import os
from data.data_utils import parse_city_prompt

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def sink_decomp(model, tokenizer, prompts, eps=0.2, eval=False, device=torch.device("cuda")):
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    num_kv_heads = model.config.num_key_value_heads
    kv_repeats = num_heads // num_kv_heads

    activation_with_tags = {}
    activation_no_tags = {}
    activations = {}

    sink_count = {}
    sink_hist = {}
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            sink_count[(layer_idx, head_idx)] =[]
            activation_with_tags[(layer_idx, head_idx)] =[]
            activation_no_tags[(layer_idx, head_idx)] =[]
            activations[(layer_idx, head_idx)] = []
            sink_hist[(layer_idx, head_idx)] = [[], []]

    for prompt in prompts:

        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        token_ids = inputs['input_ids']

        decoded_tokens = [tokenizer.decode([tok]) for tok in token_ids[0]]
        # Find the last token index containing "."
        period_indices = [i for i, tok in enumerate(decoded_tokens) if '.' in tok]
        last_period_idx = period_indices[-1] if period_indices else None

        preamble_tokens, city_tokens, country_tokens = parse_city_prompt(prompt, tokenizer)

        # Find indices for histogram
        outputs = model.generate(
            **inputs,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=1
        )

        attentions = outputs['attentions']
        past_key_values = outputs["past_key_values"] # batch_size, num_heads, sequence_length, embed_size_per_head
        assert len(attentions) == 1
        attention_scores_all_layer = []
        values_all_layer = []

        for l in range(num_layers):
            attentions_layer = attentions[0][l]
            attention_scores_all_layer.append(attentions_layer)           # (1, num_heads, num_tokens, num_tokens)
            values_all_layer.append(repeat_kv(past_key_values[l][1], kv_repeats))
        attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0).float() # (num_layers, num_heads, num_tokens, num_tokens)
        values_all_layer = torch.cat(values_all_layer, dim = 0).float()                   # (num_layers, num_heads, num_tokens, head_dim)

        token_length = attention_scores_all_layer.shape[-1]
        norm_factor = torch.arange(token_length, 0, -1).to(device)                        # num_tokens

        avg_row = torch.sum(attention_scores_all_layer/norm_factor.view(1,1,1, token_length), dim = -2)  # (num_layers, num_heads, num_tokens)
        avg_row[avg_row <= eps] = 0                                                                      # (num_layers, num_heads, num_tokens)
        sink_support = torch.sign(avg_row)                                                               # (num_layers, num_heads, num_tokens)


        for layer_idx in range(num_layers):
            for head_idx in range(num_heads):
                a_score   = attention_scores_all_layer[layer_idx, head_idx, : , :].contiguous().clone() # (num_tokens, num_tokens)
                values    = values_all_layer[layer_idx, head_idx, : , :].contiguous().clone()           # (num_tokens, head_feature_dim)
                sink_supp = sink_support[layer_idx, head_idx, :].contiguous().clone()                   # num_tokens
                num_sinks = torch.sum(sink_supp)                                                        # 1
                sink_count[(layer_idx, head_idx)].append(num_sinks)

                sink_only_score = a_score.contiguous().detach().clone()
                no_sink_score = a_score.contiguous().detach().clone()

                # Threshold the attention weights
                sink_only_score[:, sink_supp == 0] = 0
                no_sink_score[:, sink_supp != 0] = 0

                sink_only_output      = torch.matmul(sink_only_score, values)                                    # (num_tokens, feature_dim)
                no_sink_output        = torch.matmul(no_sink_score, values)
                output                = torch.matmul(a_score, values)

                feature_dim = sink_only_output.shape[-1]

                if num_sinks > 0:
                    activation_tag    = sink_only_output[last_period_idx, :].contiguous().clone()
                    activation_no_tag = no_sink_output[last_period_idx, :].contiguous().clone()
                    activation = output[last_period_idx, :].contiguous().clone()

                    activation_with_tags[(layer_idx, head_idx)].append(activation_tag)                                    # feature_dim
                    activation_no_tags[(layer_idx, head_idx)].append(activation_no_tag)
                    activations[(layer_idx, head_idx)].append(activation)

                    # Generate histogram of sinks
                    if len(sink_hist[(layer_idx, head_idx)][0]) == 0:
                        added_preamble = False
                        added_city = False
                        added_country = False
                        for j, tok_id in enumerate(token_ids.flatten().tolist()):
                            if j <= last_period_idx:
                                if j in city_tokens:
                                    if not added_city:
                                        sink_hist[(layer_idx, head_idx)][0].append("[CITY]")
                                        sink_hist[(layer_idx, head_idx)][1].append(0)
                                        added_city=True
                                elif j in country_tokens:
                                    if not added_country:
                                        sink_hist[(layer_idx, head_idx)][0].append("[COUNTRY]")
                                        sink_hist[(layer_idx, head_idx)][1].append(0)
                                        added_country=True
                                elif j in preamble_tokens:
                                    if not added_preamble:
                                        sink_hist[(layer_idx, head_idx)][0].append("[PREAMBLE]")
                                        sink_hist[(layer_idx, head_idx)][1].append(0)
                                        added_preamble=True
                                else:
                                    str_token = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(tok_id)])
                                    sink_hist[(layer_idx, head_idx)][0].append(str_token)
                                    sink_hist[(layer_idx, head_idx)][1].append(0)
                    
                    tok_idx = 0
                    added_city = False
                    added_country = False
                    added_preamble = False
                    for j, tok_id in enumerate(token_ids.flatten().tolist()):
                        if j <= last_period_idx:
                            if j in city_tokens:
                                if not added_city:
                                    if sink_supp[j] == 1:
                                        sink_hist[(layer_idx, head_idx)][1][tok_idx] += 1
                                    if j == city_tokens[-1]:
                                        added_city=True
                                        tok_idx += 1
                            elif j in country_tokens:
                                if not added_country:
                                    if sink_supp[j] == 1:
                                        sink_hist[(layer_idx, head_idx)][1][tok_idx] += 1
                                    if j == country_tokens[-1]:
                                        added_country=True
                                        tok_idx += 1
                            elif j in preamble_tokens:
                                if not added_preamble:
                                    if sink_supp[j] == 1:
                                        sink_hist[(layer_idx, head_idx)][1][tok_idx] += 1
                                    if j == preamble_tokens[-1]:
                                        added_preamble=True
                                        tok_idx += 1
                            else:
                                if sink_supp[j] == 1:
                                    sink_hist[(layer_idx, head_idx)][1][tok_idx] += 1
                                tok_idx += 1
                
                elif eval:
                    activation_with_tags[(layer_idx, head_idx)].append(torch.zeros(feature_dim, device=device))  

                    activation_no_tags[(layer_idx, head_idx)].append(output[last_period_idx, :].contiguous().clone())
                    activations[(layer_idx, head_idx)].append(output[last_period_idx, :].contiguous().clone())
                    
    return activation_with_tags, activation_no_tags, activations, sink_hist

def get_mass_mean_probes(model, tags_correct, tags_incorrect):
    probes = {}
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads

    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            if len(tags_correct[(layer_idx, head_idx)]) > 0 and len(tags_incorrect[(layer_idx, head_idx)]) > 0:
                cm_correct = torch.mean(torch.stack(tags_correct[(layer_idx, head_idx)]), dim=0)                # feature_dim
                cm_incorrect = torch.mean(torch.stack(tags_incorrect[(layer_idx, head_idx)]), dim=0)            # feature_dim
                theta_mm = cm_correct - cm_incorrect

                class_centered = []
                for correct_tag in tags_correct[(layer_idx, head_idx)]:
                    class_centered.append(correct_tag - cm_correct)
                for incorrect_tag in tags_incorrect[(layer_idx, head_idx)]:
                    class_centered.append(incorrect_tag - cm_incorrect)

                class_centered_data = torch.stack(class_centered)  # Shape: (n_samples, n_features)
                covariance = class_centered_data.T @ class_centered_data / (class_centered_data.shape[0])  # feature_dim x feature_dim

                try:
                    probes[(layer_idx, head_idx)] = (torch.linalg.pinv(covariance, hermitian=True, atol=1e-3) @ theta_mm)    # feature_dim
                except:
                    # If not invertible, skip the attention head.
                    probes[(layer_idx, head_idx)] = "skip"
            else:
                # If no tags exist, skip the attention head.
                probes[(layer_idx, head_idx)] = "skip"
    return probes

def calc_probe_acc(model, probes, val_tags_correct, val_tags_incorrect, eps, tags):
    probe_stats = {}
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            if probes[(layer_idx, head_idx)] != "skip":
                total_samples = len(val_tags_correct[(layer_idx, head_idx)]) + len(val_tags_incorrect[(layer_idx, head_idx)])
                correct_samples = 0
                stacked_correct_tags = torch.stack(val_tags_correct[(layer_idx, head_idx)], dim=0)     # n_samples x feature_dim
                stacked_incorrect_tags = torch.stack(val_tags_incorrect[(layer_idx, head_idx)], dim=0) # n_samples x feature_dim

                pred_correct = torch.sigmoid(stacked_correct_tags @ probes[(layer_idx, head_idx)])
                pred_incorrect = torch.sigmoid(stacked_incorrect_tags @ probes[(layer_idx, head_idx)])

                correct_samples += (torch.sum((pred_correct > 0.5).float()) + torch.sum((pred_incorrect < 0.5).float()))

                head_stats = {}
                head_stats["eps"] = eps
                head_stats["Probe Type"] = tags
                head_stats["Probe Accuracy"] = correct_samples.item()/total_samples
                probe_stats[(layer_idx, head_idx)] = head_stats
    
    return probe_stats

def gen_histogram(model, correct_hist, incorrect_hist, plot_path):
    os.makedirs(plot_path, exist_ok=True)
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads

    grid_cols = math.ceil(math.sqrt(num_heads))
    grid_rows = math.ceil(num_heads / grid_cols)

    for layer_idx in range(num_layers):

        fig, axes = fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(grid_cols * 5, grid_rows * 2.5))
        axes = axes.flatten()

        for head_idx in range(num_heads):
            c_labels = correct_hist[(layer_idx, head_idx)][0]
            ic_labels = incorrect_hist[(layer_idx, head_idx)][0]

            if len(c_labels) > 0 and len(ic_labels) > 0:
                unique_labels = []
                counts = {}
                for label in c_labels:
                    if label in counts:
                        counts[label] += 1
                        unique_labels.append(f"{label}_{counts[label]}")
                    else:
                        counts[label] = 0
                        unique_labels.append(label)

                sum_sinks = [a + b for a, b in zip(correct_hist[(layer_idx, head_idx)][1], incorrect_hist[(layer_idx, head_idx)][1])]

                ax = axes[head_idx]
                ax.bar(unique_labels, sum_sinks)
                ax.set_title(f"Head {head_idx}", fontsize=10)
                ax.tick_params(axis='x', labelrotation=45, labelsize=7)
        for j in range(num_heads, len(axes)):
            axes[j].axis('off')

        fig.suptitle(f"Layer {layer_idx}", fontsize=14)
        fig.tight_layout(rect=[0, 0, 1, 0.95])
        filepath = os.path.join(plot_path, "sink_hist_layer_" + str(layer_idx) + ".png")
        plt.savefig(filepath)
        plt.close()

    return