# Predict command
## COMMAND: [ python, predict.py ]
# Train command

SAVE_INTERVAL: 50
VALID_INTERVAL: 5
PATIENCE: 5
TRAIN_COMMAND: [ python, train.py ]
LOO_CV_COMMAND: [python, loo_cross_validation.py]

CLASSES:
  - { label: 1,  name: 'walking',                             }
  - { label: 2,  name: 'running',                             }
  - { label: 3,  name: 'shuffling',                 }
  - { label: 4,  name: 'stairs (ascending)',        }
  - { label: 5,  name: 'stairs (descending)',       }
  - { label: 6,  name: 'standing',                            }
  - { label: 7,  name: 'sitting',                             }
  - { label: 8,  name: 'lying',                               }
  - { label: 13, name: 'cycling (sit)',                       }
  - { label: 14, name: 'cycling (stand)',           }
  - { label: 130, name: 'cycling (sit, inactive)',  }
  - { label: 140, name: 'cycling (stand, inactive)',}

# --- Information about data
# Path to training dataset 
class_count: {0: 1000, 1: 162, 2: 77, 3: 800}
class_count_semi: [0.01, 0.05, 0.1]


class_dict: {0: "AFIB", 1: "AFL", 2: "J", 3: "N"}

# Amount of training data used for validation (between 0 and 1)
VALID_SPLIT: 0.2
display_batch: 12000
# Randomly selected test subjects (Not used during LOSO!)
TEST_SUBJECTS: [S027.csv,S023.csv,S008.csv,S019.csv,S006.csv,S024.csv, S010.csv,  S014.csv,  S017.csv,  S020.csv]
SEED: 77

NUM_WORKERS: 4  # Num of workers for data loading
NUM_GPUS: [0]  # How many/which GPUs to use
WANDB: False

seeds: [42, 53, 64, 75]
baselines: [vanilla_cl, double_cl, margin_cl, ts2vec, infoTS, triplet, tnc, cpc]

output_dir: output/ecg/



DATASET: ECG2
DATASET_ARGS:
  #STFT FOR ECG
  class_to_exclude: [7]
  n_fft: [1500]
  hop_length: [750]
  win_length: [1500]
  seq_length: [200]
  num_labels: [4]

# -- Model 
# Which classifier to use 
ALGORITHM: DownstreamMLP
# Arguments for classifier
# (all given as lists in case to perform a GridSearch)
ALGORITHM_ARGS:
  epochs: [1]
  supervised_epochs: [50]
  iterations: [1500]
  linear_epochs: [10]
  margin: [100]
  p_overlap: [0.5]
  n_split: [10]
  feature_dim: [2]
  out_features: [256]
  temperature: [0.5]
  mask_fraction: [0.3]
  margin_thresh: [0.4]
  mse_lambda: [1]
  batch_size: [16]
  tnc_window: [20]
  fft_features: 2502
  sequence_sample: [2500]
  save_model: True

  

  # MARGIN_LOSS, LS_MARGIN_LOSS, LS_HATCL_LOSS, HATCL_LOSS
  loss: [HATCL_LOSS]
  mloss: [LS_MARGIN_LOSS]
  optimizer: [AdamW]
  weight_decay: [3e-4]
#   output_activation: [softmax]
#   metrics: [[F1Score,Accuracy,Precision,Recall]]  # Which metrics to log
  lr: [3e-4]
  lr_scheduler: [ExponentialTFDecay]
  number_subject: [10]
  number_sample: [1]
  total_subjects: [30000]
  # lr_scheduler: [LinearWUSchedule]
  # Architecture params

STORE_CMATS: True  # Store test cmats on disk
SKIP_FINISHED_ARGS: False

# Metric which defines the best model
EVAL_METRIC: average_f1score
