import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchfly.flyconfig import GlobalFlyConfig

from compressive_transformer.compressive_transformer import CompressiveTransformerModel


config = GlobalFlyConfig(config_path="config/base64_16.yml", disable_chdir=True, disable_logging=True)
config = config.user_config

model = CompressiveTransformerModel(config.model)

input_ids = torch.LongTensor([[0, 1, 2, 3]])

model(input_ids)

breakpoint()