from transformers import BertConfig, LongformerConfig


class LocalWindowsConfig(BertConfig):
    model_type = 'local windows'

    def __init__(
            self,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            max_position_embeddings=2048,
            window_size=512,
            **kwargs
    ):
        super().__init__(
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=intermediate_size,
            max_position_embeddings=max_position_embeddings,
            **kwargs)
        self.window_size = window_size


class CustomLongformerConfig(LongformerConfig):
    def __init__(
            self,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            max_position_embeddings=2048,
            attention_window=512,
            **kwargs
    ):
        super().__init__(
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=intermediate_size,
            max_position_embeddings=max_position_embeddings,
            **kwargs)
        self.attention_window = [attention_window]


# add custom configs here


dataset_paths = {
    'diabetes': 'objects/diabetes1024.pt',
    'heart': 'objects/heart1024.pt',
    'image': 'objects/image4096.pt',
    'generated': 'objects/generated.pt'

    # add other data here if needed
}
column_labels = {
    'diabetes': ['bmi', 'hbA1c', 'blood glucose'],
    'heart': ['oldpeak', 'thalach', 'cp'],
    'image': None,
    'generated': None

    # add column labels here if needed
}
