from collections import defaultdict
import torch
from tqdm import tqdm
from utils.setup import init_metric, get_logit_ids
from utils.patching import edge_attribution_patching, attribution_patching
import gc
import numpy as np
import copy
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.cluster import DBSCAN
from numpy.linalg import norm
import random


def get_cross_sections(model, get_examples, correct, incorrect, upstream_components, downstream_components, k=10, n_cross_sections=10, position_offset=0, incl_eap=True, incl_ap=True, aggr=False):

    avg_scores = defaultdict(lambda: defaultdict(lambda: []))
    results = defaultdict(lambda: [])

    example_batches = [get_examples(i) for i in range(k)]

    for example_batch in tqdm(example_batches, desc="Get cross sections"):      
        metric_fn = init_metric(model, example_batch, correct, incorrect)
        
        if(incl_eap):
            eap_scores = edge_attribution_patching(model, upstream_components, downstream_components, example_batch, metric_fn)
            for i in range(len(eap_scores)):
                upstream_component = eap_scores[i]["upstream-component"]
                downstream_component = eap_scores[i]["downstream-component"]
                scores = eap_scores[i]["scores"]

                _, _, upstream_num_heads, num_layers, token_length, downstream_num_heads = scores[list(scores.keys())[0]].shape

                for upstream_layer_idx in range(num_layers-1):
                    for token1_idx in range(token_length):
                        for head1_idx in range(upstream_num_heads):
                            for downstream_layer_idx in range(num_layers):
                                for token2_idx in range(token_length):
                                    for head2_idx in range(downstream_num_heads):
                                        if(upstream_num_heads == 1):
                                            head1_id = None
                                        else:
                                            head1_id = head1_idx
                                        if(downstream_num_heads == 1):
                                            head2_id = None
                                        else:
                                            head2_id = head2_idx

                                        for attr in scores.keys():

                                            score = scores[attr][upstream_layer_idx, token1_idx, head1_idx, downstream_layer_idx, token2_idx, head2_idx].item()

                                            if(aggr):
                                                avg_scores["aggr"][(upstream_component,upstream_layer_idx,token1_idx,head1_id,downstream_component,downstream_layer_idx,token2_idx,head2_id)].append(score)
                                            else:
                                                if(token1_idx >= position_offset):
                                                    avg_scores[attr][(upstream_component,upstream_layer_idx,token1_idx,head1_id,downstream_component,downstream_layer_idx,token2_idx,head2_id)].append(score)
        
        if(incl_ap):
            ap_scores = attribution_patching(model, upstream_components, example_batch, metric_fn)
            for i in range(len(ap_scores)):
                component = ap_scores[i]["component"]
                scores = ap_scores[i]["scores"]

                token_length, num_layers, num_heads = scores[list(scores.keys())[0]].shape

                for layer_id in range(num_layers-1):
                    for token_pos in range(token_length):
                        for head in range(num_heads):
                            if(num_heads == 1):
                                head_id = None
                            else:
                                head_id = head

                            for attr in scores.keys():
                                score = scores[attr][token_pos,layer_id,head].item()

                                if(aggr):
                                    avg_scores["aggr"][(component,layer_id,token_pos,head_id,component,layer_id+1,token_pos,head_id)].append(score)
                                else:
                                    if(token_pos >= position_offset):
                                        avg_scores[attr][(component,layer_id,token_pos,head_id,component,layer_id+1,token_pos,head_id)].append(score)


    n_attr = 0

    for attr in avg_scores.keys():
        n_attr += 1

        for key in avg_scores[attr].keys():

            upstream_component,upstream_layer_idx,token1_idx,head1_idx,downstream_component,downstream_layer_idx,token2_idx,head2_idx = key
            score = sum(avg_scores[attr][key])/len(avg_scores[attr][key])

            results[attr].append({
                "downstream-component":downstream_component,
                "downstream-layer":downstream_layer_idx,
                "downstream-position":token2_idx,
                "downstream-head":head2_idx,
                "upstream-component":upstream_component,
                "upstream-layer":upstream_layer_idx,
                "upstream-position":token1_idx,
                "upstream-head":head1_idx,
                "score": score,
                "score-attribute":attr
            })

        results[attr] = sorted(results[attr], key=lambda x: abs(x["score"]), reverse=True)[:n_cross_sections]
    
    cross_sections = []
    for attr in results.keys(): 
        cross_sections.extend(results[attr][:n_cross_sections//n_attr])

    return cross_sections


def get_feature_dicts(model, cross_sections, get_examples, k, features, method='mean', path=None, step=50):

    for i in range(0, len(cross_sections), step):
        torch.cuda.empty_cache()
        gc.collect()

        # Process the features in smaller chunks to minimize memory usage
        f = get_features(model, cross_sections[i:i+step], get_examples, k, features, method)
        
        if(path is not None):
            torch.save(f, f"{path}_{i//step}")
        
        # Clear memory after saving the features
        del f
        torch.cuda.empty_cache()
        gc.collect()

def compute_mse_features(C, A, S_i, N_A, mean_act, order_values, features):
    num_examples = len(A)
    num_features = sum(len(S) for S in S_i)  # Total number of attribute-value pairs

    # Initialize C as a (num_examples x num_features) matrix
    C_matrix = torch.zeros((num_examples, num_features))
    
    # Fill the C matrix for regular attributes and order
    feature_index = 0
    for attr in N_A:
        if attr == "order":
            for val in order_values:
                for n_ex in range(num_examples):
                    if C.get(("order", n_ex, val)) == 1:
                        C_matrix[n_ex, feature_index] = 1
                feature_index += 1
        else:
            for val in S_i:
                if val not in order_values:
                    for n_ex in range(num_examples):
                        if C.get((attr, n_ex, val)) == 1:
                            C_matrix[n_ex, feature_index] = 1
                feature_index += 1

    # Convert A to a (num_examples x d_model) matrix
    A_matrix = torch.stack([A[i] for i in range(num_examples)], dim=0) - mean_act

    # Compute (C^T C)^+
    C_T_C = C_matrix.T @ C_matrix
    C_T_C_pseudo_inverse = torch.pinverse(C_T_C)

    # Compute U^* = (C^T C)^+ C^T A for all attributes
    U_star = torch.matmul(torch.matmul(C_T_C_pseudo_inverse, C_matrix.T), A_matrix)

    # Build the feature dictionary
    feature_dict = defaultdict(dict)
    feature_index = 0

    for attr in N_A:
        if attr == "order":
            for val in order_values:
                feature_dict[attr][val] = U_star[feature_index]
                feature_index += 1
        else:
            for val in S_i:
                if val not in order_values:  # Exclude order values from non-order attributes
                    if(val in features):
                        feature_dict[attr][val] = U_star[feature_index]
                feature_index += 1
    
    del C_matrix, A_matrix, C_T_C, C_T_C_pseudo_inverse, U_star
    torch.cuda.empty_cache()
    gc.collect()

    return feature_dict

def get_features(model, cross_sections, get_examples, k, features, method='mean'):
    assert method in ['mean', 'mse'], "Method must be either 'mean' or 'mse'"

    res_cs = []

    # Initialize data structure for feature data
    feature_data = [
        defaultdict(lambda: defaultdict(lambda: {'mean': torch.zeros(model.cfg.d_model), 'count': 0})) 
        for _ in cross_sections
    ]

    overall_mean_data = [
        {'mean': torch.zeros(model.cfg.d_model), 'count': 0} for _ in cross_sections
    ]

    template_attr = list(set(get_examples(0)[0]["attributes"]))
    attr_val_dicts = []
    for i in range(k):
        attr = random.choice(template_attr)
        val = random.choice(features)
        attr_val_dict = {attr:val}
        attr_val_dicts.append(attr_val_dict)

    S_i = list(set([val for vals in [x["values"] + [x["order"]] for i in range(k) for x in get_examples(i,attr_val_dict=attr_val_dicts[i]) ] for val in vals]))
    N_A = template_attr + ["order"]
    order_values = list(set([x["order"] for i in range(k) for x in get_examples(i)]))

    batch_size = len(get_examples(0))

    C = [{attr: {n_ex: {val: 0 for val in S_i} for n_ex in range(k*batch_size)} for attr in N_A} for _ in cross_sections]
    A = [{n_ex: torch.zeros(model.cfg.d_model) for n_ex in range(k*batch_size)} for _ in cross_sections]

    for i in tqdm(range(k), desc="Processing batches"):

        batch = get_examples(i,attr_val_dict=attr_val_dicts[i])
        encoded_batch = [model.tokenizer.encode(example["example"], return_tensors="pt", padding=True) for example in batch]
        token_batch = torch.cat(encoded_batch, dim=0)

        with torch.no_grad():
            _, batch_act_cache = model.run_with_cache(token_batch)
            batch_act_cache_cpu = {k: v.to("cpu") for k, v in batch_act_cache.items()}  # Keep on CPU to save GPU memory

        for j, cross_section in enumerate(cross_sections):
            hook_point = f"blocks.{cross_section['upstream-layer']}.{cross_section['upstream-component']}"
            position = cross_section["upstream-position"]
            head = cross_section["upstream-head"]

            # Slice only the necessary part of the tensor
            batch_act = batch_act_cache_cpu[hook_point][:, position] if head is None else batch_act_cache_cpu[hook_point][:, position, head]

            for l, example in enumerate(batch):
                example_act = batch_act[l].detach().cpu()

                value_positions = example["value-positions"]

                if(method == "mse"):
                    A[j][i * batch_size + l] = example_act

                # Update the overall mean using Welford's method
                current_count = overall_mean_data[j]['count']
                current_mean = overall_mean_data[j]['mean']
                new_mean = current_mean + (example_act - current_mean) / (current_count + 1)

                overall_mean_data[j]['mean'] = new_mean
                overall_mean_data[j]['count'] += 1

                seen_attributes = set()
                for z in range(len(value_positions)):
                    attribute = example["attributes"][z]
                    value = example["values"][z]

                    if(attribute not in seen_attributes):

                        current_count = feature_data[j][attribute][value]['count']
                        current_mean = feature_data[j][attribute][value]['mean']
                        new_mean = current_mean + (example_act - current_mean) / (current_count + 1)

                        feature_data[j][attribute][value]['mean'] = new_mean
                        feature_data[j][attribute][value]['count'] += 1

                        if(method == "mse"):
                            C[j][attribute, i*batch_size + l, value] = 1
                    
                    seen_attributes.add(attribute)
                        
                # Handle "order" similarly
                order = example["order"]
                current_count = feature_data[j]["order"][order]['count']
                current_mean = feature_data[j]["order"][order]['mean']
                new_mean = current_mean + (example_act - current_mean) / (current_count + 1)

                feature_data[j]["order"][order]['mean'] = new_mean
                feature_data[j]["order"][order]['count'] += 1

                if(method == "mse"):
                    C[j]["order", i*batch_size + l, order] = 1

                del example_act, value_positions, new_mean

            # Clean up after each cross-section is processed
            del batch_act

        # Clear intermediate variables after each batch processing
        del batch, encoded_batch, token_batch, batch_act_cache_cpu, batch_act_cache
        torch.cuda.empty_cache()
        gc.collect()

    for i, cs in tqdm(enumerate(cross_sections), total=len(cross_sections), desc="Compute feature dicts"):
        
        mean_act = overall_mean_data[i]['mean']
        
        if method == 'mean':
            # Mean method: Compute feature dictionaries by subtracting the overall mean
        
            feature_dict = defaultdict(dict)
            for attribute in N_A:
                for value in S_i:
                    if((value in features and attribute != "order") or (attribute == "order" and value in order_values)):
                        feature_dict[attribute][value] = feature_data[i][attribute][value]['mean'] - mean_act
                    
            cs.update({
                "feature-dict": feature_dict,
                "mean-activation": mean_act,
            })

        elif method == 'mse':
            # MSE method: Compute feature dictionaries using the closed-form solution
            feature_dict = compute_mse_features(C[i], A[i], S_i, N_A, mean_act, order_values, features)
            cs.update({
                "feature-dict": feature_dict,
                "mean-activation": mean_act,  # Mean activation isn't used in the MSE method
            })

        res_cs.append(cs)

        # Clear large variables after use
        del feature_dict, mean_act
        torch.cuda.empty_cache()
        gc.collect()

    # Clear memory-heavy structures after processing all cross-sections
    del feature_data, S_i, N_A, C, A, overall_mean_data, order_values, template_attr
    torch.cuda.empty_cache()
    gc.collect()

    return res_cs


def cluster_cross_sections(cross_sections,cluster_effect=True):

    clusters = {}

    for cs in cross_sections:
        # Determine the key for clustering based on downstream-component and the sign of the score
        if(cluster_effect):
            key = (cs['score-attribute'],"reduce" if cs["score"] >= 0 else "increase") # this way since approx. is L_corr - L_clean
        else:
            key = (cs['score-attribute'])

        # Assign the cross-section to the appropriate cluster in the clusters dictionary
        if key not in clusters:
            clusters[key] = []
        clusters[key].append(cs)
    
    for key in sorted(clusters,reverse=False):
        clusters[key] = sorted(clusters[key], key=lambda x: abs(x["score"]), reverse=True)
    
    return clusters


def compute_sweep_cluster(model, get_examples, correct, incorrect, k, cross_section_cluster, step=5):
    def compute_mean_score(model, get_examples, k, cross_section_cluster_subset):
        clean_results = []
        ablation_results = []

        for i in range(k):
            torch.cuda.empty_cache()
            gc.collect()

            # Get example batch and process tokens
            example_batch = get_examples(i)
            batch_tokens = torch.cat([model.tokenizer.encode(example["example"], return_tensors="pt", padding=True) for example in example_batch], dim=0)

            # Forward pass to get logits and cache
            with torch.no_grad():
                clean_logits, cache = model.run_with_cache(batch_tokens)
            clean_logit_ids = clean_logits[:, -1].argmax(1)

            # Process residuals and activations
            resids = defaultdict(lambda: torch.zeros_like(batch_tokens).unsqueeze(-1))
            edit_masks = defaultdict(lambda: torch.zeros(batch_tokens.size(0), batch_tokens.size(1), model.cfg.d_model, dtype=torch.bool))  # Mask for edits

            for l in range(model.cfg.n_layers):
                resid_decomposition, labels = cache.get_full_resid_decomposition(layer=l, expand_neurons=False, return_labels=True)
                ref_resid_decomposition = resid_decomposition.clone()
                resid = cache[f"blocks.{l}.hook_resid_pre"].detach().cpu().clone()
                for cs in cross_section_cluster_subset:
                    downstream_layer = cs["downstream-layer"]
                    if cs['upstream-head'] is not None:
                        cid = f"L{cs['upstream-layer']}H{cs['upstream-head']}"
                    else:
                        cid = f"{cs['upstream-layer']}_mlp_out"
                    upstream_position = cs["upstream-position"]

                    if l == downstream_layer:
                        upstream_component_id = labels.index(cid)
                        resid[:, upstream_position] -= ref_resid_decomposition[upstream_component_id, :, upstream_position].cpu()
                       
                resids[l] = resid
                del resid_decomposition
                gc.collect()

            for cs in cross_section_cluster_subset:
                # Get cross section information
                downstream_layer = cs["downstream-layer"]
                upstream_position = cs["upstream-position"]

                # Create mean ablation
                batch_mean_ablation = torch.stack([cs["mean-activation"] for _ in range(len(example_batch))], dim=0)
                resids[downstream_layer][:, upstream_position] += batch_mean_ablation
            
                # Update the edit mask for the specific locations
                edit_masks[downstream_layer][:, upstream_position] = True

            def get_edit_hook(edit_dict, mask_dict):
                def hook(act, hook):
                    layer = int(hook.name.split(".")[1])
                    # Apply edits only at the masked positions
                    act[mask_dict[layer]] = edit_dict[layer][mask_dict[layer]].to(act.device)
                    return act
                return hook

            # Apply hook and run model for mean ablations
            with torch.no_grad():
                with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(resids, edit_masks))]):
                    ablation_logits = model(batch_tokens)
            
            correct_ids, incorrect_ids = get_logit_ids(model,example_batch,correct,incorrect)

            clean_diff = clean_logits[torch.arange(clean_logits.size(0)), -1, correct_ids] - clean_logits[torch.arange(clean_logits.size(0)), -1, incorrect_ids]
            ablation_diff = ablation_logits[torch.arange(ablation_logits.size(0)), -1, correct_ids] - ablation_logits[torch.arange(ablation_logits.size(0)), -1, incorrect_ids]

            clean_results.extend(clean_diff.tolist())
            ablation_results.extend(ablation_diff.tolist())

            del batch_tokens, clean_logits, ablation_logits, resids
            torch.cuda.empty_cache()
            gc.collect()

        avg_clean = sum(clean_results) / len(clean_results)
        avg_ablation = sum(ablation_results) / len(ablation_results)

        score = abs(avg_clean - avg_ablation)
     
        return score

    # Dictionary to hold normalization scores for increasing subsets of the cluster
    sweep_results = {}

    # Iterate over subsets of the cluster from size 1 to the full size
    for subset_size in range(1, len(cross_section_cluster) + 1, step):
        cluster_subset = cross_section_cluster[:subset_size]
        score = compute_mean_score(model, get_examples, k, cluster_subset)
        sweep_results[f'subset_{subset_size}'] = score

    return sweep_results

def get_alphas(sae, resid, sub_activation, m, alpha_start=1.0, alpha_end=1e-5):
    if m <= 1:
        return [alpha_end]

    # Use the decoder weights from each SAE to reconstruct the sub_activation
    W_dec = [sae[j].W_dec.detach().cpu().numpy() for j in range(len(sae))]
    encoded_resid = [sae[j].encode(resid[j].unsqueeze(0)).detach().cpu().numpy().squeeze() for j in range(len(sae))]
    sub_activation_np = sub_activation.detach().cpu().numpy()

    res = []

    for i in range(len(sub_activation_np)):
        # Identify non-zero (active) features
        active_indices = [np.where(encoded_resid[j][i] != 0)[0] for j in range(len(encoded_resid))]
        X_active = np.concatenate([W_dec[j][active_indices[j], :] for j in range(len(active_indices))], axis=0)

        # Projection Method
        contribution = np.dot(X_active, sub_activation_np[i])
        significant_indices = np.where(np.abs(contribution) > np.abs(contribution).mean())[0]

        # Use only the selected significant indices
        X_significant = X_active[significant_indices, :]

        # Use only the significant features
        y = sub_activation_np[i]
        alphas = np.logspace(np.log10(alpha_start), np.log10(alpha_end), num=50)

        first_alpha = alpha_end
        last_alpha = alpha_start

        # Find the first_alpha by sweeping from left to right
        for alpha in alphas:
            lasso = Lasso(alpha=alpha, fit_intercept=False)
            lasso.fit(X_significant.T, y)

            if np.any(lasso.coef_ != 0):
                first_alpha = alpha
                break

        # Find the last_alpha by annealing from the right side
        for alpha in reversed(alphas):
            lasso = Lasso(alpha=alpha, fit_intercept=False)
            lasso.fit(X_significant.T, y)

            if np.all(lasso.coef_ != 0):
                last_alpha = alpha
            else:
                break

        # Space the alphas evenly between the first and last alpha
        spaced_alphas = np.linspace(first_alpha, last_alpha, num=m)
        res.append(spaced_alphas)
    
    # Return the mean alpha values across all samples
    return np.stack(res, axis=0).mean(0)



@ignore_warnings(category=ConvergenceWarning)
def sae_reconstruct(sae, resid, sub_activation, m, return_components=False):
    # Precompute the sae decoder weights as NumPy arrays for efficient operations
    W_dec = [sae[j].W_dec.detach().cpu().numpy() for j in range(len(sae))]
    
    # Convert sub_activation to NumPy array only once
    sub_activation_np = sub_activation.detach().cpu().numpy()
    
    # Get alpha values
    alphas = get_alphas(sae, resid, sub_activation, m)
    
    # Precompute encoded residuals
    encoded_resid = [sae[j].encode(resid[j]).detach().cpu().numpy() for j in range(len(resid))]

    results = []

    for alpha in alphas:
        reconstruction = []
        components = []

        # Vectorize the loop over all samples in sub_activation_np
        for i in range(sub_activation_np.shape[0]):
            # Perform full residual stream encoding to find active features
            active_indices = [np.where(encoded_resid[j][i] != 0)[0] for j in range(len(encoded_resid))]
            X_active = np.concatenate([W_dec[j][active_indices[j], :] for j in range(len(active_indices))], axis=0)
            
            # Projection Method
            contribution = np.dot(X_active, sub_activation_np[i])
            significant_indices = np.where(np.abs(contribution) > np.abs(contribution).mean())[0]

            # Use only the selected significant indices
            X_significant = X_active[significant_indices, :]

            # Perform Lasso regression using the significant features
            y = sub_activation_np[i]
            lasso = Lasso(alpha=alpha, fit_intercept=False)  # Removed intercept fitting for efficiency
            lasso.fit(X_significant.T, y)
            w = lasso.coef_

            # Reconstruct the sub_activation using the Lasso coefficients
            sub_component_reconstruction = np.dot(X_significant.T, w)
            reconstruction.append(sub_component_reconstruction)

            if return_components:
                sorted_indices = np.argsort(np.abs(w))[::-1]
                X_sorted = X_significant[sorted_indices]
                w_sorted = w[sorted_indices]
                components.append((torch.tensor(X_sorted), torch.tensor(w_sorted)))
                #components.append((torch.tensor(X_significant), torch.tensor(w)))

            #print(torch.norm(sub_activation[i]-sub_component_reconstruction))
        
        # Stack the reconstruction results and convert them to a PyTorch tensor
        reconstruction_tensor = torch.tensor(np.stack(reconstruction, axis=0))
    
        if return_components:
            results.append(components)
        else:
            results.append(reconstruction_tensor)

    return results


def supervised_reconstruct(cs, example_batch, sub_activation):
    reconstruction_list = []
    
    for j, example in enumerate(example_batch):
        value_positions = example["value-positions"]
        values = example["values"]
        attributes = example["attributes"]
        order = example["order"]
        
        value_features = []
        
        seen_attribute = set()
        for i in range(len(value_positions)):
            
            attribute = attributes[i]
            value = values[i]
            
            if(attribute not in seen_attribute):
                feature = cs['feature-dict'][attribute][value]
                value_features.append(feature)
            
            seen_attribute.add(attribute)
        
        # Add the feature corresponding to the order
        value_features.append(cs['feature-dict']["order"][order])

        V = torch.stack(value_features, dim=0)
        R = sub_activation[j] - cs["mean-activation"]
        optimal_weights = torch.linalg.inv(V @ V.T) @ (V @ R)

        rec = cs["mean-activation"] + torch.stack([value_features[i]*optimal_weights[i] for i in range(len(optimal_weights))]).sum(0)
        
        reconstruction_list.append(rec)
  
    return torch.stack(reconstruction_list, dim=0)


def evaluate_sufficiency_and_necessity(model, get_examples, correct, incorrect, k, cross_section_cluster, saes, sizes, m=1, in_context=True):
    def get_edit_hook(edit_dict, mask_dict):
        def hook(act, hook):
            layer = int(hook.name.split(".")[1])
            act[mask_dict[layer]] = edit_dict[layer][mask_dict[layer]].to(act.device)
            return act
        return hook

    def calculate_differences(logits, correct_ids, incorrect_ids):
        return (logits[torch.arange(logits.size(0)), -1, correct_ids] - logits[torch.arange(logits.size(0)), -1, incorrect_ids]).tolist()


    # Pre-initialize results to avoid repeated allocations
    sufficiency_results = {"clean": [], "supervised": [], "mean": [], "sae": [[[] for _ in range(m)] for _ in range(len(saes))]}
    necessity_results = {"clean": [], "supervised": [], "mean": [], "sae": [[[] for _ in range(m)] for _ in range(len(saes))]}

    for i in range(k):
        torch.cuda.empty_cache()
        gc.collect()

        # Get example batch and process tokens
        example_batch = get_examples(i)
        batch_tokens = torch.cat([model.tokenizer.encode(example["example"], return_tensors="pt", padding=True) for example in example_batch], dim=0)

        # Forward pass to get logits and cache
        with torch.no_grad():
            clean_logits, cache = model.run_with_cache(batch_tokens)
        correct_ids, incorrect_ids = get_logit_ids(model, example_batch, correct, incorrect)

        # Process residuals and activations
        resids = defaultdict(lambda: torch.zeros_like(batch_tokens).unsqueeze(-1))
        sub_acts = {}
        edit_masks = defaultdict(lambda: torch.zeros(batch_tokens.size(0), batch_tokens.size(1), model.cfg.d_model, dtype=torch.bool))

        # Minimize redundant operations within loops
        for l in range(model.cfg.n_layers):
            resid_decomposition, labels = cache.get_full_resid_decomposition(layer=l, return_labels=True, expand_neurons=False)
            resid = cache[f"blocks.{l}.hook_resid_pre"].detach().cpu().clone()

            ref_decomposition = resid_decomposition.clone()
            for cs in cross_section_cluster:
                downstream_layer = cs["downstream-layer"]
                cid = f"L{cs['upstream-layer']}H{cs['upstream-head']}" if cs['upstream-head'] is not None else f"{cs['upstream-layer']}_mlp_out"
                upstream_position = cs["upstream-position"]
                if l == downstream_layer:
                    upstream_component_id = labels.index(cid)
                    sub_acts[(downstream_layer, cid, upstream_position)] = ref_decomposition[upstream_component_id, :, upstream_position].clone().detach().to("cpu")
                    resid[:,upstream_position] -= ref_decomposition[upstream_component_id, :, upstream_position].clone().detach().to("cpu")

            resids[l] = resid
        
        # Use only necessary deep copies
        mean_ablations = {k: v.clone() for k, v in resids.items()}

        sufficiency_supervised_resid = {k: v.clone() for k, v in resids.items()}
        sufficiency_sae_resids = [[{k: v.clone() for k, v in resids.items()} for _ in range(m)] for _ in range(len(saes))]

        necessity_supervised_resid = {k: v.clone() for k, v in resids.items()}
        necessity_sae_resids = [[{k: v.clone() for k, v in resids.items()} for _ in range(m)] for _ in range(len(saes))]

        for idx, cs in enumerate(cross_section_cluster):
            downstream_layer = cs["downstream-layer"]
            upstream_layer = cs["upstream-layer"]
            upstream_position = cs["upstream-position"]

            batch_mean_ablation = torch.stack([cs["mean-activation"] for _ in range(len(example_batch))], dim=0)
            sub_activation = sub_acts[(downstream_layer, f"L{cs['upstream-layer']}H{cs['upstream-head']}" if cs['upstream-head'] is not None else f"{cs['upstream-layer']}_mlp_out", upstream_position)]
            resid = [cache[f"blocks.{layer}.hook_resid_post"][:, upstream_position].to("cpu") for layer in range(upstream_layer, downstream_layer)]

            # Update mean ablation resid
            mean_ablations[downstream_layer][:, upstream_position] += batch_mean_ablation

            # Update supervised Resid
            supervised_reconstruction = supervised_reconstruct(cs, example_batch, sub_activation)

            sufficiency_supervised_resid[downstream_layer][:, upstream_position] += supervised_reconstruction
            necessity_supervised_resid[downstream_layer][:, upstream_position] += batch_mean_ablation + (sub_activation - supervised_reconstruction)

            # Update SAE resids
            if in_context:
                reconstruction = [sae_reconstruct(saes[s][upstream_layer:downstream_layer], resid, sub_activation, m) for s in range(len(saes))]
            else:
                reconstruction = [[saes[j][upstream_layer](sub_activation)] for j in range(len(saes))]

            for s in range(len(saes)):
                for j in range(m):
                    necessity_sae_resids[s][j][downstream_layer][:, upstream_position] += batch_mean_ablation + (sub_activation - reconstruction[s][j])
                    sufficiency_sae_resids[s][j][downstream_layer][:, upstream_position] += reconstruction[s][j]
                    continue

            # Update the edit mask
            edit_masks[downstream_layer][:, upstream_position] = True

        # Run model with hooks for sufficiency
        with torch.no_grad():
            with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(mean_ablations, edit_masks))]):
                mean_logits = model(batch_tokens)
            with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(sufficiency_supervised_resid, edit_masks))]):
                sufficiency_supervised_logits = model(batch_tokens)
            for s in range(len(saes)):
                for j in range(m):
                    with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(sufficiency_sae_resids[s][j], edit_masks))]):
                        sae_logits = model(batch_tokens)
                        sufficiency_results["sae"][s][j].extend(calculate_differences(sae_logits, correct_ids, incorrect_ids))

        sufficiency_results["clean"].extend(calculate_differences(clean_logits, correct_ids, incorrect_ids))
        sufficiency_results["supervised"].extend(calculate_differences(sufficiency_supervised_logits, correct_ids, incorrect_ids))
        sufficiency_results["mean"].extend(calculate_differences(mean_logits, correct_ids, incorrect_ids))

        # Run model with hooks for necessity
        with torch.no_grad():
            with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(necessity_supervised_resid, edit_masks))]):
                necessity_supervised_logits = model(batch_tokens)
            for s in range(len(saes)):
                for j in range(m):
                    with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(necessity_sae_resids[s][j], edit_masks))]):
                        sae_logits = model(batch_tokens)
                        necessity_results["sae"][s][j].extend(calculate_differences(sae_logits, correct_ids, incorrect_ids))


        necessity_results["clean"].extend(calculate_differences(clean_logits, correct_ids, incorrect_ids))
        necessity_results["supervised"].extend(calculate_differences(necessity_supervised_logits, correct_ids, incorrect_ids))
        necessity_results["mean"].extend(calculate_differences(mean_logits, correct_ids, incorrect_ids))

        # Clean up to manage memory
        del batch_tokens, clean_logits, sufficiency_supervised_logits, necessity_supervised_logits, sae_logits
        torch.cuda.empty_cache()
        gc.collect()

    # Calculate averages for sufficiency
    sufficiency_avg_clean = sum(sufficiency_results["clean"]) / len(sufficiency_results["clean"])
    sufficiency_avg_supervised = sum(sufficiency_results["supervised"]) / len(sufficiency_results["supervised"])
    sufficiency_avg_mean = sum(sufficiency_results["mean"]) / len(sufficiency_results["mean"])
    sufficiency_sae_avgs = [[sum(sufficiency_results["sae"][s][j]) / len(sufficiency_results["sae"][s][j]) for j in range(m)] for s in range(len(saes))]

    # Calculate averages for necessity
    necessity_avg_clean = sum(necessity_results["clean"]) / len(necessity_results["clean"])
    necessity_avg_supervised = sum(necessity_results["supervised"]) / len(necessity_results["supervised"])
    necessity_avg_mean = sum(necessity_results["mean"]) / len(necessity_results["mean"])
    necessity_sae_avgs = [[sum(necessity_results["sae"][s][j]) / len(necessity_results["sae"][s][j]) for j in range(m)] for s in range(len(saes))]

    # Normalize results for sufficiency and necessity
    sufficiency_normalization = abs(sufficiency_avg_clean - sufficiency_avg_mean)
    necessity_normalization = abs(necessity_avg_clean - necessity_avg_mean)

    # Prepare sufficiency results
    sufficiency_output = {
        "supervised": abs(sufficiency_avg_supervised - sufficiency_avg_mean) / sufficiency_normalization
    }
    for s in range(len(saes)):
        sufficiency_output.update({
            **{f"sae {sizes[s]}": max(0.001, abs(avg_sae - sufficiency_avg_mean) / sufficiency_normalization) for j, avg_sae in enumerate(sufficiency_sae_avgs[s])},
        })

    # Prepare necessity results
    necessity_output = {
        "supervised": 1 - abs(necessity_avg_supervised - necessity_avg_mean) / necessity_normalization,
    }
    for s in range(len(saes)):
        necessity_output.update({
            **{f"sae {sizes[s]}": max([0.01, 1 - abs(avg_sae - necessity_avg_mean) / necessity_normalization]) for j, avg_sae in enumerate(necessity_sae_avgs[s])},
        })

    return {
        "sufficiency": sufficiency_output,
        "necessity": necessity_output
    }



def evaluate_sparse_controllability(model, get_examples, k, cluster_dict, saes, n_edits_list, sizes, m=1):
    results = {}

    for cluster_key, cluster in sorted(cluster_dict.items(), key=lambda x: x[0]):
        counter_attribute = cluster_key[0]  # The attribute specified for this cluster

        # Process each batch separately
        for batch_idx in range(k):
            torch.cuda.empty_cache()
            gc.collect()

            # Get a batch of examples that differ in the specified counter-attribute
            example_batch = get_examples(batch_idx, counter_attr_specified=counter_attribute)
            if not isinstance(example_batch, list):
                example_batch = [example_batch]  # Ensure we handle a single example correctly as a list

            # Tokenize the inputs
            batch_tokens = torch.cat([model.tokenizer.encode(example["example"], return_tensors="pt") for example in example_batch], dim=0)
            counterfactual_tokens = torch.cat([model.tokenizer.encode(example["counterexample"], return_tensors="pt") for example in example_batch], dim=0)

            # Forward pass to get logits and cache
            with torch.no_grad():
                clean_logits, cache = model.run_with_cache(batch_tokens)
                counterfactual_logits, counterfactual_cache = model.run_with_cache(counterfactual_tokens)

            clean_logit_ids = clean_logits[:, -1].argmax(1)

            # Initialize dictionaries for the cluster
            cluster_resid_dicts = {
                "resids": defaultdict(lambda: torch.zeros_like(batch_tokens).unsqueeze(-1)),
                "sub_acts": {},
                "counterfactual_sub_acts": {},
                "edit_masks": defaultdict(lambda: torch.zeros(batch_tokens.size(0), batch_tokens.size(1), model.cfg.d_model, dtype=torch.bool))
            }

            # Process residuals and activations for the cluster
            for l in range(model.cfg.n_layers):
                resid_decomposition, labels = cache.get_full_resid_decomposition(layer=l, expand_neurons=False, return_labels=True)
                ref_decomposition = resid_decomposition.clone().detach()
                counterfactual_resid_decomposition = counterfactual_cache.get_full_resid_decomposition(layer=l, expand_neurons=False)
                ref_counterfactual_decomposition = counterfactual_resid_decomposition.clone().detach()

                resid = cache[f"blocks.{l}.hook_resid_pre"].detach().cpu().clone()

                for cs in cluster:
                    downstream_layer = cs["downstream-layer"]
                    cid = f"L{cs['upstream-layer']}H{cs['upstream-head']}" if cs['upstream-head'] is not None else f"{cs['upstream-layer']}_mlp_out"
                    upstream_position = cs["upstream-position"]

                    if l == downstream_layer:
                        upstream_component_id = labels.index(cid)
                        cluster_resid_dicts["sub_acts"][(downstream_layer, cid, upstream_position)] = ref_decomposition[upstream_component_id, :, upstream_position].clone().detach()
                        cluster_resid_dicts["counterfactual_sub_acts"][(downstream_layer, cid, upstream_position)] = ref_counterfactual_decomposition[upstream_component_id, :, upstream_position].clone().detach()

                        # Zero out the relevant components in-place
                        resid[:, upstream_position] -= ref_decomposition[upstream_component_id, :, upstream_position].cpu()

                # Store residuals for the cluster
                cluster_resid_dicts["resids"][l] = resid

                # Clean up memory
                del resid_decomposition, ref_decomposition, ref_counterfactual_decomposition
                torch.cuda.empty_cache()
                gc.collect()

            # Initialize dictionaries for Supervised and Ground Truth edits
            supervised_resids = {k: v.clone() for k, v in cluster_resid_dicts["resids"].items()}
            ground_truth_resids = {k: v.clone() for k, v in cluster_resid_dicts["resids"].items()}

            # SAE residuals: separate residuals for each SAE and different reconstructions for each `n_edits`
            sae_resids = {n_edits: [[{k: v.clone() for k, v in cluster_resid_dicts["resids"].items()} for _ in range(m)] for _ in range(len(saes))] for n_edits in n_edits_list}

            # Compute supervised and SAE edits for the cluster
            for cs in tqdm(cluster, desc="Compute Edits"):
                downstream_layer = cs["downstream-layer"]
                upstream_layer = cs["upstream-layer"]
                cid = f"L{cs['upstream-layer']}H{cs['upstream-head']}" if cs['upstream-head'] is not None else f"{cs['upstream-layer']}_mlp_out"
                upstream_position = cs["upstream-position"]

                # Prepare the inputs for SAE editing
                resid = [cache[f"blocks.{layer}.hook_resid_post"][:, upstream_position].to("cpu") for layer in range(upstream_layer, downstream_layer)]
                counterfactual_resid = [counterfactual_cache[f"blocks.{layer}.hook_resid_post"][:, upstream_position].to("cpu") for layer in range(upstream_layer, downstream_layer)]
                sub_activation = cluster_resid_dicts["sub_acts"][(downstream_layer, cid, upstream_position)].to("cpu")
                counterfact_sub_activation = cluster_resid_dicts["counterfactual_sub_acts"][(downstream_layer, cid, upstream_position)].to("cpu")

                # Update the edit mask for the specific locations
                cluster_resid_dicts["edit_masks"][downstream_layer][:, upstream_position, :] = True

                # Apply Ground Truth reconstructions
                ground_truth_resids[downstream_layer][:, upstream_position] += counterfact_sub_activation

                # Apply supervised edit
                supervised_edit = get_supervised_edit(cs, example_batch, sub_activation)
                supervised_resids[downstream_layer][:, upstream_position] += supervised_edit

                # Compute multiple SAE edits for each `n_edits` in `n_edits_list`
                for n_edits in n_edits_list:
                    for s_idx, sae_set in enumerate(saes):
                        sae_edits = get_sae_edit(sae_set[upstream_layer:downstream_layer], resid, sub_activation, counterfactual_resid, counterfact_sub_activation, n_edits=n_edits)

                        for j in range(m):
                            sae_resids[n_edits][s_idx][j][downstream_layer][:, upstream_position] += sae_edits[j]

            # Clean up large variables
            del sub_activation, counterfact_sub_activation, resid, counterfactual_resid
            torch.cuda.empty_cache()
            gc.collect()

            def get_edit_hook(edit_dict, mask_dict):
                def hook(act, hook):
                    layer = int(hook.name.split(".")[1])
                    # Apply edits only at the masked positions
                    act[mask_dict[layer]] = edit_dict[layer][mask_dict[layer]].to(act.device)
                    return act
                return hook

            with torch.no_grad():
                # Evaluate supervised and ground truth edits for the cluster
                with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(ground_truth_resids, cluster_resid_dicts["edit_masks"]))]):
                    ground_truth_edited_logits = model(batch_tokens)
                
                ground_truth_last_token_logits = ground_truth_edited_logits[:, -1]
                ground_truth_token_ids = ground_truth_last_token_logits.argmax(1)

                if cluster_key not in results:
                    results[cluster_key] = {f"sae_control_{n_edits}": [[] for _ in range(len(saes))] for n_edits in n_edits_list}
                    results[cluster_key]["supervised_control"] = []

                # Supervised Evaluation
                with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(supervised_resids, cluster_resid_dicts["edit_masks"]))]):
                    supervised_edited_logits = model(batch_tokens)

                # Store results for the cluster
                results[cluster_key]["supervised_control"].extend((supervised_edited_logits[:, -1].argmax(1) == ground_truth_token_ids).float().tolist())

                # Evaluate SAE-edited logits for each SAE, reconstruction, and `n_edits`
                for n_edits in n_edits_list:
                    for s_idx in range(len(saes)):
                        for j in range(m):
                            with model.hooks(fwd_hooks=[(lambda name: "resid_pre" in name, get_edit_hook(sae_resids[n_edits][s_idx][j], cluster_resid_dicts["edit_masks"]))]):
                                sae_edited_logits = model(batch_tokens)
                            results[cluster_key][f"sae_control_{n_edits}"][s_idx].extend((sae_edited_logits[:, -1].argmax(1) == ground_truth_token_ids).float().tolist())

            # Clean up after processing the batch
            del batch_tokens, counterfactual_tokens, clean_logits, counterfactual_logits, ground_truth_edited_logits, supervised_edited_logits
            torch.cuda.empty_cache()
            gc.collect()

    # Normalize results based on clean differences
    final_results = {}
    for cluster_key, cluster_results in results.items():
        avg_supervised_control = sum(cluster_results["supervised_control"]) / len(cluster_results["supervised_control"])

        # Create ordered dictionary
        ordered_results = {
            "supervised": avg_supervised_control
        }

        for s_idx in range(len(saes)):
            for n_edits in n_edits_list:
                sae_avg = [sum(cluster_results[f"sae_control_{n_edits}"][s_idx]) / len(cluster_results[f"sae_control_{n_edits}"][s_idx]) for _ in range(m)]
                for j, sae_avg_val in enumerate(sae_avg):
                    ordered_results[f"sae {sizes[s_idx]} - {n_edits} edits"] = max(0.01, sae_avg_val)

        final_results[cluster_key] = ordered_results

    return final_results



def get_supervised_edit(cs,example_batch,sub_activation):
    act = sub_activation.clone()

    for i in range(len(example_batch)):
        example = example_batch[i]
        counter_attribute = example["counter-attribute"]
        if(counter_attribute == "order"):
            counter_value = example["counter-order"]
            value = example["order"]
        else:
            counter_value_idx = example["counter-attributes"].index(counter_attribute)
            counter_value = example["counter-values"][counter_value_idx]
            value_idx = example["attributes"].index(counter_attribute)
            value = example["values"][value_idx]

        counter_feature = cs['feature-dict'][counter_attribute][counter_value]
        feature = cs['feature-dict'][counter_attribute][value]
        
        act[i] = act[i] - feature + counter_feature

    return act



def get_sae_edit(saes, resid, sub_activation, counter_resid, counter_sub_activation, n_edits):
    act = sub_activation.clone()  # Make a deep copy of the activation to avoid altering the original

    components = sae_reconstruct(saes, resid, sub_activation, return_components=True, m=1)[0]
    counter_components = sae_reconstruct(saes, counter_resid, counter_sub_activation, return_components=True, m=1)[0]

    for i in range(sub_activation.shape[0]):
        X, w = components[i]
        c_X, c_w = counter_components[i]

        # Sort weights and counter_weights by magnitude
        sorted_indices = torch.argsort(w.abs(), descending=True)
        sorted_c_indices = torch.argsort(c_w.abs(), descending=True)

        sorted_X = X[sorted_indices]
        sorted_w = w[sorted_indices]
        sorted_c_X = c_X[sorted_c_indices]
        sorted_c_w = c_w[sorted_c_indices]

        # Track used edits
        used_c_edits = set()
        used_edits = set()

        # Greedy selection with minimization, measured against the running activation
        for _ in range(n_edits):
            best_diff = float('inf')
            best_edit = None

            # Iterate over sorted weights and features
            for idx in range(len(sorted_w)):
                for c_idx in range(len(sorted_c_w)):
                    # Skip used edits
                    if idx in used_edits or c_idx in used_c_edits:
                        continue

                    # Simulate edit
                    new_act = act[i] - sorted_w[idx] * sorted_X[idx] + sorted_c_w[c_idx] * sorted_c_X[c_idx]
                    new_difference = torch.norm(new_act - counter_sub_activation[i])

                    if new_difference < best_diff:
                        best_diff = new_difference
                        best_edit = (sorted_w[idx], sorted_X[idx], sorted_c_w[c_idx], sorted_c_X[c_idx])
                        best_indices = (idx, c_idx)

            # Apply best edit and mark as used
            if best_edit:
                weight, feature, counter_weight, counter_feature = best_edit
                act[i] = act[i] - weight * feature + counter_weight * counter_feature
                used_edits.add(best_indices[0])
                used_c_edits.add(best_indices[1])

    #act = counter_sub_activation[i] - torch.stack([X[i] * w[i] for i in range(len(w))],dim=0).sum(0) + torch.stack([c_X[i] * c_w[i] for i in range(len(c_w))],dim=0).sum(0)
    return act
