import argparse
from distutils.util import strtobool
import os


def config():
    parser = argparse.ArgumentParser(
        prog='llmtelora',
        description='...',
        epilog = 'Authors ©')
    
    parser.add_argument('--name',
        type=str,
        help='Name of the computation (it relates to the subfolder name in the "result" folder)',
        default='base')
    parser.add_argument('--mode',
        type=str,
        help='Kind of the model: "bs" (original model), "cp" (multi-head model with CP-based heads), "lr2" (low-rank based model with two heads), "mt" (multi-token) or "tt" (tensor-train-based optimization)',
        choices=['bs', 'cp', 'lr2', 'mt', 'tt', 'cp_sparse'],
        default='bs')
    parser.add_argument('--seed',
        type=int,
        help='Initial (global) random seed value',
        default=42)
    parser.add_argument('--save_model',
        type=lambda x: bool(strtobool(x)),
        help='Do we auto-save model after train (it make sense only if "optinet.py" script is called)',
        nargs="?",
        const=True,
        default=False)
    parser.add_argument('--lr',
        type=float,
        help='Learning rate for the model training',
        default=1.E-3)
    parser.add_argument('--lr_min',
        type=float,
        help='Learning rate minimum value for the model training',
        default=5.E-5)
    parser.add_argument('--batch_trn',
        type=int,
        help='Batch size while model training',
        default=6)
    parser.add_argument('--batch_tst',
        type=int,
        help='Batch size while model testing',
        default=6)
    parser.add_argument('--n_layer',
        type=int,
        help='Number of model layers',
        default=2)
    parser.add_argument('--n_head',
        type=int,
        help='Number of model heads',
        default=2)
    parser.add_argument('--n_embd',
        type=int,
        help='Number of model embeddings',
        default=500)
    parser.add_argument('--block_size',
        type=int,
        help='The block size for the model (this parameter should not be changed)',
        default=1024)
    parser.add_argument('--vocab_size',
        type=int,
        help='The vocabulary size for the model (this parameter should not be changed)',
        default=50304)
    parser.add_argument('--dropout',
        type=float,
        help='The value of dropout for the model',
        default=0.1)
    parser.add_argument('--warmup_steps',
        type=int,
        help='Number of warmup steps while training',
        default=10)
    parser.add_argument('--epochs',
        type=int,
        help='Number of training iterations (epochs)',
        default=50)
    parser.add_argument('--grad_acc_steps',
        type=int,
        help='Number of gradient accumulation steps while training',
        default=1)
    parser.add_argument('--prompt_demo',
        type=str,
        help='Simple prompt for model demo after training',
        default='There was a')
    parser.add_argument('--d',
        type=int,
        help='Dimension (output sequence length) for the low-rank / TT-based model or number of heads for multi-head model',
        default=4)
    parser.add_argument('--r',
        type=int,
        help='TT-rank for the TT-based model or rank for low-rank model',
        default=3)
    parser.add_argument('--gpu',
        type=int,
        help='Optional number of the GPU for computation',
        default=None)
    parser.add_argument('--rewrite',
        type=lambda x: bool(strtobool(x)),
        help='Do we auto delete the computation with the same name',
        nargs="?",
        const=True,
        default=False)
    parser.add_argument('--from-pretrained', 
        type=str, 
        default=None, 
        help='Path to pretrained model weights')
    parser.add_argument('--config-override',
        metavar="KEY=VALUE",
        nargs='+',
        help='Override the config parameters'),
    parser.add_argument('--temp', 
        type=float, 
        default=1.0, 
        help='Temperature for text generation')

    if 'JPY_PARENT_PID' in os.environ:
        # Jupyter can not use the console arguments:
        return parser.parse_args([])
    else:
        return parser.parse_args()