import torch
from tqdm import tqdm
import numpy as np
from baukit import TraceDict
from utils import evaluate_utils
from utils.logging_utils import logger


def get_mean_head_activation(model_wrapper, dataset):
    """
    Get the whole mean head activations for a given model and dataset.
    :param model_wrapper: the model wrapper
    :param tokenizer: the corresponding tokenizer
    :param dataset: untokenized dataset
    :return: mean_activations Shape: (layers, heads, tokens, head_dim)
    """

    # get model config
    model_config = model_wrapper.model_config

    def split_activations_by_head(activations, model_config):
        # input shape: (batch_size, seq_len, n_heads * head_dim)
        new_shape = activations.size()[:-1] + (
            model_config['n_heads'],
            model_config['head_dim'])  # Some models' head_dim may not be equal to hidden_size // n_heads

        # output shape: (batch_size, seq_len, n_heads, head_dim)
        activations = activations.view(*new_shape)
        return activations

    prepend_bos_flag = model_config['prepend_bos']
    # get logic labels sequence length
    if model_config['prepend_bos']:
        logic_seq_len = len(dataset[0]['prompt']) + 1
    else:
        logic_seq_len = len(dataset[0]['prompt'])

    # allocate activation storage
    activation_storage = torch.zeros(
        len(dataset),
        model_config['n_layers'],
        model_config['n_heads'],
        logic_seq_len,
        model_config['head_dim']
    )  # Shape: (batch_size, n_layers, n_heads, logic_seq_len, head_dim)

    for n in tqdm(range(len(dataset)), desc="Extracting mean activations", leave=False):
        activations_td, word_idx = extract_attn_activations(
            tokens=dataset[n]['prompt'],
            layers=model_config['attn_hook_names'],
            model_wrapper=model_wrapper
        )  # activations_td[layer].input Shape: (batch_size=1, seq_len, n_heads * head_dim)

        # Map the tokens' activations to the corresponding word indices
        stack_initial = torch.vstack([split_activations_by_head(activations_td[layer].input, model_config) for layer in
                                      model_config['attn_hook_names']]).permute(0, 2, 1,
                                                                                3)  # Shape: (n_layers, n_heads, seq_len, head_dim)

        # Create a new tensor with logic slots to store the activations
        n_layers, n_heads, token_seq_len, head_dim = stack_initial.shape
        stack_filtered = torch.zeros((n_layers, n_heads, logic_seq_len, head_dim), device=stack_initial.device,
                                     dtype=stack_initial.dtype)

        i = 0

        for j in range(logic_seq_len):
            # <bos> slot 0
            if prepend_bos_flag and j == 0:
                stack_filtered[:, :, j, :] = stack_initial[:, :, 0, :].view(n_layers, n_heads, head_dim)
                i += 1
                continue

            # Record the start and end index of the segment
            start = i
            while i < token_seq_len and (
                    (word_idx[i] + 1 == j) if prepend_bos_flag
                    else word_idx[i] == j):
                i += 1
            end = i

            if start == end:
                segment = stack_initial[:, :, start:start + 1, :]
            else:
                # Take the mean of the activations for the segment
                segment = stack_initial[:, :, start:end, :]  # Shape: (n_layers, n_heads, end-start, head_dim)
            stack_filtered[:, :, j, :] = segment.mean(dim=2)  # Shape: (n_layers, n_heads, head_dim)

        # The n-th example activations
        activation_storage[n] = stack_filtered

    mean_activations = activation_storage.mean(dim=0)  # Shape: (n_layers, n_heads, logic_seq_len, head_dim)
    return mean_activations


def extract_attn_activations(tokens, layers, model_wrapper):
    """
    Extract the attention activations for a given model and dataset.
    :param tokens:
    :param layers: the layers to extract activations from
    :param model_wrapper:
    :return: the input activations of the proj layers
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    model = model_wrapper.model

    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors='pt').to(device)
    word_idx = inputs.word_ids(batch_index=0)

    with torch.no_grad():
        # Access Activations
        # Keep inputs of proj layers
        with TraceDict(model, layers=layers, retain_input=True, retain_output=False) as td:
            model(**inputs)  # Shape: (batch_size, seq_len, vocab_size)

    return td, word_idx


def extract_ablated_attn_activations(tokens, model_wrapper):
    """
    Update the edge-ablated attention activations for a given model and dataset.
    :param tokens:
    :param layers: the layers to extract activations from
    :param model_wrapper:
    :return: the attention activations per head Shape: (layers, batch_size, seq_len, head_dim)
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    model = model_wrapper.model

    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors='pt').to(device)
    word_idx = inputs.word_ids(batch_index=0)

    tp_inds = align_sub_tokens_with_token_types(model_wrapper, inputs.input_ids,
                                                word_idx)  # Shape: (batch_size, seq_len)
    model_wrapper.set_token_indices(tp_inds)

    with torch.no_grad():
        ablated_output = model(**inputs, use_cache=False).logits[:, -1, :]  # Shape: (batch_size, seq_len, vocab_size)

    return model_wrapper.attention_head_activations, word_idx, ablated_output


def get_function_vector(model_wrapper, mean_activations, mean_indirect_effect, top_k=10):
    """
        Computes a "function vector" vector that communicates the task observed in ICL examples used for downstream intervention.

        Parameters:
        mean_activations: the average activation of each head for a particular task. Shape: (layers, heads, logic_seq_len, head_dim)
        mean_indirect_effect: the mean indirect effect of each head across N trials. Shape: tensor of size (Layers, Heads)
        model_wrapper: huggingface model wrapper
        top_k: the number of heads to use when computing the function vector

        Returns:
        function_vector: vector representing the communication of a particular task. Shape: (1, Hidden Size)
        top_heads: list of the top influential heads represented as tuples [(L,H,S), ...], (L=Layer, H=Head, S=Average
        Probability Improvement. Indirect Effect Score)
    """
    model = model_wrapper.model
    model_config = model_wrapper.model_config
    hidden_size = model_config['hidden_size']  # residual stream width
    n_heads = model_config['n_heads']
    head_dim = model_config['head_dim']
    device = model_wrapper.device

    # Compute top influential heads
    h_shape = mean_indirect_effect.shape  # Shape: (Layers, Heads)
    topk_vals, topk_inds = torch.topk(mean_indirect_effect.view(-1), k=top_k, largest=True)

    topk_inds = topk_inds.cpu().numpy()  # Convert to numpy for unravel_index

    # Transform the topk indices back to layer and head indices
    layer_idxs, head_idxs = np.unravel_index(topk_inds, h_shape)

    top_lh = list(zip(
        layer_idxs,
        head_idxs,
        [round(x.item(), 4)  # corresponding indirect effect score
         for x in topk_vals]
    ))

    # Make sure extract the top_k influential heads
    top_heads = top_lh[:top_k]

    # Compute Function Vector as sum of influential heads
    function_vector = torch.zeros((1, 1, hidden_size)).to(device)  # Shape: (1, 1, hidden_size)
    T = -1  # Intervention & values taken from last token

    # Depending on different model architectures, the out projection layer may vary
    # W_O Shape: (hidden_size, head_dim * n_heads)
    for L, H, _ in top_heads:
        out_proj = model.model.layers[L].self_attn.o_proj

        x = torch.zeros(n_heads * head_dim)  # Shape: (n_heads * head_dim,)
        # Fill the Hth head with the mean activation in the corresponding position for the last token
        x[H * head_dim: (H + 1) * head_dim] = mean_activations[L, H, T]
        # Project the head activations to the residual stream
        d_out = out_proj(x.reshape(1, 1, n_heads * head_dim).to(device).to(model.dtype))  # y = x @ W_O.T

        function_vector += d_out

    function_vector = function_vector.to(model.dtype)
    function_vector = function_vector.reshape(1, hidden_size)  # Shape: (1, hidden_size)

    return function_vector, top_heads


def get_ablated_mean_head_activation(model_wrapper, training_dataset, corrupted_dataset, sub_FV_name=None):
    """
    Get the whole edge-ablated mean head activations for a given model and dataset.
    :param model_wrapper: the model wrapper
    :param tokenizer: the corresponding tokenizer
    :param training_dataset: the clean dataset where to extract the mean head activations
    :param corrupted_dataset: the corrupted dataset
    :param sub_FV_name: current ablation settings
    :return: mean_activations Shape: (layers, heads, tokens, head_dim)
    """

    # get model config
    model_config = model_wrapper.model_config

    def split_activations_by_head(activations, model_config):
        # input shape: (batch_size, seq_len, n_heads * head_dim)
        new_shape = activations.size()[:-1] + (
            model_config['n_heads'],
            model_config['head_dim'])  # Some models' head_dim may not be equal to hidden_size // n_heads

        # output shape: (batch_size, seq_len, n_heads, head_dim)
        activations = activations.view(*new_shape)
        return activations

    prepend_bos_flag = model_config['prepend_bos']
    # get logic labels sequence length
    if model_config['prepend_bos']:
        logic_seq_len = len(training_dataset[0]['prompt']) + 1
    else:
        logic_seq_len = len(training_dataset[0]['prompt'])

    # allocate activation storage
    activation_storage = torch.zeros(
        len(training_dataset),
        model_config['n_layers'],
        model_config['n_heads'],
        logic_seq_len,
        model_config['head_dim']
    )  # Shape: (batch_size, n_layers, n_heads, logic_seq_len, head_dim)

    ablated_output_list = []

    for n in tqdm(range(len(training_dataset)), desc="Extracting mean activations", leave=False):
        model_wrapper.attention_scores.append(dict())

        # Store the original activations if needed
        if 'k_all' in sub_FV_name or 'v_all' in sub_FV_name:
            update_original_activations(model_wrapper, training_dataset[n]['prompt'])

        # Update the model wrapper with the corrupted activations
        update_corrupted_activations(model_wrapper, corrupted_dataset[n]['prompt'])

        # The corrupted prompt and the clean prompt has the same sequence length and same token type construction
        # Thus there is no need to reset the token indices
        activations_td, word_idx, ablated_output = extract_ablated_attn_activations(
            tokens=training_dataset[n]['prompt'],
            model_wrapper=model_wrapper
        )  # Shape: (n_layers, batch_size=1, seq_len, n_heads * head_dim)

        ablated_output_list.append(ablated_output.cpu())

        # Map the tokens' activations to the corresponding word indices
        stack_initial = torch.vstack([split_activations_by_head(activations_td[layer], model_config) for layer in
                                      range(model_config['n_layers'])]).permute(0, 2, 1,
                                                                                3)  # Shape: (n_layers, n_heads, seq_len, head_dim)

        # Create a new tensor with logic slots to store the activations
        n_layers, n_heads, token_seq_len, head_dim = stack_initial.shape
        stack_filtered = torch.zeros((n_layers, n_heads, logic_seq_len, head_dim), device=stack_initial.device,
                                     dtype=stack_initial.dtype)

        i = 0

        for j in range(logic_seq_len):
            # <bos> slot 0
            if prepend_bos_flag and j == 0:
                stack_filtered[:, :, j, :] = stack_initial[:, :, 0, :].view(n_layers, n_heads, head_dim)
                i += 1
                continue

            # Record the start and end index of the segment
            start = i
            while i < token_seq_len and (
                    (word_idx[i] + 1 == j) if prepend_bos_flag
                    else word_idx[i] == j):
                i += 1
            end = i

            if start == end:
                segment = stack_initial[:, :, start:start + 1, :]
            else:
                # Take the mean of the activations for the segment
                segment = stack_initial[:, :, start:end, :]  # Shape: (n_layers, n_heads, end-start, head_dim)
            stack_filtered[:, :, j, :] = segment.mean(dim=2)  # Shape: (n_layers, n_heads, head_dim)

        # The n-th example activations
        activation_storage[n] = stack_filtered

    # Evaluate ablated output
    # ablate_success_rate = evaluate_utils.evaluate_ablated_output_results(model_wrapper, training_dataset,
    #                                                                      ablated_output_list)
    ablate_success_rate = 0.0

    mean_activations = activation_storage.mean(dim=0)  # Shape: (n_layers, n_heads, logic_seq_len, head_dim)
    return mean_activations, activation_storage, ablate_success_rate


def update_corrupted_activations(model_wrapper, tokens):
    """
    Update the corrupted activations of the model wrapper with the corrupted prompt.
    :param model_wrapper: the model wrapper
    :param tokens: the tokens of a corrupted prompt
    :return: None
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    model = model_wrapper.model

    # Switch to corrupted run mode to get the corrupted activations
    model_wrapper.is_corrupted_run = True

    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors='pt').to(device)
    word_idx = inputs.word_ids(batch_index=0)  # Shape: (seq_len,)

    tp_inds = align_sub_tokens_with_token_types(model_wrapper, inputs.input_ids,
                                                word_idx)  # Shape: (batch_size, seq_len)
    model_wrapper.set_token_indices(tp_inds)

    with torch.no_grad():
        model(**inputs, use_cache=False)

    # Reset the corrupted run flag
    model_wrapper.is_corrupted_run = False


def update_original_activations(model_wrapper, tokens):
    """
    Update the original activations of the model wrapper with the base prompt.
    :param model_wrapper: the model wrapper
    :param tokens: the tokens of a base prompt
    :return: None
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    model = model_wrapper.model

    # Switch to original run mode to get the original activations
    model_wrapper.is_original_run = True

    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors='pt').to(device)
    word_idx = inputs.word_ids(batch_index=0)  # Shape: (seq_len,)

    tp_inds = align_sub_tokens_with_token_types(model_wrapper, inputs.input_ids,
                                                word_idx)  # Shape: (batch_size, seq_len)
    model_wrapper.set_token_indices(tp_inds)

    with torch.no_grad():
        model(**inputs, use_cache=False)

    # Reset the corrupted run flag
    model_wrapper.is_original_run = False


def align_sub_tokens_with_token_types(model_wrapper, input_ids, word_idx):
    """
    Align the sub-tokens with the token types in the model wrapper.
    :param model_wrapper:
    :param input_ids: the input ids of the tokens Shape: (batch_size, seq_len)
    :param word_idx: the word indices of the tokens Shape: (seq_len,)
    :return: tp_inds: the token type indices aligned with the sub-tokens Shape: (batch_size, seq_len)
    """
    tp_inds = torch.zeros_like(input_ids, dtype=torch.int, device=input_ids.device)
    token_type_map = model_wrapper.token_type_map  # The mapping from word_id to token type

    # logger.info(
    #     f"tp_inds shape: {tp_inds.shape}, token_type_map length: {len(token_type_map)}, word_idx length: {len(word_idx)}, input_ids shape: {input_ids.shape}")
    # logger.info(input_ids)
    # logger.info(f"word_idx")
    # logger.info(word_idx)

    for i, word_id in enumerate(word_idx):
        if word_id is not None:
            # check if word_id is within the bounds of token_type_map
            if 0 <= word_id < len(token_type_map):
                # set sub-token's token type to the corresponding word's token type
                tp_inds[0, i] = token_type_map[word_id + 1]
            else:
                # if word_id is out of bounds, set to 0
                tp_inds[0, i] = 0
        else:
            # for <bos> and other special tokens
            tp_inds[0, i] = 0

    # logger.info("tp_inds")
    # logger.info(tp_inds)

    return tp_inds