from classifiers.core import AttributionClassifier


base_path = "datasets"
output_path = "results"
log_dir = "logs"
def main():

    model_configs=[
        "gemma3_4B",
        "gemma3_1B",
        "gemma3_12B",
        "gemma3_27B",
        "llama3.2_1B",
        "llama3.2_3B",
        "llama3.3_70B"
    ]
    domain_configs=[
        "IAS_EAS_NEU",
        "GAS_SPAS_NEU",
        "SAS_UAS_NEU"
    ]
    shot_configs=[
        "fewshot"
        # "zeroshot"
    ]
    for model_config in model_configs:
        for domain_config in domain_configs:
            for shot_config in shot_configs:
                print(f"Running {model_config} {domain_config} {shot_config}")
                AttributionClassifier.run_classification(
                    model_name=model_config,
                    attribution_domain=domain_config,
                    shot=shot_config,
                    start_idx=1,
                    end_idx=10,
                    runs=3,
                    base_path=base_path,
                    output_path=output_path,
                    log_dir=log_dir
                )
                print(f"{model_config} {domain_config} {shot_config} done")
    # AttributionClassifier.run_classification(
    #                     model_name="llama3.2_3B",
    #                     attribution_domain="SAS_UAS_NEU",
    #                     shot="fewshot",
    #                     start_idx=1,
    #                     end_idx=10,
    #                     runs=3,
    #                     base_path=base_path,
    #                     output_path=output_path,
    #                     log_dir=log_dir
    #                 )
    # AttributionClassifier.run_classification(
    #                     model_name="llama3.3_70B",
    #                     attribution_domain="GAS_SPAS_NEU",
    #                     shot="zeroshot",
    #                     start_idx=1,
    #                     end_idx=10,
    #                     runs=3,
    #                     base_path=base_path,
    #                     output_path=output_path,
    #                     log_dir=log_dir
    #                 )


if __name__ == "__main__":
    main()


# MODEL_CONFIGS = {
#     "llama3.2_1B": {
#         "name": "llama3.2:1b",
#         "suffix": ""
#     },
#     "llama3.2_3B": {
#         "name": "llama3.2",
#         "suffix": "_3b"
#     },
#     "llama3.3_70B": {
#         "name": "llama3.3",
#         "suffix": "_70b"
#     },
#     "gemma3_1B": {
#         "name": "gemma3:1b",
#         "suffix": "_1b"
#     },
#     "gemma3_4B": {
#         "name": "gemma3:4b",
#         "suffix": "_4b"
#     },
#     "gemma3_12B": {
#         "name": "gemma3:12b",
#         "suffix": "_12b"
#     },
#     "gemma3_27B": {
#         "name": "gemma3:27b",
#         "suffix": "_27b"
#     }
# }

