import torch
import torch.nn.functional as F
from tqdm import tqdm
from Layer_Pruning.get_cosine import generate_unique_index

@torch.no_grad()
def get_hidden_state_kl_interval(model, dataset, num_data, device, layer_intervals, num_layer):
    model.eval()
    model = model.to(device)
    kl_divergence = [[] for _ in range(num_layer - layer_intervals + 1)]
    data_index = generate_unique_index(0, len(dataset), num_data)

    for i in tqdm(data_index, desc="Computing KL divergence on hidden states"):
        input_ids = torch.tensor(dataset[i]['input_ids']).unsqueeze(0).to(device)

        # Forward pass to get original hidden states
        outputs = model(input_ids, output_hidden_states=True)
        original_hidden_states = outputs.hidden_states  # Tuple: (embeddings, layer1, ..., layerN)

        for j in range(num_layer - layer_intervals + 1):
            def perturb_hook(module, input, output):
                if isinstance(output, tuple):
                    perturbed = output[0] + torch.randn_like(output[0]) * 1e-2
                    return (perturbed,) + output[1:]
                else:
                    return output + torch.randn_like(output) * 1e-2

            # Dynamically get target layer
            if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
                target_layer = model.transformer.h[j]
            elif hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
                target_layer = model.model.decoder.layers[j]
            elif hasattr(model, 'encoder') and hasattr(model.encoder, 'layer'):
                target_layer = model.encoder.layer[j]
            else:
                raise ValueError("Unsupported model structure for hook registration.")

            handle = target_layer.register_forward_hook(perturb_hook)

            # Forward with perturbation
            perturbed_outputs = model(input_ids, output_hidden_states=True)
            perturbed_hidden_states = perturbed_outputs.hidden_states

            handle.remove()

            # Compare at layer j + interval (as output_hidden_states[0] = embeddings)
            layer_idx = j + layer_intervals
            if layer_idx >= len(original_hidden_states):
                continue

            orig_hidden = original_hidden_states[layer_idx]       # shape: [1, seq_len, hidden_dim]
            pert_hidden = perturbed_hidden_states[layer_idx]

            # Apply softmax across hidden_dim to get a probability distribution
            orig_probs = F.softmax(orig_hidden, dim=-1)
            pert_probs = F.softmax(pert_hidden, dim=-1)

            # Avoid log(0) by clamping
            pert_log = torch.log(torch.clamp(pert_probs, min=1e-8))

            # Compute KL divergence (batchmean over all tokens)
            kl = F.kl_div(pert_log, orig_probs, reduction='batchmean').item()
            kl_divergence[j].append(kl)

    # Compute average KL per layer
    print("Calculating average KL divergence on hidden states...")
    avg_kl = [sum(layer_scores) / len(layer_scores) for layer_scores in kl_divergence]
    avg_kl_tensor = torch.tensor(avg_kl)
    best_layer = torch.argmin(avg_kl_tensor).item()
    best_kl = avg_kl[best_layer]

    for i, kl_score in enumerate(avg_kl):
        print(f'The KL divergence between layer {i} and layer {i + layer_intervals} is {kl_score:.6f}')

    print(f'The lowest KL divergence is from layer {best_layer} to layer {best_layer + layer_intervals}, value: {best_kl:.6f}')

    model.cpu()
    torch.cuda.empty_cache()

    return best_layer
