import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

def find_layers_module(model):
    """
    Tries to find 'layers' from 'model.model.layers' or 'model.model.model.layers'.
    """
    try:
        return model.model.layers
    except AttributeError:
        return model.model.model.layers

def get_intermediate_layers(model, is_frozen, num_points=5):
    layers = find_layers_module(model)
    total_blocks = len(layers)  # for some models: model.model.layers
    indices = torch.linspace(0, total_blocks - 1, steps=num_points).long().tolist()
    return indices

def extract_intermediate_features(model, input_ids, attention_mask, layers_idx, is_frozen):
    activations = {}
    hooks = []

    def get_hook(idx):
        def hook(module, input, output):
            # output: (batch, seq_len, hid_dim)
            activations[idx] = output[0].detach() if isinstance(output, tuple) else output.detach()
        return hook
    
    # For Qwen/Llama architecture with transformer blocks in model.model.layers
    layers = find_layers_module(model)
    for idx in layers_idx:
        hooks.append(
            layers[idx].register_forward_hook(get_hook(idx))
        )

    with torch.no_grad():
        _ = model(input_ids=input_ids, attention_mask=attention_mask)
    
    for hook in hooks:
        hook.remove()

    return [activations[idx] for idx in layers_idx]

def compute_concat_intermediate_features(model, input_ids, attention_mask, is_frozen=False, num_intermediate_layers=5):
    # Get indeces of layers to extract features from
    layers_idx = get_intermediate_layers(model, is_frozen, num_points=num_intermediate_layers)

    # Extract intermediate features from the layers specified in layers_idx
    features = extract_intermediate_features(model, input_ids, attention_mask, layers_idx, is_frozen)

    flat_normed_feats = []
    for f in features:
        # f shape: (batch=1, seq_len, hid_dim)
        f = f.squeeze(0)
        f = f.flatten()  # shape (seq_len * hid_dim,)
        f_normed = F.normalize(f, dim=0)
        flat_normed_feats.append(f_normed)

    concat_feature = torch.cat(flat_normed_feats, dim=0)  # shape: (num_points * seq_len * hid_dim,)
    return concat_feature

def compute_feature_space_distance_v1(model_trained, model_frozen, input_ids, attention_mask, num_intermediate_layers=5):
    """
    Computes the feature space distance over a batch.
    
    Args:
      model_trained: finetuned model (parameters θ_t)
      model_frozen: initial/pretrained model (parameters θ_0)
      input_ids, attention_mask: batch inputs to LLM (shape: [batch, seq])
      get_phi_fn: function(model, input_ids, attention_mask) -> [batch, feat_dim] or [feat_dim,]
        Returns the concatenated intermediate features

    Returns:
      Scalar tensor: average L2 feature space distance for the batch
    """
    batch_size = input_ids.size(0)
    distances = []
    for i in range(batch_size):
        phi_trained = compute_concat_intermediate_features(model_trained, input_ids[i:i+1], attention_mask[i:i+1], False, num_intermediate_layers)  # (feat_dim,)
        phi_frozen = compute_concat_intermediate_features(model_frozen, input_ids[i:i+1], attention_mask[i:i+1], True, num_intermediate_layers)
        d = torch.norm(phi_trained - phi_frozen, p=2) ** 2 
        distances.append(d)
    distances = torch.stack(distances)
    return distances.mean()

def compute_feature_space_distance_normalized(model_trained, model_frozen, input_ids, attention_mask, num_intermediate_layers=5):
    """
    Normalizes on LEN of PHI
    """
    batch_size = input_ids.size(0)
    distances = []
    for i in range(batch_size):
        phi_trained = compute_concat_intermediate_features(model_trained, input_ids[i:i+1], attention_mask[i:i+1], False, num_intermediate_layers)  # (feat_dim,)
        phi_frozen = compute_concat_intermediate_features(model_frozen, input_ids[i:i+1], attention_mask[i:i+1], True, num_intermediate_layers)
        d = torch.norm(phi_trained - phi_frozen, p=2) ** 2 
        # TEST - NORMALIZE ON LEN OF PHI
        phi_dim = phi_trained.size(0)  
        d = d / phi_dim
        distances.append(d)
    distances = torch.stack(distances)
    return distances.mean()

def extract_features_for_lp(model, dataloader, num_layers=5, is_frozen=False, return_numpy=True):
    #prep
    model.eval()
    all_features = []
    all_labels = []

    device = next(model.parameters()).device
    #batch processing
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            #if datasets use either labels or targets
            labels = batch["labels"] if "labels" in batch else batch["target"]
            labels = labels.to(device)

            #extract ldifs features
            features = []
            for i in range(input_ids.size(0)):
                f = compute_concat_intermediate_features(
                    model,
                    input_ids[i:i+1],
                    attention_mask[i:i+1],
                    is_frozen=is_frozen,
                    num_intermediate_layers=num_layers
                )
                features.append(f)
            all_features.append(torch.stack(features))
            all_labels.append(labels)

    X = torch.cat(all_features)
    y = torch.cat(all_labels)

    if return_numpy:
        return X.cpu().numpy(), y.cpu().numpy()
    return X, y
