from typing import Optional

import kmedoids
import torch as t
from einops import rearrange
from geom_median.torch import compute_geometric_median
from k_means_constrained import KMeansConstrained
from loguru import logger
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances
from torch.utils.data import DataLoader, TensorDataset

from auto_encoder import debug, device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_output_types import SupermodelOutput
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel
from data.ae_data import get_eval_dataloader


def estimate_activations_geometric_median(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    dataloader: DataLoader,
    num_samples: int = 10_000,
    max_iter: int = 5_000,
) -> tuple[t.Tensor, float]:
    # COLLECT ACTIVATIONS
    random_activations_subset = generate_random_activations(
        supermodel, dataloader, num_samples
    )

    random_activations_subset = random_activations_subset.to("cpu")

    gm_output = compute_geometric_median(random_activations_subset, maxiter=max_iter)
    geometric_median: t.Tensor = gm_output.median.to(device)

    shifted_inputs = random_activations_subset.to(device) - geometric_median
    median_absolute_deviation = t.median(t.abs(shifted_inputs)).item()

    return geometric_median, median_absolute_deviation


def generate_random_activations(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    dataloader: DataLoader,
    num_samples: int,
    num_batches: int = 50,
):
    neuron_activations_list: list[t.Tensor] = []
    for sample_num, batch in enumerate(dataloader):
        if sample_num >= num_batches:
            break

        # Sample `batch_size` examples from the dataset
        batch_input: t.Tensor = batch["input_ids"]  # batch_size, seq_len
        batch_input = batch_input.to(device)  # type: ignore

        # Run the model on the batch
        supermodel.eval()
        with t.no_grad():
            supermodel_output: SupermodelOutput = supermodel(
                batch_input, output_eval_metrics=True
            )
        neuron_activations = (
            supermodel_output.initial_neuron_activations_BSN
        )  # batch_size, seq_len, num_neurons

        _batch_size, seq_len, _num_neurons = neuron_activations.shape

        random_indices = t.randint(0, seq_len, (num_samples // 100,))
        random_activations = neuron_activations[
            :, random_indices
        ]  # batch_size, 50, num_neurons

        random_activations = random_activations.to("cpu")

        neuron_activations_list.append(random_activations)

    # COMPUTE GEOMETRIC MEAN
    neuron_activations_tensor = t.cat(
        neuron_activations_list, dim=0
    )  # num_samples*batch_size, 50, num_neurons

    num_tokens, num_rand_samples, _neuron_dim = neuron_activations_tensor.shape
    neuron_activations_tensor = rearrange(
        neuron_activations_tensor, "tokens samples neurons -> (tokens samples) neurons"
    )  # num_tokens*num_rand_samples, neuron_dim

    random_activations_subset = neuron_activations_tensor[
        t.randperm(num_tokens)[:num_samples]
    ]  # num_samples, neuron_dim

    return random_activations_subset


def find_input_centroids(
    x: t.Tensor,
    num_clusters: int,
    max_iter: int,
    distance_fn: str = "cosine",
) -> tuple[t.Tensor, Optional[float]]:
    # print(x.shape)

    if len(x.shape) == 3 and x.shape[1] == 1:
        x = x.squeeze(1)

    assert len(x.shape) == 2, f"Input should be 2D. Got input of shape {x.shape}"

    x_numpy = x.detach().cpu().numpy()

    if distance_fn == "cosine":
        distances = cosine_distances(x_numpy)
    elif distance_fn == "euclidean":
        distances = euclidean_distances(x_numpy)
    else:
        raise ValueError(
            f"Distance function {distance_fn} not recognized, choose either 'cosine' or 'euclidean'"
        )

    # Balanced k-means
    # Fit K-Medoids model
    km = kmedoids.KMedoids(n_clusters=num_clusters, method="fasterpam", max_iter=max_iter)
    km.fit(distances)

    # Get the cluster centers
    medoid_indices = km.medoid_indices_
    loss = km.inertia_

    # Extract the medoid points
    medoids = x_numpy[medoid_indices] * 1.0

    centroids = t.tensor(medoids, device=x.device, dtype=x.dtype)
    loss = km.inertia_

    logger.debug(f"Centroids shape {centroids.shape}")

    return centroids, loss


if __name__ == "__main__":
    dataloader = get_eval_dataloader()
    transformer = TinyTransformer()
    config = AutoEncoderConfig()
    supermodel = FrozenTransformerAutoencoderSuperModel(
        transformer,
        device=device,
        medoid_initial_tensor_N=None,
        expert_initial_tensors=None,
        scaling_factor=None,
        ae_config=config,
    )
    geometric_median = estimate_activations_geometric_median(supermodel, dataloader)
    print(geometric_median)
