import torch
from tqdm import tqdm
from baukit import TraceDict
from utils.logging_utils import logger
from utils.model_utils import get_module
from bitsandbytes import functional as bnb_func


def replace_activations_and_evaluate_logits_on_each_head(prompt, answer, mean_activations, model_wrapper,
                                                         last_token_only=True,
                                                         logic_seq_len=None):
    """
    Replace the activations at each head of each layer and evaluate the logits improvement.
    :param prompt:
    :param answer:
    :param mean_activations: Shape: (n_layers, n_heads, logic_seq_len, head_dim)
    :param model_wrapper:
    :param last_token_only:
    :param logic_seq_len:
    :return: indirect effect of each layer and head: tensor Shape: (n_layers, n_heads)
    """
    tokenizer = model_wrapper.tokenizer
    device = model_wrapper.device
    model_config = model_wrapper.model_config
    model = model_wrapper.model

    # Tokenize the prompt and get the word indices
    inputs = tokenizer(prompt, is_split_into_words=True, return_tensors='pt').to(device)
    word_idx = inputs.word_ids(batch_index=0)
    # Get real sequence length
    seq_len = len(word_idx)

    # Align word indices with the logic slots
    if word_idx[0] is None:
        word_idx = [0 if i == 0 else idx + 1 for i, idx in enumerate(word_idx)]

    # Allocate space for the indirect effect storage
    n_layers = model_config['n_layers']
    n_heads = model_config['n_heads']
    n_classes = 1 if last_token_only else logic_seq_len
    indirect_effect_storage = torch.zeros(n_layers, n_heads).to(device)  # Shape: (n_layers, n_heads)

    # Get the first token_id of the answer tokens
    answer_token_id = get_answer_token_id(answer, tokenizer).to(device)  # Type: torch.LongTensor

    # Calculate clean prompt probability baseline
    with torch.no_grad():
        clean_logits = model(**inputs).logits[:, -1, :]  # Shape: (1, vocab_size)
        clean_probs = torch.softmax(clean_logits, dim=-1)  # Shape: (1, vocab_size)
        max_probs, max_token_id = clean_probs.max(dim=-1)  # Get the max token ID and its probability
        logger.info(f"{prompt}:{answer}\n Answer token ID: {answer_token_id}")
        logger.info(
            f"Clean prompt probabilities: {clean_probs.index_select(1, answer_token_id)}, max token ID: {max_token_id}, max probability: {max_probs}")

    # Replace the activations at each head of each layer
    for layer in tqdm(range(n_layers), desc='Replacing activations Layer', leave=False):
        head_hook_layer = [model_config['attn_hook_names'][layer]]

        for head in tqdm(range(model_config['n_heads']), desc=f'Replacing activations Head', leave=False):
            tokens_idx = [-1] if last_token_only else list(range(seq_len))
            intervention_locations = [(layer, head, idx) for idx in tokens_idx]

            # Create a function to replace the activations
            intervention_fn = replace_activation_on_single_head(model_wrapper, intervention_locations, mean_activations,
                                                                word_idx)

            # Edit the output of the proj layer
            with TraceDict(model, layers=head_hook_layer, edit_output=intervention_fn) as td:
                output = model(**inputs).logits[:, -1, :]  # Shape: (1, 1, vocab_size)

            # Evaluate improvement of probs of answer tokens
            # Convert to probability distribution
            intervention_probs = torch.softmax(output, dim=-1)  # Shape: (1, vocab_size)
            max_probs, max_token_id = intervention_probs.max(dim=-1)  # Get the max token ID and its probability
            logger.info(
                f"{layer}:{head} - Intervention probabilities: {intervention_probs.index_select(1, answer_token_id)}, max token ID: {max_token_id}, max probability: {max_probs}")
            indirect_effect_storage[layer, head] = (intervention_probs - clean_probs).index_select(1,
                                                                                                   answer_token_id).squeeze()

    return indirect_effect_storage


def replace_activation_on_single_head(model_wrapper, intervention_locations, mean_activations, word_idx):
    """
    Create a function to replace the activations at a single head of each layer.
    :param model_wrapper:
    :param intervention_locations: replacement locations
    :param mean_activations: Shape: (n_layers, n_heads, logic_seq_len, head_dim)
    :param word_idx:
    :return: intervention function
    """
    edit_layers = [x[0] for x in intervention_locations]
    model_config = model_wrapper.model_config
    model = model_wrapper.model

    def rep_act(output, layer_name, inputs):
        # Extract current layer number
        current_layer = int(layer_name.split('.')[2])

        # Replace activations only in the edit layers
        if current_layer in edit_layers:
            if isinstance(inputs, tuple):
                inputs = inputs[0]

            original_shape = inputs.shape  # Shape: (batch_size, seq_len, n_heads * head_dim)

            # Split hidden_size into (n_heads, head_dim)
            new_shape = inputs.size()[:-1] + (
                model_config['n_heads'],
                model_config['head_dim']
            )  # Shape: (batch_size, seq_len, n_heads, head_dim)

            inputs = inputs.view(*new_shape)

            # ======== Begin Replace ========
            # Replace the activations at the specified locations in the intervention locations list
            for (layer, head_n, token_idx) in intervention_locations:
                if layer == current_layer:
                    inputs[-1, token_idx, head_n] = mean_activations[layer, head_n, word_idx[token_idx]]
            # ======== Finish Replace ========

            # Reshape back to original shape
            inputs = inputs.view(*original_shape)  # Shape: (batch_size, seq_len, n_heads * head_dim)

            # Get the current hooked attention projection module
            proj_module = get_module(model, layer_name)
            # y = x @ W_O.T
            out_proj = proj_module.weight  # W_O matrix Shape: (hidden_size, n_heads * head_dim)

            new_output = torch.matmul(inputs, out_proj.T)  # Shape: (batch_size, seq_len, hidden_size)

            # Return the new output
            return new_output
        else:
            # Directly return the original output if not in the edit layers
            return output

    return rep_act


def get_answer_token_id(answer, tokenizer):
    """
    Get the first valid token ID of the answer.
    :param answer: str
    :param tokenizer:
    :return: target token ID: torch.LongTensor
    """
    # Tokenize the answer and get the token ID
    answer_tokens = tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids[0]

    # Get the token ID of the answer
    target_token_id = answer_tokens[0].item()  # Get the first valid token ID

    return torch.LongTensor([target_token_id])


def inject_function_vector(model_wrapper, prompt, function_vector, target_layer=None, inject_mode='single_step'):
    """
    Inject the function vector into the model's attention heads.
    :param prompt: the input prompt to the model.
    :param model_wrapper: model wrapper containing the model and its configuration.
    :param function_vector: the function vector to inject.
    :param target_layer: the layer at which to inject the function vector.
    :param inject_mode: the mode of injection, either 'single_step' or 'full_sequence'.

    :return single_step mode: intervened and clean output logits for the next token prediction.
    :return full_sequence mode: output logits and generated token ID sequence.
    """
    model = model_wrapper.model
    model_config = model_wrapper.model_config
    device = model_wrapper.device
    tokenizer = model_wrapper.tokenizer

    # Ensure the function vector is on the correct device
    function_vector = function_vector.to(device)

    # Tokenize the prompt
    inputs = tokenizer(prompt, is_split_into_words=True, return_tensors='pt').to(device)

    assert target_layer is not None, "Target layer must be specified for function vector injection."

    # Create intervention function
    intervention_fn = add_function_vector(
        target_layer,
        function_vector,
        idx=-1  # Default to injecting at the last token
    )

    layer_name = model_config['layer_hook_names'][target_layer]

    clean_output = None
    if inject_mode == 'single_step':
        # Calculate clean prompt probability baseline
        with torch.no_grad():
            clean_output = model(**inputs).logits[:, -1, :]  # Shape: (1, vocab_size)

    with TraceDict(model, layers=[layer_name], edit_output=intervention_fn):
        if inject_mode == 'single_step':
            # Directly predict the next token logits with the function vector injected
            intervened_output = model(**inputs).logits[:, -1, :]  # Shape: (batch_size, vocab_size)
        elif inject_mode == 'full_sequence':
            MAX_NEW_TOKENS = 5

            # Generate a sequence with the function vector injected
            # output.scores is a tuple of logits for each token in the sequence
            # output.scores[i] Shape: (batch_size, vocab_size)
            # output.sequence is the generated sequence tokenIDs Shape: (batch_size, seq_len)
            intervened_output = model.generate(
                inputs.input_ids,
                top_p=0.9,
                temperature=0.1,
                max_new_tokens=MAX_NEW_TOKENS,
                return_dict_in_generate=True,
                output_scores=True
            )

    return intervened_output, clean_output


def add_function_vector(target_layer, function_vector, idx=-1):
    """
    Adds a vector to the output of a specified layer in the model
    :param target_layer: the target layer number to perform the FV intervention
    :param function_vector: the function vector to add as an intervention
    :param idx: the token index to add the function vector
    :return add_act: a function specifying how to add a function vector to a layer's residual stream
    """

    def add_act(output, layer_name):
        # Extract the target layer num from the layer name
        current_layer = int(layer_name.split(".")[2])

        # Inject at the specified layer
        if current_layer == target_layer:
            if isinstance(output, tuple):
                # output: tuple (hidden_size_tensor, **other_values)
                # Inject at the specified token index
                output[0][:, idx] += function_vector.to(dtype=output[0].dtype)
                return output
            else:
                return output
        else:
            # Directly return the original output if not in the edit layers
            return output

    return add_act