import torch
import einops
import re
import numpy as np
from collections import defaultdict


def attribution_patching(model, components, example_batch, metric_fn):
    batch_size = len(example_batch)
    num_layers = model.cfg.n_layers

    clean_tokens_batch = []
    patch_tokens_batch = []

    for example in example_batch:
        clean_tokens = model.tokenizer.encode(example["example"], return_tensors="pt")
        patch_tokens = model.tokenizer.encode(example["counterexample"], return_tensors="pt")
        clean_tokens_batch.append(clean_tokens)
        patch_tokens_batch.append(patch_tokens)

    clean_tokens_batch = torch.cat(clean_tokens_batch, dim=0)
    patch_tokens_batch = torch.cat(patch_tokens_batch, dim=0)

    _, clean_cache = model.run_with_cache(clean_tokens_batch)
    _, patch_cache = model.run_with_cache(patch_tokens_batch)

    clean_grad_cache = {}
    patch_grad_cache = {}
    def grad_cache_hook(cache):
        def hook(act, hook):
            cache[hook.name] = act
        return (lambda _: True, hook)
    with model.hooks(bwd_hooks=[grad_cache_hook(clean_grad_cache)]):
        metric = metric_fn(model(clean_tokens_batch))
        metric.backward()
    with model.hooks(bwd_hooks=[grad_cache_hook(patch_grad_cache)]):
        metric = metric_fn(model(patch_tokens_batch))
        metric.backward()

    results = []

    for component in components:
        # Assuming token length is consistent across batch
        token_length = clean_tokens_batch.shape[1]
        act_shape = clean_cache[f"blocks.0.{component}"].shape
        if(len(act_shape) == 3): 
            n_heads = 1
        else:
            n_heads = model.cfg.n_heads

        scores = defaultdict(lambda: torch.zeros(token_length, num_layers, n_heads))
        n_scores = defaultdict(lambda: 0)

        for cache_component in clean_cache.keys():
            if component in cache_component:

                clean_act = clean_cache[cache_component].to("cpu")
                patch_act = patch_cache[cache_component].to("cpu")
                clean_grad = clean_grad_cache[cache_component].to("cpu")
                patch_grad = patch_grad_cache[cache_component].to("cpu")
            
                dim = len(clean_grad.shape)
                if dim == 3:
                    effect = einops.einsum(
                        (0.5*clean_grad + 0.5*patch_grad),
                        (patch_act - clean_act),
                        "batch pos d_model, batch pos d_model -> batch pos",
                    ).unsqueeze(-1)
                elif dim == 4:
                    effect = einops.einsum(
                        (0.5*clean_grad + 0.5*patch_grad),
                        (patch_act - clean_act),
                        "batch pos n_heads d_model, batch pos n_heads d_model -> batch pos n_heads"
                    )

                layer_idx = int(re.findall(r'\d+', cache_component)[0])

                for batch_idx in range(len(example_batch)):
                    patch_attribute = example_batch[batch_idx]["counter-attribute"]
                        
                    scores[patch_attribute][:,layer_idx,:] += effect[batch_idx]
                    n_scores[patch_attribute] += 1

        for key in scores.keys():
            scores[key] = scores[key]/n_scores[key]
    
        results.append({
            "component":component,
            "scores":scores
        })

    return results


def edge_attribution_patching(model, upstream_components, downstream_components, example_batch, metric_fn):
    batch_size = len(example_batch)
    num_layers = model.cfg.n_layers

    clean_tokens_batch = []
    patch_tokens_batch = []
    clean_value_positions_batch = []
    patch_value_positions_batch = []


    for example in example_batch:
        clean_tokens = model.tokenizer.encode(example["example"], return_tensors="pt")
        patch_tokens = model.tokenizer.encode(example["counterexample"], return_tensors="pt")
        clean_tokens_batch.append(clean_tokens)
        patch_tokens_batch.append(patch_tokens)
        clean_value_positions_batch.append(example["value-positions"])
        patch_value_positions_batch.append(example["counter-value-positions"])

    clean_tokens_batch = torch.cat(clean_tokens_batch, dim=0)
    patch_tokens_batch = torch.cat(patch_tokens_batch, dim=0)

    _, clean_cache = model.run_with_cache(clean_tokens_batch)
    _, patch_cache = model.run_with_cache(patch_tokens_batch)

    clean_grad_cache = {}
    patch_grad_cache = {}
    def grad_cache_hook(cache):
        def hook(act, hook):
            cache[hook.name] = act
        return (lambda _: True, hook)
    with model.hooks(bwd_hooks=[grad_cache_hook(clean_grad_cache)]):
        metric = metric_fn(model(clean_tokens_batch))
        metric.backward()
    with model.hooks(bwd_hooks=[grad_cache_hook(patch_grad_cache)]):
        metric = metric_fn(model(patch_tokens_batch))
        metric.backward()

    results = []

    for upstream_component in upstream_components:
        for downstream_component in downstream_components:

            token_length = clean_tokens_batch.shape[1]
            downstream_act_shape = clean_cache[f"blocks.0.{downstream_component}"].shape
            upstream_act_shape = clean_cache[f"blocks.0.{upstream_component}"].shape
            if(len(downstream_act_shape) == 3): 
                downstream_n_heads = 1
            else:
                downstream_n_heads = downstream_act_shape[2]
            if(len(upstream_act_shape) == 3): 
                upstream_n_heads = 1
            else:
                upstream_n_heads = upstream_act_shape[2]

            scores = defaultdict(lambda: torch.zeros(num_layers, token_length, upstream_n_heads, num_layers, token_length, downstream_n_heads))
            n_scores = defaultdict(lambda: defaultdict(lambda: 0))

            for upstream_layer_id in range(num_layers):
                if(("resid_pre" in upstream_components or "resid_mid" in upstream_components) and "mlp" in downstream_component):
                    offset = 0
                elif("resid_pre" in upstream_components and ("q_input" in downstream_component or "k_input" in downstream_component or "v_input" in downstream_component)):
                    offset = 0
                else:
                    offset = 1
                for downstream_layer_id in range(upstream_layer_id+offset, num_layers):
            
                    upstream_hook_point = f"blocks.{upstream_layer_id}.{upstream_component}"
                    downstream_hook_point = f"blocks.{downstream_layer_id}.{downstream_component}"

                    clean_act = clean_cache[upstream_hook_point].to("cpu")
                    patch_act = patch_cache[upstream_hook_point].to("cpu")
                    clean_grad = clean_grad_cache[downstream_hook_point].to("cpu")
                    patch_grad = patch_grad_cache[downstream_hook_point].to("cpu")

                    if clean_act.ndim == 3:
                        clean_act = clean_act.unsqueeze(-2)
                        patch_act = patch_act.unsqueeze(-2)
                    if clean_grad.ndim == 3:
                        clean_grad = clean_grad.unsqueeze(-2)
                        patch_grad = patch_grad.unsqueeze(-2)

                    effect = einops.einsum(
                        (patch_act - clean_act),
                        (0.5*clean_grad + 0.5 * patch_grad),
                        "batch pos1 n_heads1 d_model, batch pos2 n_heads2 d_model -> batch pos1 n_heads1 pos2 n_heads2",
                    )

                    for target_pos in range(len(clean_tokens_batch[0])):
                        for source_pos in range(target_pos+1):
                            source_target_pos_effect = effect[:,source_pos,:,target_pos,:]

                            for batch_idx in range(len(example_batch)):
                                patch_attribute = example_batch[batch_idx]["counter-attribute"]
                                    
                                scores[patch_attribute][upstream_layer_id,source_pos,:,downstream_layer_id,target_pos,:] += source_target_pos_effect[batch_idx]
                                n_scores[patch_attribute][(upstream_layer_id,source_pos,downstream_layer_id,target_pos)] += 1
            
            for key1 in scores.keys():
                for key2 in n_scores[key1].keys():
                    upstream_layer_id,source_pos,downstream_layer_id,target_pos = key2
                    scores[key1][upstream_layer_id,source_pos,:,downstream_layer_id,target_pos,:] = scores[key1][upstream_layer_id,source_pos,:,downstream_layer_id,target_pos,:]/n_scores[key1][key2]

            results.append({
                "upstream-component":upstream_component,
                "downstream-component":downstream_component,
                "scores":scores
            })

    return results


def eap_scores_to_graph(k, eap_scores):

    graph_strings = []

    for upstream_component, downstream_component in eap_scores.keys():
        
        scores = eap_scores[(upstream_component, downstream_component)]
        _, _, upstream_num_heads, num_layers, token_length, downstream_num_heads = scores.shape
        
        # Flatten the scores array
        flattened_scores = scores.flatten()
        
        # Get the indices of the top-k absolute scores
        if k < len(flattened_scores):
            top_k_indices = np.argpartition(-np.abs(flattened_scores), k)[:k]
        else:
            top_k_indices = np.arange(len(flattened_scores))

        # Sort the top k indices by their actual values
        top_k_indices = top_k_indices[np.argsort(-np.abs(flattened_scores[top_k_indices]))]

        for idx in top_k_indices:
            # Unravel the flattened index back into the original multi-dimensional indices
            upstream_layer_idx, token1_idx, head1_idx, downstream_layer_idx, token2_idx, head2_idx = np.unravel_index(idx, scores.shape)

            score = scores[upstream_layer_idx, token1_idx, head1_idx, downstream_layer_idx, token2_idx, head2_idx].item()
            
            if upstream_num_heads == 1:
                upstream_string = f"blocks.{upstream_layer_idx}.pos.{token1_idx}.{upstream_component}"
            else:
                upstream_string = f"blocks.{upstream_layer_idx}.pos.{token1_idx}.{upstream_component}.head{head1_idx}"
            
            if downstream_num_heads == 1:
                downstream_string = f"blocks.{downstream_layer_idx}.pos.{token2_idx}.{downstream_component}"
            else:
                downstream_string = f"blocks.{downstream_layer_idx}.pos.{token2_idx}.{downstream_component}.head{head2_idx}"
            
            score_string = f" -- {score:.4f} --> "
            graph_strings.append(
                upstream_string + score_string + downstream_string
            )

    # Extract the k largest elements from the heap and sort them by absolute score
    graph_strings.sort(key=lambda x: abs(float(x.split('--')[1])), reverse=True)

    return graph_strings[:k]



def ap_scores_to_graph(k, component, scores):
    token_length, num_layers, n_heads = scores.shape
    graph_strings = []

    for layer_idx in range(num_layers):
        for token_idx in range(token_length):
            for head_idx in range(n_heads):
                score = scores[token_idx, layer_idx, head_idx].item()
                if score != 0:  # Only consider non-zero scores
                    if n_heads == 1:
                        graph_strings.append(
                            f"blocks.{layer_idx}.{component}.pos.{token_idx} -- {score:.4f}"
                        )
                    else:
                        graph_strings.append(
                            f"blocks.{layer_idx}.{component}.pos.{token_idx}.head{head_idx} -- {score:.4f}"
                        )

    # Sort the strings by absolute value of score in descending order and select the top-k
    graph_strings = sorted(graph_strings, key=lambda x: abs(float(x.split('--')[1])), reverse=True)
    return graph_strings[:k]