import os

import torch as t
from einops import rearrange
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_output_types import SupermodelOutput
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel
from data.ae_data import get_train_dataloader


@t.no_grad()
def get_features_hit_tensor(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    num_features: int,
    train_dataloader: DataLoader,
    total_num_batches: int = 10_000,
    device: str = device,
) -> tuple[t.Tensor, t.Tensor]:

    features_hit_times_F = t.zeros(num_features, device=device, dtype=t.int32)
    median_activations_sum_F = t.zeros(num_features, device=device)

    config = supermodel.ae_config

    for idx, train_batch in tqdm(enumerate(train_dataloader)):
        if idx > total_num_batches:
            break

        batch = {k: v for k, v in train_batch.items()}
        tokens_BS: t.Tensor = batch["input_ids"].to(device)
        logger.debug(tokens_BS.shape)

        # logger.info(tokens_BS.shape)

        tokens_BS = tokens_BS[: config.transformer_batch_size, : config.seq_len]

        supermodel_output: SupermodelOutput = supermodel(tokens_BS)
        feature_activations_BSF: t.Tensor = supermodel_output.feature_activations_BSF

        current_features_hit_F = (
            (feature_activations_BSF > 0).float().sum(dim=(0, 1)).to(t.int32)
        )
        features_hit_times_F += current_features_hit_F

        feature_activations_BsF = rearrange(
            feature_activations_BSF,
            "batch seq_len num_features -> (batch seq_len) num_features",
        )

        feature_activations_nan_BsF = feature_activations_BsF.clone()
        feature_activations_nan_BsF[feature_activations_nan_BsF == 0] = float("nan")

        local_median_activation_F = t.median(feature_activations_nan_BsF, dim=0).values
        median_activations_sum_F += local_median_activation_F

        if idx % 1000 == 0 or (idx % 100 == 0 and idx < 1000):
            logger.info(f"Processed {idx}/{total_num_batches} batches")
            logger.info(f"Features hit: {features_hit_times_F}")
            logger.info(f"Features hit sorted: {features_hit_times_F.sort().values}")
            logger.info(f"Median feature hit: {features_hit_times_F.median()}")
            logger.info(f"Max feature hit: {features_hit_times_F.max()}")
            logger.info(f"Min feature hit: {features_hit_times_F.min()}")

            avg_median_activation_F = median_activations_sum_F / (idx + 1)

            logger.info(f"Average median activation: {avg_median_activation_F}")

    avg_median_activation_F = median_activations_sum_F / total_num_batches

    return features_hit_times_F, avg_median_activation_F


def get_ordered_feature_info(
    original_supermodel: FrozenTransformerAutoencoderSuperModel,
    config: AutoEncoderConfig,
    autoencoder_path: str,
    base_path: str = "artifacts/auto_encoder",
) -> tuple[t.Tensor, t.Tensor, t.Tensor]:

    if os.path.exists(f"{base_path}/{autoencoder_path}_features_hit_order.pt"):
        features_hit_order_F = t.load(f"{base_path}/{autoencoder_path}_features_hit_order.pt")
        features_hit_order_F = features_hit_order_F.to(device)

    if os.path.exists(f"{base_path}/{autoencoder_path}_features_hit_times.pt"):
        features_hit_times_F = t.load(f"{base_path}/{autoencoder_path}_features_hit_times.pt")
        ordered_features_hit_times_F = features_hit_times_F.to(device)

    if os.path.exists(f"{base_path}/{autoencoder_path}_avg_median_feature_activation.pt"):
        print("Loading cached feature information (e.g. features_hit_times)")
        avg_median_feature_activation_F = t.load(
            f"{base_path}/{autoencoder_path}_avg_median_feature_activation.pt"
        )
        ordered_avg_median_feature_activation_F = avg_median_feature_activation_F.to(device)

    else:
        print("Calculating feature information (e.g. features_hit_times)")
        try:
            tokenizer = AutoTokenizer.from_pretrained(config.transformer_model_name)
        except Exception:
            tokenizer = AutoTokenizer.from_pretrained("gpt2")

        dataloader = get_train_dataloader(batch_size=config.batch_size, tokenizer=tokenizer)

        features_hit_times_F, avg_median_feature_activation_F = get_features_hit_tensor(
            supermodel=original_supermodel,
            num_features=config.num_features,
            train_dataloader=dataloader,
            total_num_batches=1_000,
        )

        features_hit_order_F = t.argsort(features_hit_times_F, descending=True)

        # Save the tensors

        t.save(features_hit_order_F, f"{base_path}/{autoencoder_path}_features_hit_order.pt")

        ordered_avg_median_feature_activation_F = avg_median_feature_activation_F[
            features_hit_order_F
        ]
        t.save(
            ordered_avg_median_feature_activation_F,
            f"{base_path}/{autoencoder_path}_avg_median_feature_activation.pt",
        )

        ordered_features_hit_times_F = features_hit_times_F[features_hit_order_F]
        t.save(
            ordered_features_hit_times_F,
            f"{base_path}/{autoencoder_path}_features_hit_times.pt",
        )

    return (
        ordered_features_hit_times_F,
        ordered_avg_median_feature_activation_F,
        features_hit_order_F,
    )
