import hues
from core.model.datasets import Dataset
from core.controller import Controller
import argparse
import pandas as pd


def _build_parse():
    parser = argparse.ArgumentParser(description='Arguments to run the algorithm')
    parser.add_argument('--score', type=str, help='Score type: acc or f1', default='f1')
    parser.add_argument('--dataset', type=str, help='Datasets: GSM8k or ArSarcasm')
    parser.add_argument('--num_iterations', type=int, help='num_iterations', default=20)
    parser.add_argument('--population_size', type=int, help='num_iterations', default=15)
    return parser.parse_args()


def _load_data(path_or_name: str):
    # 根据数据集名称或路径加载数据
    if path_or_name == 'GSM8k':
        return Dataset('./datasets/GSM8k.csv')
    elif path_or_name == 'ArSarcasm':
        return Dataset('./datasets/ArSarcasm.csv')
    else:  # or 直接读取路径
        return Dataset(path_or_name)


if __name__ == '__main__':
    # 解析命令行参数
    arg = _build_parse()
    hues.info(f"Program running with args：{arg}")

    # 加载数据集
    ds = _load_data(arg.dataset)

    # 设置初始提示和问题描述
    initial_prompt = '''
    ## Task
    Is this tweet sarcastic?
    ## Output format
    Answer Yes or No as labels.
    ## Prediction
    Text: {input}                          
    Label:
    '''

    problem_description = "Determine whether a given statement is sarcastic."

    # 创建并运行 Controller 实例
    controller = Controller(
        initial_prompt=initial_prompt,
        problem_description=problem_description,
        train_df=ds.train_set_df,
        test_df=ds.test_set_df,
        num_iterations=arg.num_iterations,
        population_size=arg.population_size,
        is_f1=bool(arg.score == 'f1')
    ).run()
