# Dataset and Model Configuration
dataset: 'listops'  # Dataset to use for training and evaluation
# Dataset choices: [,'imdb', 'imdb_long', 'imdb_lra', 'cifar10', 'listops']
# imdb_long uses canine-c tokenizer with data from huggingface and imdb_lra uses custom ascii tokenizer with dowloaded data
model_name: 'custom'  # Pretrained model to use/ For importing tokenizer
# Model choices: ['bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased', 'roberta-base',\
#'roberta-large', 'google/canine-c', 'custom']

# Training Parameters
batch_size: 48  # Batch size for training
max_seq_len: 512  # Maximum sequence length for input texts
num_epochs: 30  # Number of training epochs
learning_rate: '5e-4'  # Learning rate
weight_decay: 0.1  # Weight decay factor
use_scheduler: True # Uses the default scheduler (Currently: OneCycleLR )

# Model Architecture
num_heads: 2  # Number of attention heads
hidden_dim: 2048  # Dimension of the feedforward network model
num_layers: 1  # Number of transformer encoder layers
dropout: 0.1  # Dropout factor of the model
attention_type: 'astro'  # Type of attention to use, choices: ['astro', 'softmax']

# Memory Replay Parameters
memory_replay_backprop: False  # Use memory replay backpropagation
num_segments: 1  # Number of segments to split input sequences for memory replay backpropagation
num_memory_tokens: 8  # Number of memory tokens
astro_mem: True  # Flag to use the Astrocytic memory
mem_sum: False  # Flag to use the Summed memory from all the segments

# AstroAttention Parameters
scaleD: 100  # Scaling factor for AstroAttention
alpha: 0.25  # Alpha parameter for AstroAttention
add_Hrel: True  # Flag to add Hrel/scaling_factor
astro_sigmoid_nonlinearity: False  # Flag to use sigmoid nonlinearity over the H_neuron and H_astro
clip: 500  # Clip value for relative positional encoding

# Pooling Method
pooling: 'average'  # Pooling method, choices: ['cls', 'average']

# Dataset Sampling
sample_percentage: 100  # Percentage of dataset to use for training

# Pretrained Model Handling
freeze_pretrained: False  # Freeze pretrained model parameters
use_only_embeddings: True  # Use only embeddings from pretrained model

# Attention Replacement Parameters
replace_attention: False  # Replace attention layers with AstroAttention (only True for use_only_embeddings: False)
layers_to_replace:  #'0,1,2'  # Comma-separated list of layer indices to replace attention (e.g., "0,1,2"). If None, replace all layers.

# Logging and Saving
wandb: True  # Use Weights & Biases for logging
wandb_run_name: 'rmaat_v8_listops_t35'  # Name for the Weights & Biases run
model_save_path: './models'  # Path to save the model and plots

# Seed for Reproducibility
seed: 42  # Seed value for reproducibility

# Comments:
# starting rmaat_v5:
# Relative_pos (non-skewed) as Hrel used
# OneCyclicLR scheduler is being used
# Added yelp_review_full, imdb_long dataset
# AdamW is being used as optimizer (some of v4 also used AdamW)
# upto v5_imdb_t6 => scheduler was applied outside batch loop, from t7=> scheduler applied inside batch loop
# From v5_imdb_t10 onwards => optimizer was applied after every segment, previously it wa used after every batch
# In v5_imdb_t14 and t15 => optimizer.zero_grad() applied after each segment
# In v5_imdb_t16, t17 => optimizer.step() after all the segments (gradient accumulation)
# From v5_imdb_t17 => Softmax attention option and clip value was added as arguments
# clip was 10 upto v5_imdb_t16
# in v5_imdb_t26_value_masked=> value was masked_filled with 0 instead of query

# v7 is basically continuation of v5
# from v7_imdb_t4 weighted loss applied, weighted_loss = loss * current_segment/total_segement
# In v7_imdb_t4,t5 xavier_uniform, gain changed to 0.02
# In v7_imdb_t7, used scaled down memory tokens (1e-3)
# In v7_imdb_t12, used scaled down memory tokens (1e-5)

# From v8 all glue tasks, wikitext, yelp have been removed. Long Range Arena 3 datasets have been introduced: imdb_lra, cifar10, listops
# From v8_cifar10_t13, 3072 max sequence length used by concatenating r,g,b channels in Cifar10Dataset class in lra_datasets.py
# From v8_listops_t4=> used custom_word_tokenizer
# From v8_listops_t19 to t21=> used normal_word_tokenizer
# From v8_listops_t20 => used test set instead of validation set
# From v8_listops_t22 => used custom word tokenizer
# From v8_cifar10_t20, v8_listops_t26 => Relpos changed to account for sequence length
# From v8_cifar10_t21, v8_listops_t26 => used torch.sign in astro_attention
