from typing import List

import pytest

from olmo.config import (
    DataConfig,
    InitFnType,
    ModelConfig,
    OptimizerConfig,
    PaddingDirection,
    SchedulerConfig,
    TokenizerConfig,
    TrainConfig,
)
from olmo.tokenizer import Tokenizer

TEST_MODEL = "gpt2"

LOREM_IPSUM_1 = 

LOREM_IPSUM_2 = 


@pytest.fixture(scope="function")
def model_config() -> ModelConfig:
    return ModelConfig(
        vocab_size=50257,
        eos_token_id=50256,
        pad_token_id=50256,
        d_model=128,
        n_heads=2,
        n_layers=3,
        max_sequence_length=512,
        init_fn=InitFnType.normal,
    )


@pytest.fixture(scope="function")
def tokenizer() -> Tokenizer:
    return Tokenizer.from_pretrained(TEST_MODEL)


@pytest.fixture(scope="function")
def train_config(tmp_path, model_config) -> TrainConfig:
    return TrainConfig(
        model=model_config,
        optimizer=OptimizerConfig(),
        scheduler=SchedulerConfig(),
        data=DataConfig(
            paths=[
                "test_fixtures/c4-sample.01.json.gz",
                "test_fixtures/c4-sample.02.json.gz",
                "test_fixtures/c4-sample.03.json.gz",
            ],
            pad_direction=PaddingDirection.right,
        ),
        tokenizer=TokenizerConfig(identifier=TEST_MODEL),
        save_folder=str(tmp_path / "checkpoints"),
    )


@pytest.fixture(scope="module")
def eos_token_id(tokenizer: Tokenizer) -> int:
    return tokenizer.eos_token_id


@pytest.fixture(scope="module")
def lorem_ipsum() -> str:
    return LOREM_IPSUM_1.replace("\n", " ").strip()


@pytest.fixture(scope="module")
def lorem_ipsum_docs() -> List[str]:
    return [text.replace("\n", " ").strip() for text in (LOREM_IPSUM_1, LOREM_IPSUM_2)]


@pytest.fixture(scope="function")
def model_path() -> str:
    return "test_fixtures/test-olmo-model"
