import torch as t
from einops import rearrange
from loguru import logger
from torch.utils.data import DataLoader

from auto_encoder import device
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel


class ActivationBuffer:
    def __init__(
        self,
        batch_size: int,
        supermodel: FrozenTransformerAutoencoderSuperModel,
        train_dataloader: DataLoader,
        do_shuffle: bool = True,
        max_buffer_size: int = 32,
        inverted: bool = False,
    ):
        self.input_ids: list[t.Tensor] = []
        self.activations: list[t.Tensor] = []

        self.max_buffer_size = max_buffer_size
        self.batch_size = batch_size
        self.supermodel = supermodel
        self.train_dataloader = train_dataloader
        self.do_shuffle = do_shuffle
        self.inverted = inverted

    def get_activation_batch(self) -> t.Tensor:
        if len(self.activations) < self.max_buffer_size // 2:
            self.refresh()

        return self.activations.pop()

    @t.no_grad()
    def refresh(self) -> None:
        self.supermodel.eval()

        while len(self.activations) < self.max_buffer_size:
            batch = next(iter(self.train_dataloader))

            batch_inputs_BS: t.Tensor = batch["input_ids"]
            batch_inputs_BS = batch_inputs_BS.to(device)

            if self.inverted:
                feature_activations_BSF = self.supermodel.get_feature_activations(
                    batch_inputs_BS
                )
                self.activations.append(feature_activations_BSF.to("cpu"))
            else:
                initial_mlp_activations_BSN = self.supermodel.get_neuron_activations(
                    batch_inputs_BS
                )
                self.activations.append(
                    initial_mlp_activations_BSN.to("cpu")
                )  # buffer list[batch seq_len num_neurons]

        # Shuffle the buffer
        if self.do_shuffle:
            self.shuffle_activations()

    def shuffle_activations(self) -> None:
        activations_tensor_BuBaSN = t.stack(
            self.activations, dim=0
        )  # buffer, batch, seq_len, num_neurons

        buffer_size, batch_size, seq_len, _ = activations_tensor_BuBaSN.shape
        activations_tensor_BN = rearrange(
            activations_tensor_BuBaSN,
            "buffer batch seq_len num_neurons -> (buffer batch seq_len) num_neurons",
        )
        shuffle_order = t.randperm(buffer_size * batch_size * seq_len)
        activations_tensor_BN = activations_tensor_BN[shuffle_order]
        activations_tensor_BuBaSN = rearrange(
            activations_tensor_BN,
            "(buffer batch seq_len) num_neurons -> buffer batch seq_len num_neurons",
            buffer=buffer_size,
            batch=batch_size,
        )

        activations = list(t.split(activations_tensor_BuBaSN, 1, dim=0))
        self.activations = [activation.squeeze(0) for activation in activations]
