import torch
import numpy as np
from tqdm import tqdm
import os
from .utils import process_batch
import json
import pandas as pd

def measure_coactivation_graph(model, dataloader, tokenizer, end_batch_ind=3):
    """
    Measure and build the co-activation hypergraph based on neuron activation thresholds.
    
    This function processes a dataset through the model and builds a hypergraph where
    each hyperedge represents neurons that are co-activated for a specific input token.
    It uses hooks to capture activations at the MLP layers and determine which neurons
    is activated.
    
    The process follows these steps:
    1. Load pre-computed activation thresholds for each neuron
    2. Register hooks to capture activations from each layer's MLP
    3. Process batches of data through the model
    4. For each token, determine which neurons are co-activated 
    
    Args:
        model: The transformer model to analyze (AutoModelForCausalLM)
        dataloader: DataLoader containing the dataset to process
        end_batch_ind: Number of batches to process (default: 3)
        
    Returns:
       List of pandas Datarames, each containing the co-activation hyperedges for a layer
    """
    # Co-activation graph buffer
    coactivation_edges_buffer = [None] * model.config.num_hidden_layers

    hyperedges = [[] for _ in range(model.config.num_hidden_layers)]

    def get_coactivation_hook(layer_ind):
        def hook(module, input, output):
            # Get the activation tensor for output of activation function
            activation = output[0].clone().detach().cpu().float().numpy().reshape(-1, model.config.intermediate_size) # (batch * seq_len) x neurons
            # Get indices of neurons that exceed their thresholds (i.e., are activated)
            inds = np.where(activation > 0, np.arange(model.config.intermediate_size), -1) # (batch * seq_len) x neurons
            # Store these indices as a potential hyperedge
            coactivation_edges_buffer[layer_ind] = inds
        return hook
    
    # Register hooks for each layer's act_fn to capture activations
    hooks = []
    for layer_ind in range(model.config.num_hidden_layers):
        hook = model.model.layers[layer_ind].mlp.act_fn.register_forward_hook(get_coactivation_hook(layer_ind))
        hooks.append(hook)
    print('Co-activation hooks registered.')

    # Process batches of data
    for batch_idx, data in tqdm(enumerate(dataloader)):
        # Process the batch through the model
        process_batch(model, data)

        id_data = data["input_ids"].reshape(-1)

        # Process and write the edges from buffer to file after each batch
        for layer_ind in range(model.config.num_hidden_layers):
            if coactivation_edges_buffer[layer_ind] is not None:
                for edge_idx, edges in enumerate(coactivation_edges_buffer[layer_ind]):
                    # token_info 
                    # data: (batch_size, seq_len)
                    token_id = id_data[edge_idx]
                    token = tokenizer.decode(token_id)
                    # sentence = tokenizer.decode(token_id[0])
                    # edges: (batch_size, seq_len)
                    hyperedge_item = {
                        "token": token,
                        "token_id": int(token_id),
                        "activated_neurons": edges[edges>=0].tolist()
                    }
                    hyperedges[layer_ind].append(hyperedge_item)
                                    
                # Clear the buffer for this layer to free memory
                coactivation_edges_buffer[layer_ind] = []
        
        print(f"Processed batch {batch_idx+1}")
        
        if batch_idx >= end_batch_ind:
            break
    
    # Clean up: remove hooks to free up resources
    for hook in hooks:
        hook.remove()

    dataframes = [pd.DataFrame(hyperedges[layer_ind]) for layer_ind in range(len(hyperedges))]
    return dataframes
