import argparse

from loguru import logger

from vds_cfg import Configure
from vds_finetune import Finetune
from vds_pipeline import Pipeline
from vds_util import stabilize


def main():
    logger.info(f'{args=}')
    cfg = Configure(args)

    stabilize()
    if args.peft:
        finetune = Finetune(cfg)
        finetune.optimize_model()
    else:
        pipeline = Pipeline(cfg)
        pipeline.optimize_reprs()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default='webss', type=str, help='[webss] trec ...')
    parser.add_argument('--model', default='pythia-xs', type=str, help='[pythia-xs] gpt2 ...')
    parser.add_argument('--abla', default='sc', type=str, help='[sc] mm de4sc de4mm ...')
    parser.add_argument('--exp', default='standard.approach', type=str, help='[standard.approach] ...')
    parser.add_argument('--epoch-num', default='100', type=int, help='10 20 50 [100] ...')
    parser.add_argument('--batch-size', default='256', type=int, help='[256] 512 1024 ...')
    parser.add_argument('--peft', action='store_true', help='run peft, else run vai ...')
    parser.add_argument('--peft-vai', action='store_true', help='run peft with vai logits ...')
    parser.add_argument('--peft-algo', default='lora', type=str, help='[lora] ia3 ...')
    parser.add_argument('--peft-epoch-num', default='1', type=int, help='[1] 5 10 ...')
    parser.add_argument('--peft-batch-size', default='1', type=int, help='[1] 2 4 8 16 32 64 ...')
    args = parser.parse_args()

    main()
