import pytest
import torch as t
from torch.utils.data import DataLoader, Dataset
from transformer.config import TRANSFORMER_MODELS
from transformer.models.model import TinyTransformer

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.buffer import ActivationBuffer
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel


class DummyDataset(Dataset):
    def __init__(self, num_samples, seq_len, vocab_size):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {"input_ids": t.randint(0, self.vocab_size, (self.seq_len,))}


@pytest.fixture
def mock_supermodel():
    ae_config = AutoEncoderConfig()

    # Define model and optimizer
    pretrained_transformer = TinyTransformer(
        pretrained_load_path=TRANSFORMER_MODELS["main_transformer"].file_name
    )

    supermodel = FrozenTransformerAutoencoderSuperModel(
        pretrained_transformer,
        ae_config=ae_config,
        device=device,
        medoid_initial_tensor_N=None,
        expert_initial_tensors=None,
        scaling_factor=None,
    )
    return supermodel


@pytest.fixture
def train_dataloader():
    dataset = DummyDataset(num_samples=100, seq_len=10, vocab_size=30522)
    return DataLoader(dataset, batch_size=2)


@pytest.fixture
def activation_buffer(mock_supermodel, train_dataloader):
    return ActivationBuffer(
        batch_size=2,
        supermodel=mock_supermodel,
        train_dataloader=train_dataloader,
        max_buffer_size=12,
    )


def test_initial_state(activation_buffer: ActivationBuffer):
    assert activation_buffer.max_buffer_size == 12
    assert activation_buffer.batch_size == 2
    assert len(activation_buffer.activations) == 0
    assert len(activation_buffer.input_ids) == 0


def test_refresh(activation_buffer: ActivationBuffer):
    activation_buffer.refresh()
    assert len(activation_buffer.activations) == 12


def test_get_batch_input(activation_buffer: ActivationBuffer):
    activation_buffer.refresh()
    batch = activation_buffer.get_activation_batch()
    assert batch.shape == (2, 10, 512)
    assert len(activation_buffer.activations) == 11


def test_refresh_shuffles_activations(activation_buffer: ActivationBuffer):
    activation_buffer.refresh()
    initial_activations = t.stack(activation_buffer.activations)
    activation_buffer.refresh()
    new_activations = t.stack(activation_buffer.activations)
    assert not t.equal(initial_activations, new_activations)


if __name__ == "__main__":
    pytest.main()
