import torch
import numpy as np
from transformer_lens.utils import to_numpy
from sklearn.cluster import DBSCAN
from collections import Counter

from WeightLens.weight_lens.utils import (
    get_outliers, 
    get_activation_with_stop, 
    Feature
)

def append_forward_hooks(model, hook_name = None):
    if hook_name is None:
        hook_name = model.feature_input_hook
    # Dictionary to store activations from the MLP layers
    activations = {}
    forward_hooks = []

    # Function to generate a forward hook for a given layer name
    def generate_hook(layer_name):
        def forward_hook(module, input, output):
            # Store the output of the layer in the activations dictionary
            activations[layer_name] = output
        return forward_hook
    
    for name, layer in model.named_modules():
        if hook_name in name:
            handle = layer.register_forward_hook(generate_hook(name))
            forward_hooks.append(handle)

    return activations, forward_hooks

def append_backward_hooks(model, hook_name = None):
    if hook_name is None:
        hook_name = model.feature_output_hook
    # Dictionary to store activations from the MLP layers
    gradients = {}
    backward_hooks = []

    # Function to generate a forward hook for a given layer name
    def generate_hook(layer_name):
        def backward_hook(module, input, output):
            # Store the output of the layer in the activations dictionary
            gradients[layer_name] = output
        return backward_hook
    
    for name, layer in model.named_modules():
        if hook_name in name:
            handle = layer.register_full_backward_hook(generate_hook(name))
            backward_hooks.append(handle)

    return gradients, backward_hooks


def feature_attribution(model, input_ids, target_feature, token_id, fwd_hooks = []):
    # Attaching forward and backward hooks
    activations, forward_hooks = append_forward_hooks(model)
    gradients, backward_hooks = append_backward_hooks(model)

    all_acts = {}
    all_handles = []
    
    # Setting up hooks
    for hook_name in fwd_hooks:
        acts, handles = append_forward_hooks(model, hook_name)
        all_acts[hook_name] = acts
        all_handles.extend(handles)
    handles.extend(forward_hooks)
    handles.extend(backward_hooks)


    # Forward pass with stop at the target feature layer
    activations_transcoders, logits = get_activation_with_stop(model, input_ids, stop_at_layer=target_feature.layer, requires_grad=True, return_logits=True)

    # Backward pass to get gradients
    target_activation = activations[f'blocks.{target_feature.layer}.ln2.hook_normalized'][0, token_id, :] @ model.transcoders[target_feature.layer].W_enc[:, target_feature.feature_idx].to(torch.float32)
    target_activation.backward()

    for hook in all_handles:
        hook.remove()
    all_handles = []

    return activations_transcoders, all_acts, gradients, logits


def get_positive_contributions(contributions: torch.Tensor):
    mask = contributions > 0
    indices = torch.nonzero(mask, as_tuple=False).view(-1)   # always 1D
    values = contributions[mask]

    if values.numel() == 0:
        return torch.empty(0, dtype=torch.long), torch.empty(0, dtype=contributions.dtype)

    sorted_vals, sorted_idx = torch.sort(values, descending=True)
    sorted_indices = indices[sorted_idx]

    return sorted_indices, sorted_vals


def get_contributing_features(model, prompt, feature, token_position):
    activations_transcoders, gradients = feature_attribution(model, prompt, feature, token_position)

    contributing_features = {}
    for layer in range(feature.layer):
        contributions = model.transcoders[layer].W_dec.to(torch.float32) @ gradients[f'blocks.{layer}.hook_mlp_out.hook_out_grad'][0][0, token_position]
        contributions = torch.mul(activations_transcoders[layer, token_position], contributions)
        feature_ids, contrs = get_positive_contributions(contributions)
        contributing_features[layer] = {int(idx): float(val) for idx, val in zip(feature_ids, contrs)}

    return contributing_features

@torch.no_grad()
def get_attention_contribution(model, cache, target_feature, token_position=None, contribs_threshold=5):
    # Inspired by transcoder_circuits
    # https://github.com/jacobdunefsky/transcoder_circuits/blob/master/transcoder_circuits/circuit_analysis.py
    if token_position is None:
        token_position=target_feature.pos

    attn_contribution = {}

    for layer in range(target_feature.layer + 1):
        v = cache['hook_v'][f'blocks.{layer}.attn.hook_v'].detach().cpu()
        attn = cache['hook_pattern'][f"blocks.{layer}.attn.hook_pattern"].detach().cpu()
        
        # For gemma
        if model.cfg.n_key_value_heads is not None:

            # Step 1: Set up
            _, _, n_kv_heads, _ = v.shape
            n_heads = model.cfg.n_heads
            n_kv_heads = model.cfg.n_key_value_heads
            repeat_factor = n_heads // n_kv_heads

            # Step 3: Repeat kv heads to match number of query heads
            v = v.repeat_interleave(repeat_factor, dim=2)  # [b, n_heads, key_len, head_dim]


        weighted_vals = torch.einsum(
            'b h d s, b s h f -> b h d s f', 
            attn, v
        )
        weighted_outs = torch.einsum(
            'b h d s f, h f m -> b h d s m',
            weighted_vals, model.blocks[layer].attn.W_O.cpu().to(torch.float32)
        )

        feature_vector = model.transcoders[target_feature.layer].W_enc[:, target_feature.feature_idx].cpu().to(torch.float32)
        contribs = torch.einsum(
            'b h d s m, m -> b h d s',
            weighted_outs, feature_vector
        )
        contribs = contribs[0, :, token_position, :]
        top_attn_contrib_indices_flattened = get_outliers(contribs.flatten(), threshold=contribs_threshold)[0]
        top_attn_contribs_flattened = contribs.flatten()[top_attn_contrib_indices_flattened].detach().cpu()
        top_attn_contrib_indices = np.array(np.unravel_index(to_numpy(top_attn_contrib_indices_flattened), contribs.shape)).T  
                
        attn_contribution[layer] = top_attn_contrib_indices, top_attn_contribs_flattened.numpy()

    return attn_contribution

def get_concept_pattern(attn_heads, prompt, feature_activations=None, k=5):
    """
    Create a mask of tokens that contributed to the feature.
    Tokens are kept if they either:
        - Are among top-k contributing tokens via attention heads
        - OR have positive activation in feature_activations (if provided)
    """
    # Flatten and collect all entries from attention
    entries = []
    for layer, (indices, contribs) in attn_heads.items():
        for (head, token), value in zip(indices, contribs):
            entries.append([value.item(), layer, head.item(), token.item()])

    # Sort by contribution descending
    entries.sort(reverse=True)

    mask = np.zeros((len(prompt)), dtype=int)

    if len(entries) > 0:
        topk_tokens = np.unique(np.array(entries)[:k, -1]).astype(int)
        mask[topk_tokens] = prompt[topk_tokens]

    # Keep the last token
    mask[-1] = prompt[-1]

    # Incorporate positive activations
    if feature_activations is not None:
        # feature_activations: [seq_len] tensor or array
        pos_indices = np.nonzero(feature_activations[:len(mask)] > 0)[0]
        mask[pos_indices] = prompt[pos_indices]

    return mask


def get_concept_pattern_with_fallback(model, cache_acts, feature, input, token_position, feature_activations=None, k=5):
    """
    Try multiple contribs_threshold values until a non-empty mask is found.
    Returns mask or None.
    """
    thresholds = [5, 4.5, 4, 3.5, 3, 2.5, 2, 1.5, 1]

    for thr in thresholds:
        attn_heads = get_attention_contribution(model, cache_acts, feature, token_position, contribs_threshold=thr)
        mask = get_concept_pattern(attn_heads, input[:token_position+1], feature_activations.cpu(), k=k)

        if (mask > 0).any():
            # trim to start from first nonzero
            return attn_heads, mask[np.nonzero(mask)[0].min():]

    return None, None

def pattern_and_circuit_discovery(model, inputs, feature, generate_tokens=5, topk=5):
    attn_heads_info, patterns, features = [], [], []
    positive_activations, positive_inputs, outputs = [], [], []
    
    for input, token_position in inputs:
        fwd_hooks = ["hook_v", "hook_pattern", "resid_pre", "resid_mid", "ln1.hook_normalized"]

        # Consider only real tokens up to token_position
        seq_ids = input[:token_position+1]

        transcoder_acts, cache_acts, gradients, logits = feature_attribution(
            model, seq_ids, feature, token_position, fwd_hooks
        )

        if transcoder_acts[feature.layer, :, feature.feature_idx].sum() > 0: 
            positive_activations.append(transcoder_acts[feature.layer, :, feature.feature_idx].to(torch.float32))
            positive_inputs.append((input, token_position))

            # --- attention head contributions ---
            attn_heads, mask = get_concept_pattern_with_fallback(
                model, cache_acts, feature, input, token_position,
                feature_activations=transcoder_acts[feature.layer, :, feature.feature_idx]
            )
            attn_heads_info.append(attn_heads)
            patterns.append(mask)

            # --- feature contributions ---
            contributing_features = {}
            for ll in range(feature.layer):
                contributions = model.transcoders[ll].W_dec.to(torch.float32) @ gradients[f'blocks.{ll}.hook_mlp_out.hook_out_grad'][0][0, token_position]
                contributions = torch.mul(transcoder_acts[ll, token_position], contributions)
                feature_ids, contrs = get_positive_contributions(contributions)
                contributing_features[ll] = {int(idx): float(val) for idx, val in zip(feature_ids, contrs)}
            features.append(contributing_features)

            # --- next-token info ---
            last_logits = logits[0, -1]  # shape [vocab_size]
            probs = torch.softmax(last_logits, dim=-1)
            top_vals, top_ids = torch.topk(probs, topk)
            top_predictions = [
                (model.tokenizer.decode([i.item()]), i.item(), v.item())
                for i, v in zip(top_ids, top_vals) if i.item() != model.tokenizer.pad_token_id
            ]

            # --- greedy continuation ---
            generated = []
            cur_ids = seq_ids.clone()  # use only the real tokens
            for _ in range(generate_tokens):
                with torch.no_grad():
                    out = model(cur_ids[None, :])
                next_id = out[0, -1].argmax().cpu()
                cur_ids = torch.cat([cur_ids, next_id[None]], dim=0)
                if next_id.item() != model.tokenizer.pad_token_id:  # skip padding
                    generated.append(next_id.item())

            outputs.append({
                "top_predictions": top_predictions,
                "greedy_continuation": generated
            })

    result = {
        "patterns": patterns,
        "features": features,
        "attn_heads": attn_heads_info,
        "activations": positive_activations,
        "inputs": positive_inputs,
        "outputs": outputs
    }

    return result

def extract_feature_sets_combined(
    data, 
    max_token_positions, # attn_heads
    original_contributions=None, # features
    per_layer=True, 
    min_freq_ratio=0.05,
    outputs=None,          # new: list of output dicts, same length as data
    top_k_logits=5         # take top-k predictions from logits
):
    """
    Combine original features, attention heads, and model outputs into one feature set per input,
    keeping only frequent features above min_freq_ratio.
    """
    raw_input_sets = []

    for idx, input_dict in enumerate(data):
        feats = set()

        # Original features
        if original_contributions:
            orig_dict = original_contributions[idx]
            for layer, feats_dict in orig_dict.items():
                if per_layer:
                    feats.update({(layer, f) for f in feats_dict.keys()})
                else:
                    feats.update(feats_dict.keys())

        # Attention heads
        for layer, (arr, _) in input_dict.items():
            for row in arr:
                head, token_pos = int(row[0]), int(row[1])
                rel_pos = token_pos - max_token_positions[idx]
                if per_layer:
                    feats.add((layer, "att", head, rel_pos))
                else:
                    feats.add(("att", head, rel_pos))

        # Outputs as features
        if outputs:
            out_dict = outputs[idx]
            # Top-k logits
            for rank, (token_str, token_id, score) in enumerate(out_dict.get("top_predictions", [])[:top_k_logits]):
                feats.add(("top_logits", rank, token_id))
            # Greedy continuation
            for pos, token_id in enumerate(out_dict.get("greedy_continuation", [])):
                feats.add(("greedy_predict", pos, token_id))

        raw_input_sets.append(feats)

    # -----------------------
    # Pre-filter: keep only features occurring in >= min_freq_ratio fraction of inputs
    # -----------------------
    n_inputs = len(raw_input_sets)
    feature_counts = Counter(f for s in raw_input_sets for f in s)
    frequent_features = {f for f, count in feature_counts.items() if count / n_inputs >= min_freq_ratio}

    # Build filtered input sets
    filtered_input_sets = [{f for f in s if f in frequent_features} for s in raw_input_sets]
    return filtered_input_sets

def jaccard_distance_matrix(feature_sets):
    n = len(feature_sets)
    dist = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            inter = len(feature_sets[i] & feature_sets[j])
            union = len(feature_sets[i] | feature_sets[j])
            d = 1 - inter / union if union > 0 else 1
            dist[i, j] = dist[j, i] = d
    return dist

def cluster_inputs_dbscan(feature_sets, eps=0.5, min_samples=3):
    dist = jaccard_distance_matrix(feature_sets)
    clustering = DBSCAN(
        metric="precomputed",
        eps=eps,
        min_samples=min_samples,
    ).fit(dist)
    return clustering.labels_


def cluster_top_features(feature_sets, labels, topk=10):
    clusters = {}
    for c in set(labels):
        feats = [f for i, s in enumerate(feature_sets) if labels[i] == c for f in s]
        clusters[c] = Counter(feats).most_common(topk)
    return clusters

def cluster_activations(feature_result, print_summary=False):
    # --- 1. Extract input-based feature sets ---
    feature_sets_input = extract_feature_sets_combined(
        data=feature_result['attn_heads'],
        max_token_positions=[inp[1] for inp in feature_result['inputs']],
        original_contributions=feature_result['features'],
        per_layer=True,
        min_freq_ratio=0.1,
        outputs=None  # ignore outputs
    )

    # --- 2. Extract output-based feature sets ---
    feature_sets_output = extract_feature_sets_combined(
        data=[{}]*len(feature_result['outputs']),  # dummy input dicts
        max_token_positions=[0]*len(feature_result['outputs']),  # dummy
        original_contributions=None,
        per_layer=False,
        min_freq_ratio=0.05,
        outputs=feature_result['outputs'],  # use only outputs
        top_k_logits=5
    )

    # --- 3. Compute clusters separately ---
    labels_input = cluster_inputs_dbscan(feature_sets_input, eps=0.7, min_samples=3)
    labels_output = cluster_inputs_dbscan(feature_sets_output, eps=0.3, min_samples=3)

    # --- 4. Summarize top features per cluster ---
    summary_input = cluster_top_features(feature_sets_input, labels_input, topk=5)
    summary_output = cluster_top_features(feature_sets_output, labels_output, topk=5)

    if print_summary:
        print("Cluster assignments input-based:", labels_input)
        print("\nCluster summaries:")
        for c, feats in summary_input.items():
            print(f"Cluster {c}:")
            for f, count in feats:
                print(f"  {f}: {count} occurrences")
        print("Cluster assignments output-based:", labels_output)
        print("\nCluster summaries:")
        for c, feats in summary_output.items():
            print(f"Cluster {c}:")
            for f, count in feats:
                print(f"  {f}: {count} occurrences")
    return labels_input, labels_output