name: default

tasks:
  datastore:
    embedding: false
    index: false
  eval: 
    task_name: gen  # task name is used to load the eval data. Options:["perplexity", "lm-eval"].
    search: false  # search top-k offline
    merge_search: false  # merge searched results from multiple sources
    inference: false  # run inference with LM

model:
  sparse_retriever: null  # choices: null, bm25
  
  datastore_tokenizer: GritLM/GritLM-7B
  query_tokenizer: GritLM/GritLM-7B
  query_encoder: GritLM/GritLM-7B
  datastore_encoder: GritLM/GritLM-7B
  lm_model: meta-llama/Llama-3.1-8B-Instruct    # LM for evaluation


datastore:
  domain: massiveds-math
  raw_data_path: /weka_data/xinxil/private-retrieval-lm/deduped_dataset_sampled
  chunk_size: 256  # chunk size in number of words
  
  embedding:
    raw_data_path: ${datastore.raw_data_path}
    output_dir: /weka_data/xinxil/private-retrieval-lm/ours_v1_mini_grit
    chunk_size: ${datastore.chunk_size}
    keep_last_chunk: true
    passages_dir: /weka_data/xinxil/private-retrieval-lm/ours_v1_mini/passages/
    max_files_per_shard: 1

    per_gpu_batch_size: 512
    passage_maxlength: ${datastore.chunk_size}  # need to set to a larger num than chunk size
    model_name_or_path: ${model.datastore_encoder}
    tokenizer: ${model.datastore_tokenizer}
    no_fp16: False
    no_title: False
    lowercase: False
    normalize_text: False
    fields_to_add: [] #"subreddit"]
    # logloc: output #was using this to write passages to a log for sanity checking / debugging


    prefix: "passages"
    embedding_dir: /weka_data/xinxil/private-retrieval-lm/ours_v1_mini_grit/embeddings/
    use_saved_if_exists: true

  index:
    raw_data_path: ${datastore.raw_data_path}
    chunk_size: ${datastore.chunk_size}
    passages_embeddings: ${datastore.embedding.embedding_dir}/*.pkl
    num_subsampled_embedding_files: -1    # Number of subsampled embeddings, use all if pass -1, not supported yet, assume use all embeddings in the dir
    index_shard_ids: null
    save_or_load_index: True
    no_fp16: False
    index_type: IVFPQ
    sample_train_size: 100000
    indexing_batch_size: 1000000
    projection_size: 4096
    probe: 256 #64 #8
    ncentroids: 2048 # 4096 in silo 
    n_subquantizers: 256 # 64 in silo    # Number of subquantizer used for vector quantization, if 0 flat index is used; introducing compression rate of embedding_dimension / n_subquantizers
    n_bits: 8    # Number of bits per subquantizer, introducing compression rate of embedding_precision / n_bits
    overwrite: false


evaluation:
  domain: mmlu
  search:
    n_docs: 1200
    per_gpu_batch_size: 32    # Batch size for query encoding
    question_maxlength: 512    # Maximum number of tokens in a question
    lowercase: False
    normalize_text: False
    overwrite: true    # Overwrite the search results if exist
    # merge_multi_index_results: true  # Merge the searched results by multiple shards (same source)
    # merge_multi_source_results: false  # Merge the searched results by multiple sources
    paths_to_merge: null  # provide a txt file where each line is a file with searched results to merge when merge_multi_source_results is set to True
    merged_path: null  # path to save the multi-source merged results
    topk_subsample_p: 1  # subsample from the top-k with coin flipping with prob p if p < 1
    subsample_seed: 1000
    rerank_method: null  # rerank the results based on the method specified here (supported: lexical) *currently only support multi-dource situation
    answer_path: null  # path to load answers for lm-eval
    rerank_n_docs: null  # number of documents used for reranking, set to null if not removing any data
    use_saved_dedup_data: false  # reuse the saved dedupped data for efficient subsampling
  data:
    eval_data: s3://ai2-llm/pretraining-data/sources/ds-olmo-data/oracle_retrieval/mmlu/out/cot_queries.jsonl
    max_eval_data_seq_length: 1024
    eval_stride: 512
    merge: True
    num_eval_samples: null    # Number of evaluation samples, pass null to evaluate on all samples
    seed: 310    # Random seed for subsampling
  concate_k: 3    # Number of retrieved passages for concatenation, 0 means LM-only
  max_retrieval_len: 1024
  calibration_out_dir: null
  eval_output_dir: ${datastore.embedding.output_dir}/retrieved_results/
  results_only_log_file: scaling_out/test_massiveds_ppl.log
  debug_mode: false
  decontamination: false
  contamination_threshold: 32
  decontamination_method: longest
  use_continuation: false
  use_both_doc_and_continuation: false


hydra:
  job_logging:
    root:
      level: INFO
      handlers: [console, file]
    handlers:
      console:
        class: logging.StreamHandler
        stream: ext://sys.stdout
        formatter: simple
      file:
        class: logging.FileHandler
        filename: run.log
        formatter: simple
        mode: a
    formatters:
      simple:
        format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
