import torch

from torch.utils.data import DataLoader
from utils import collate_fn
from tqdm import tqdm
from sklearn.covariance import EmpiricalCovariance
from tqdm import tqdm


def get_activations(model, dataloader, device, chosen=True):
    model.eval()
    activations = []

    # Hook to extract activations
    def get_activation():
        def hook(model, input, output):
            if isinstance(output, tuple):
                output = output[0]
            last_token_output = output[:, -1, :].detach().cpu()
            activations.append(last_token_output)

        return hook

    # Register hooks for each layer
    hooks = []
    for name, layer in model.named_modules():
        if name in [f"base_model.model.model.layers.{i}" for i in range(32)]:  # For each block outputs
            hooks.append(layer.register_forward_hook(get_activation()))

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting activations"):
            if chosen:
                input_ids = batch["input_ids_chosen"].to(device)
                attention_mask = batch["attention_mask_chosen"].to(device)
            else:
                input_ids = batch["input_ids_rejected"].to(device)
                attention_mask = batch["attention_mask_rejected"].to(device)
            model(input_ids=input_ids, attention_mask=attention_mask)

    # Remove hooks after extraction
    for hook in hooks:
        hook.remove()

    if len(activations) > 0:
        activations = torch.cat(activations, dim=0)
    else:
        activations = torch.tensor([])

    return activations


def extract_activations(dataset, model, device):
    # Create dataloaders
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
    # Extract activations
    activations = get_activations(model, dataloader, device)

    return activations


def get_mean(tensor):
    return torch.mean(tensor, dim=0)


def get_inverse_covariance(eval_tensor):
    _, block_size, num_features = eval_tensor.size()

    # Initialize a tensor to store inverse covariance matrices for each block
    inverse_covariances = torch.zeros(
        (block_size, num_features, num_features),
        device=eval_tensor.device,
        dtype=torch.float32,
    )

    # Create an EmpiricalCovariance estimator
    group_lasso = EmpiricalCovariance(assume_centered=False)

    # Calculate the inverse covariance matrix for each block across all samples
    for i in tqdm(range(block_size), desc="Calculating inverse covariance for each block"):
        # Select all samples for the current block
        block_data = eval_tensor[:, i, :]

        # Center the data by subtracting the mean across all samples for the current block
        centered_data = block_data - torch.mean(block_data, dim=0)

        # Compute the covariance matrix and its inverse using EmpiricalCovariance
        group_lasso.fit(centered_data.cpu().numpy())
        precision = group_lasso.precision_

        # Convert the precision matrix to a PyTorch tensor and store it
        inverse_covariances[i] = torch.from_numpy(precision).float().to(eval_tensor.device)

    return inverse_covariances


def get_mahalanobis_distance(train_tensor, means, inverse_covariances):
    num_samples, block_size, _ = train_tensor.size()
    mahalanobis_distances = torch.zeros((block_size, num_samples), device=train_tensor.device, dtype=torch.float32)

    # Calculate Mahalanobis distance for each block across all samples
    for i in tqdm(range(block_size), desc="Calculating Mahalanobis distance for each block"):
        # Select all samples for the current block
        block_data = train_tensor[:, i, :]

        # Compute the difference from the mean for each sample
        diff = (block_data - means[i]).to(torch.float32)

        # Compute the left term of the Mahalanobis distance formula
        left_term = torch.matmul(diff, inverse_covariances[i])

        # Calculate the Mahalanobis distance for each sample in the current block
        mahalanobis_distance = torch.sqrt(torch.sum(left_term * diff, dim=1))

        # Store the computed distances
        mahalanobis_distances[i] = mahalanobis_distance

    return mahalanobis_distances


def compute_score(eval_tensor, train_tensor):
    # Compute mean and inverse covariance for train activations
    eval_mean = get_mean(eval_tensor)
    eval_inverse_cov = get_inverse_covariance(eval_tensor)
    mahalanobis_score = get_mahalanobis_distance(train_tensor, eval_mean, eval_inverse_cov)

    return mahalanobis_score


def estimate_mahalanobis_score_using_activations(
    eval_chosen_activations,
    eval_rejected_activations,
    train_chosen_activations,
    train_rejected_activations,
):
    eval_chosen_activations = eval_chosen_activations.view(-1, 32, eval_chosen_activations.size(1))
    # Concatenate the activations
    eval_rejected_activations = eval_rejected_activations.view(-1, 32, eval_rejected_activations.size(1))
    eval_tensor_concat = (
        torch.cat((eval_chosen_activations, eval_rejected_activations), dim=2)
        .to(eval_chosen_activations.device)
        .to(torch.float16)
    )

    train_chosen_activations = train_chosen_activations.view(-1, 32, train_chosen_activations.size(1))
    train_rejected_activations = train_rejected_activations.view(-1, 32, train_rejected_activations.size(1))
    train_tensor_concat = (
        torch.cat((train_chosen_activations, train_rejected_activations), dim=2)
        .to(train_chosen_activations.device)
        .to(torch.float16)
    )

    # Compute the Mahalanobis score
    mahalanobis_score = compute_score(eval_tensor_concat, train_tensor_concat)
    return mahalanobis_score
