"""Script for training an encoder (mapping music embeddings to brain activity).

Encoding with MuLan16d:
python -m fmri2music.scripts.train_encoder \
    --model_name "mulan16d" \
    --reg_type "mono" \
    --train_target_emb "window10s-stride1s-mv109" \
    --subject_name "Subject01" \
    --device "-1"

Encoding with MuLan128d:
python -m fmri2music.scripts.train_encoder \
    --model_name "mulan128d" \
    --reg_type "mono" \
    --train_target_emb "window10s-stride1s-mv101" \
    --subject_name "Subject01" \
    --device "1"

Encoding with w2vbert:
python -m fmri2music.scripts.train_encoder \
    --model_name "w2vbert" \
    --reg_type "mono" \
    --train_target_emb "window5s-stride1.5s-w2vbert-avg" \
    --subject_name "Subject01"
    --device "1"

Encoding with hum2search-v11:
python -m fmri2music.scripts.train_encoder \
    --model_name "hum2search-v11" \
    --reg_type "mono" \
    --train_target_emb "window4s-stride1.5s-hum2search-v11" \
    --subject_name "Subject01"
    --device "1"

Encoding with hum2search-v13:
python -m fmri2music.scripts.train_encoder \
    --model_name "hum2search-v13" \
    --reg_type "mono" \
    --train_target_emb "window4s-stride1.5s-hum2search-v13" \
    --subject_name "Subject01"
    --device "1"

Encoding with SoundStream:
python -m fmri2music.scripts.train_encoder \
    --model_name "soundstream-1_5s" \
    --reg_type "mono" \  
    --train_target_emb "window1_5s-stride1_5s-soundstream-avg" \
    --subject_name "Subject01" \
    --device "1"

Encoding with Mulan16d-text:
python -m fmri2music.scripts.train_encoder \
    --model_name "mulan16d-text" \
    --reg_type "mono" \
    --train_target_emb "window15s-stride15s-mv101txt-avg" \
    --subject_name "Subject01" \
    --device "1"
    

Encoding with Multiple feature space:

python -m fmri2music.scripts.train_encoder \
    --model_name "mulan128d_mulan16d_w2vbert_soundstream" \
    --reg_type "multi" \
    --n_iter "200" \
    --train_target_emb "window15s-stride15s-mv101-avg" "window15s-stride15s-mv109-avg"  "window15s-stride15s-w2vbert-avg" "window15s-stride15s-soundstream-avg"\
    --subject_name "Subject01" \
    --device "-1"

python -m fmri2music.scripts.train_encoder \
    --model_name "mulan128d_mulan16d" \
    --reg_type "multi" \
    --n_iter "200" \
    --train_target_emb "window15s-stride15s-mv101-avg" "window15s-stride15s-mv109-avg" \
    --subject_name "Subject01" \
    --device "1"

"""

import argparse

from dotenv import find_dotenv, load_dotenv

from fmri2music import training_encoder, data_const


def main(args):
    """Main entrypoint of the script."""

    training_encoder.train_predictor(
        args.model_name,
        args.reg_type,
        args.n_iter,
        args.train_target_emb,
        args.subject_name,
        args.device,
    )


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(
        description="Train a model to predict brain activity from music embeddings."
    )

    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
        help="Name of the trained model.",
    )

    parser.add_argument(
        "--reg_type",
        choices=["mono", "multi"],
        required=True,
        help="Type of the regressor.",
    )

    parser.add_argument(
        "--n_iter",
        type=int,
        default=100,
        required=False,
        help="Number of random search interations.",
    )

    parser.add_argument(
        "--subject_name",
        choices=data_const.SUBJECTS,
        required=True,
        help="Name of the subject to train the model on.",
    )

    parser.add_argument(
        "--train_target_emb",
        nargs="*",
        type=str,
        required=True,
        help="Training target embedding .npz file name.",
    )

    parser.add_argument(
        "--device",
        type=int,
        required=True,
        help="Number of CUDA_VISIBLE_DEVICE. If you set -1, train regressor on CPU.",
    )

    main(parser.parse_args())
