from needle_haystack_similarity import ExperimentConfig


"""
Workflow:
1. Get the embeddings and ablation similarities here by running ``python queue_file.py`` with the proper changes.
    - If it's a new model, add it into ``constants.py``
2. Run ``python needle_haystack_similarity_plotting.py`` with the proper changes. 
     - If you have a huggingface model e.g. `"junnyu/roformer_chinese_base", switch out the `/`
"""


def main():
    modes = [
        "insert",
        "remove",
    ]
    models = [
        "mosaicml/mosaic-bert-base-seqlen-1024",
        "intfloat/e5-large-v2",

        "BAAI/bge-m3",
        # "nomic-ai/nomic-embed-text-v1.5",
        # "jinaai/jina-embeddings-v2-base-en",
        # "/teamspace/studios/this_studio/embedding-text-structure/project/_models/baai-tuned",
        # "_models/baai-tuned",
        # "BAAI/bge-small-en-v1.5",
        # "reaganjlee/baai-truncate-finetune", # this uses the bge-small-en-v1.5
        # "junnyu/roformer_chinese_base",
        # "google-bert/bert-base-uncased",
        # "dwzhu/e5rope-base",
        # "intfloat/e5-mistral-7b-instruct",
        # "mosaicml/mosaic-bert-base-seqlen-1024",
        # "intfloat/e5-large-v2",
    ]
    datasets = [
        "scientific_papers",
        "paul_graham",
        "amazon_polarity",
        "arguana",
        "reddit",
    ]

    queue = []
    for mode in modes:
        if mode == "insert":
            needle_sizes = [0.05, 0.1, 0.2, 0.5, 1]
            # needle_sizes = [0.2]
            needle_keywords = ["lorem"]
        else:
            needle_sizes = [0.05, 0.1, 0.2, 0.5]
            # needle_sizes = [0.2]
            needle_keywords = [None]

        for model in models:
            for dataset in datasets:
                queue.append(
                    {
                        "mode": mode,
                        "dataset_name": dataset,
                        # "num_examples": 50, # For fine-tuning to limit number of emaples
                        "num_examples": 200,
                        "needle_keywords": needle_keywords,
                        "needle_sizes": needle_sizes,
                        "needle_posns": [0, 0.25, 0.5, .75, 1],
                        # "needle_posns": [0, 0.5, 1],
                        "model_name": model,
                        "max_length": 2048,
                    }
                )

    for item in queue:
        config = ExperimentConfig(**item)
        config.run()


if __name__ == "__main__":
    main()
