
import numpy as np
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from datasets import load_dataset
import random
from my_utils.utils import SAGMAN_V2
from torch.nn.functional import softmax

def get_text_field(dataset_name):
    """Get the appropriate text field based on dataset."""
    if dataset_name == 'sst2':
        return 'sentence'
    elif dataset_name == 'mnli':
        return 'premise'  # We'll use the premise text for MNLI
    elif dataset_name == 'oasst1' or dataset_name == 'imdb':
        # OASST1 messages are stored under the 'text' field
        return 'text'
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
def select_nodes(top_node_list: np.ndarray,top_k_percent: float = 0.01):
    n_samples = len(texts)
    n_select = int(n_samples * top_k_percent)
    non_robust_indices = top_node_list[:n_select]
    robust_indices = top_node_list[-n_select:]

    return non_robust_indices, robust_indices

def extract_embeddings_advanced_v5(
    model, 
    tokenizer, 
    texts, 
    labels, 
    batch_size=4, 
    device=None,
    input_layer=0,
    output_layer=-1,
    attention_layer=-1,
    use_cls_token_attention=True
):
    """
    Extracts sentence-level input (X) and output (Y) embeddings from a Transformer model using
    attention-based pooling.

    Parameters
    ----------
    model : PreTrainedModel
        A Hugging Face Transformer model.
    tokenizer : PreTrainedTokenizer
        Corresponding tokenizer for the model.
    texts : list of str
        The input sentences.
    labels : list
        The labels corresponding to each text (not used in embeddings, can be used for downstream tasks).
    batch_size : int
        Batch size for embedding extraction.
    device : torch.device or None
        Device to run the model on. If None, will use CUDA if available, else CPU.
    input_layer : int
        Index of the hidden state layer to use for input embeddings. 
        Typically 0 corresponds to token embeddings right after the embedding layer.
    output_layer : int
        Index of the hidden state layer to use for output embeddings. 
        -1 corresponds to the last layer.
    attention_layer : int
        Index of the attention layer to use for pooling.
        -1 corresponds to the last layer.
    use_cls_token_attention : bool
        If True, use attention weights from the [CLS] token to pool token embeddings. 
        If False, can implement another pooling strategy.

    Returns
    -------
    input_embeddings : np.ndarray, shape (N, d)
        Sentence-level input embeddings.
    output_embeddings : np.ndarray, shape (N, d)
        Sentence-level output embeddings.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    input_embeddings_list = []
    output_embeddings_list = []

    # Ensure input_layer and output_layer indices are valid
    # (This depends on the model, but typically model returns all layers including embeddings)
    # For a model with L+1 hidden_states: 
    # hidden_states[0] is embedding layer output
    # hidden_states[1] is output of first transformer layer, etc.
    # Negative indexing works as in Python lists.

    # Similarly for attentions, ensure indexing is valid
    # attentions[l] corresponds to attention after layer l (0-based),
    # so attentions[-1] is last layer's attention.
    # No explicit check here, but could be added for robustness.

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        # Tokenize and move to device
        encodings = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt')
        input_ids = encodings['input_ids'].to(device)
        attention_mask = encodings['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(
                input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                output_attentions=True
            )

        hidden_states = outputs.hidden_states
        attentions = outputs.attentions

        # Verify that we have attentions
        if attentions is None or len(attentions) == 0:
            # Fallback: If no attentions, just use a simple average pooling
            # This fallback should be rare if output_attentions=True is supported by the model.
            # Otherwise, consider raising an error or use another strategy.
            token_input_embs = hidden_states[input_layer]   # [N, T, d]
            token_output_embs = hidden_states[output_layer] # [N, T, d]

            # Masked average pooling
            mask_expanded = attention_mask.unsqueeze(-1)    # [N, T, 1]
            sum_input_emb = torch.sum(token_input_embs * mask_expanded, dim=1)
            sum_output_emb = torch.sum(token_output_embs * mask_expanded, dim=1)
            valid_counts = torch.sum(attention_mask, dim=1, keepdim=True).float()
            z_X = sum_input_emb / valid_counts
            z_Y = sum_output_emb / valid_counts
        else:
            # Use attention-based pooling
            token_input_embs = hidden_states[input_layer]   # [N, T, d]
            token_output_embs = hidden_states[output_layer] # [N, T, d]

            # Select the attention layer to use
            selected_attention = attentions[attention_layer]  # [N, num_heads, T, T]
            # Average over heads
            avg_attention = selected_attention.mean(dim=1)    # [N, T, T]

            # Use CLS token attention for pooling if requested
            if use_cls_token_attention:
                # CLS typically at index 0
                cls_attention_weights = avg_attention[:, 0, :]  # [N, T]
            else:
                # Another strategy: e.g. average over all tokens for equal weighting
                # or use a different token as reference. For now, just sum over tokens.
                cls_attention_weights = avg_attention.mean(dim=1)  # [N, T]

            # Now we have attention weights from CLS to every token. 
            # Ensure normalization (though they should already be normalized per token)
            # We'll explicitly apply a softmax over the token dimension:
            alpha_input = softmax(cls_attention_weights, dim=1)  # [N, T]
            # Optionally, we could pick a different set of attentions for output embeddings,
            # but here we reuse the same alpha for simplicity:
            alpha_output = alpha_input

            # Mask out padding tokens if necessary
            # Although attention should account for masks, let's be safe:
            alpha_input = alpha_input * attention_mask
            alpha_output = alpha_output * attention_mask

            # Renormalize after masking
            alpha_input = alpha_input / (alpha_input.sum(dim=1, keepdim=True) + 1e-9)
            alpha_output = alpha_output / (alpha_output.sum(dim=1, keepdim=True) + 1e-9)

            # Weighted sum for input embeddings
            alpha_input_expanded = alpha_input.unsqueeze(-1)    # [N, T, 1]
            z_X = torch.sum(alpha_input_expanded * token_input_embs, dim=1)  # [N, d]

            # Weighted sum for output embeddings
            alpha_output_expanded = alpha_output.unsqueeze(-1)  # [N, T, 1]
            z_Y = torch.sum(alpha_output_expanded * token_output_embs, dim=1)  # [N, d]

        # Move to CPU and convert to numpy
        z_X = z_X.detach().cpu().numpy()
        z_Y = z_Y.detach().cpu().numpy()

        input_embeddings_list.append(z_X)
        output_embeddings_list.append(z_Y)

    # Concatenate all batches
    input_embeddings = np.vstack(input_embeddings_list)
    output_embeddings = np.vstack(output_embeddings_list)

    return input_embeddings, output_embeddings


model_name = ['openlm-research/open_llama_7b_v2']
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset_name = 'sst2'
dataset = load_dataset("glue", dataset_name)
text_field = get_text_field(dataset_name)
all_examples = []
for split in dataset.keys():
    split_data = dataset[split]
    texts = [example[text_field] for example in split_data]
    labels = [example['label'] for example in split_data]
    all_examples.extend(list(zip(texts, labels)))

# Sample data
sample_size = 10000 
if sample_size:
    indices = random.sample(range(len(all_examples)), min(sample_size, len(all_examples)))
    texts = [all_examples[i][0] for i in indices]
    labels = [all_examples[i][1] for i in indices]


input_embeddings, output_embeddings = extract_embeddings_advanced_v5(model, tokenizer, texts, labels)
TopEig, TopEdgeList, TopNodeList, node_score, internal_edge = SAGMAN_V2(
                            input_embeddings, 
                            output_embeddings, 
                            k = 20,                           
                        )

non_robust_indices, robust_indices = select_nodes(TopNodeList, 0.01)