# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import logging
import os
import subprocess
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    import tomllib
except ModuleNotFoundError:
    import tomli as tomllib


@dataclass
class OverrideDefinitions:
    """
    This class is used to define the override definitions for the integration tests.
    """

    override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
    test_descr: str = "default"
    test_name: str = "default"
    requires_seed_checkpoint: bool = False
    ngpu: int = 4
    model_flavor: str = "debugmodel"

    def __repr__(self):
        return self.test_descr


def build_test_list(args):
    """
    key is the config file name and value is a list of OverrideDefinitions
    that is used to generate variations of integration tests based on the
    same root config file.
    """
    conf = args.test
    integration_tests_flavors = defaultdict(list)

    if conf == "debug_model":
        integration_tests_flavors["debug_model.toml"] = [
            OverrideDefinitions(
                [
                    [
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:287,0:288,0:546,0:1781,0:1782,0:2039}'"
                    ],
                ],
                "debug_model",
                "debug_model",
                ngpu = 1,
                model_flavor = "debugmodel",
            ),
        ]
    elif conf == "8b_128_2k_4ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 32",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 128",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:727,0:48464,1:700,1:1237,1:49377,1:49913}'"
                    ],
                ],
                "8b_128_2k_4ubs",
                "8b_128_2k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_64_1k_8ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 1024",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:615,0:13680,1:612,1:6149,1:13921,1:14457}'"
                    ],
                ],
                "8b_64_1k_8ubs",
                "8b_64_1k_8ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_64_2k_4ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:631,0:25232,1:620,1:1157,1:25697,1:26233}'"
                    ],
                ],
                "8b_64_2k_4ubs",
                "8b_64_2k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_64_3k_4ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 3072",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{}'"
                    ],
                ],
                "8b_64_3k_4ubs",
                "8b_64_3k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_64_4k_2ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 32",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:663,1:636,1:1173,0:48336,1:49249,1:49685}'"
                    ],
                ],
                "8b_64_4k_2ubs",
                "8b_64_4k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_32_2k_4ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 32",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:583,1:580,1:1117,0:13616,1:13857,1:14393}'"
                    ],
                ],
                "8b_32_2k_4ubs",
                "8b_32_2k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "8b_16_2k_4ubs":
        integration_tests_flavors["llama3_8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.16",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 4",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:559,1:560,1:1097,0:7808,1:7937,1:8473}'"
                    ],
                ],
                "8b_16_2k_4ubs",
                "8b_16_2k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "8B",
            ),
        ]
    elif conf == "70b_8_2k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 4",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 8",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1319,0:19360,1:2649,1:1320,1:20817,1:19489}'"
                    ],
                ],
                "70b_8_2k_2ubs",
                "70b_8_2k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_8_2k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 8",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1327,1:1324,1:2653,0:33776,1:35345,1:34017}'"
                    ],
                ],
                "70b_8_2k_1ubs",
                "70b_8_2k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_2k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1351,1:1340,1:2669,0:62624,1:63089,1:64417}'"
                    ],
                ],
                "70b_16_2k_1ubs",
                "70b_16_2k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_2k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1335,0:33792,1:2661,1:1332,1:35361,1:34033}'"
                    ],
                ],
                "70b_16_2k_2ubs",
                "70b_16_2k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_1k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 1024",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1351,1:1340,1:2669,0:62624,1:63089,1:64417}'"
                    ],
                ],
                "70b_16_1k_1ubs",
                "70b_16_1k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_1k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 1024",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1335,0:33792,1:2661,1:1332,1:35361,1:34033}'"
                    ],
                ],
                "70b_16_1k_2ubs",
                "70b_16_1k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_1k_4ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 4",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 1024",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1327,0:19376,1:2657,1:1328,1:20833,1:19505}'"
                    ],
                ],
                "70b_16_1k_4ubs",
                "70b_16_1k_4ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_3k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 3072",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1351,1:1340,1:2669,0:62624,1:63089,1:64417}'"
                    ],
                ],
                "70b_16_3k_1ubs",
                "70b_16_3k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_3k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 3072",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{1:6292,1:1332,1:35361,1:34033,0:6250,0:1335,0:38706,0:33792}'"
                    ],
                ],
                "70b_16_3k_2ubs",
                "70b_16_3k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_4k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1351,1:1340,1:2669,0:62624,1:63089,1:64417}'"
                    ],
                ],
                "70b_16_4k_1ubs",
                "70b_16_4k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_16_4k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{1:6292,1:1332,1:35361,1:34033,0:6250,0:1335,0:38706,0:33792}'"
                    ],
                ],
                "70b_16_4k_2ubs",
                "70b_16_4k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_32_2k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 32",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 32",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1399,1:2701,1:1372,0:120320,1:122561,1:121233}'"
                    ],
                ],
                "70b_32_2k_1ubs",
                "70b_32_2k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_32_2k_2ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 32",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1367,0:62656,1:2685,1:1356,1:64449,1:63121}'"
                    ],
                ],
                "70b_32_2k_2ubs",
                "70b_32_2k_2ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "70b_64_2k_1ubs":
        integration_tests_flavors["llama3_70b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.40",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 64",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:1495,1:2765,1:1436,0:235712,1:238849,1:237521}'"
                    ],
                ],
                "70b_64_2k_1ubs",
                "70b_64_2k_1ubs",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "70B",
            ),
        ]
    elif conf == "granite_8b_bs16_seq4k_ubs4":
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 4",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 16",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:623,1:2878,1:624,0:8772,1:8901,1:9503}'"
                    ],
                ],
                "granite_8b_bs16_seq4k_ubs4",
                "granite_8b_bs16_seq4k_ubs4",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    elif conf == "granite_8b_bs32_seq4k_ubs4":
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 32",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:647,1:1247,1:644,0:15300,1:16143,1:15541}'"
                    ],
                ],
                "granite_8b_bs32_seq4k_ubs4",
                "granite_8b_bs32_seq4k_ubs4",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    elif conf == "granite_8b_bs64_seq4k_ubs4":
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 16",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:695,1:1287,1:684,0:28356,1:29423,1:28821}'"
                    ],
                ],
                "granite_8b_bs64_seq4k_ubs4",
                "granite_8b_bs64_seq4k_ubs4",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    elif conf == "granite_8b_bs64_seq2k_ubs8":
        print("!")
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 8",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 2048",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:679,1:1279,1:676,0:15364,1:16207,1:15605}'"
                    ],
                ],
                "granite_8b_bs64_seq2k_ubs8",
                "granite_8b_bs64_seq2k_ubs8",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    elif conf == "granite_8b_bs64_seq8k_ubs2":
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 32",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 64",
                        "--training.seq_len 8192",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:727,1:1303,1:700,0:54340,1:55855,1:55253}'"
                    ],
                ],
                "granite_8b_bs64_seq8k_ubs2",
                "granite_8b_bs64_seq8k_ubs2",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    elif conf == "granite_8b_bs128_seq4k_ubs4":
        integration_tests_flavors["granite-code-8b.toml"] = [
            OverrideDefinitions(
                [
                    [
                        "--experimental.pipeline_parallel_degree 2",
                        "--experimental.pipeline_parallel_split_points layers.18",
                        "--experimental.pipeline_parallel_schedule 1f1b",
                        "--experimental.pipeline_parallel_microbatches 32",
                        "--training.tensor_parallel_degree 1",
                        "--training.batch_size 128",
                        "--training.seq_len 4096",
                        "--training.dataset \"c4_test\"",
                        "--model.tokenizer_path \"./test/assets/test_tiktoken.model\"",
                        f"--offloading.print_info {args.print_info}",
                        f"--offloading.liveness_path {args.liveness_path}",
                        f"--offloading.plan_path {args.plan_path}",
                        f"--offloading.emu_bandwidth {args.offloadingbw}",
                        "--offloading.skip_list '{0:791,1:1367,1:764,0:54468,1:55983,1:55381}'"
                    ],
                ],
                "granite_8b_bs128_seq4k_ubs4",
                "granite_8b_bs128_seq4k_ubs4",
                requires_seed_checkpoint=False,
                ngpu = 2,
                model_flavor = "Granite-8B",
            ),   
        ]
    
    
    return integration_tests_flavors

def read_output(process, output_type):
    while True:
        line = process.readline()
        if not line:
            break
        print(f"[{output_type}] {line.strip()}")

def _run_cmd(cmd):
    process = subprocess.Popen(
        [cmd],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        shell=True,
    )
    stdout_thread = threading.Thread(target=read_output, args=(process.stdout, "STDOUT"))
    stderr_thread = threading.Thread(target=read_output, args=(process.stderr, "STDERR"))
    stdout_thread.start()
    stderr_thread.start()
    stdout_thread.join()
    stderr_thread.join()
    process.wait()
    return process

def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
    # run_test supports sequence of tests.
    test_name = test_flavor.test_name
    dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
    model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
    all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

    if test_flavor.requires_seed_checkpoint:
        cmd = f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg} {model_flavor_arg}"
        logger.info(
            f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
        )
        result = _run_cmd(cmd)
        logger.info(result.stdout)

    for override_arg in test_flavor.override_args:
        cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
        if test_name == "fsdp2_mem_tracker":
            cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
        cmd += " " + dump_folder_arg
        cmd += " " + model_flavor_arg
        if override_arg:
            cmd += " " + " ".join(override_arg)
        logger.info(
            f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
        )
        result = _run_cmd(cmd)
        # logger.info(result.stdout)
        if result.returncode != 0:
            raise Exception(
                f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
            )


def run_tests(args):
    integration_tests_flavors = build_test_list(args)
    for config_file in os.listdir(args.config_dir):
        if config_file.endswith(".toml"):
            full_path = os.path.join(args.config_dir, config_file)
            with open(full_path, "rb") as f:
                config = tomllib.load(f)
                is_integration_test = config["job"].get(
                    "use_for_integration_test", False
                )
                if is_integration_test:
                    for test_flavor in integration_tests_flavors[config_file]:
                        if args.test == "all" or test_flavor.test_name == args.test:
                            if args.ngpu < test_flavor.ngpu:
                                logger.info(
                                    f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus,"
                                    f" because --ngpu arg is {args.ngpu}"
                                )
                            elif args.ngpu == 8 and test_flavor.ngpu != 8:
                                logger.info(
                                    f"Skipping non-8gpu test {test_flavor.test_name} on 8-gpu runner"
                                )
                            else:
                                run_test(test_flavor, full_path, args.output_dir)
                                test_name = test_flavor.test_name
                                folder = f"{args.output_dir}/{test_name}"
                                cmd = f"mv *.liveness {folder}"
                                logger.info(cmd)
                                result = _run_cmd(cmd)
                                


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("output_dir")
    parser.add_argument("--config_dir", default="./train_configs")
    parser.add_argument(
        "--test",
        default="all",
        help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
    )
    parser.add_argument("--ngpu", default=4, type=int)
    parser.add_argument("--print_info", default=False, type=lambda x: x.lower() in ("true", "1", "yes", "on"))
    parser.add_argument("--offloadingbw", default=-1.0, type=float)
    parser.add_argument("--liveness_path", default="unknown", type=str)
    parser.add_argument("--plan_path", default="unknown", type=str)
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    if os.listdir(args.output_dir):
        raise RuntimeError("Please provide an empty output directory.")
    run_tests(args)


if __name__ == "__main__":
    main()
