from tabulate import tabulate

# 64, 1 -- 8 GPUs H100
FPS_200k = 99e3

# 128, 1 -- 8 GPUs H100
FPS_500k = 97e3

# 192, 2 -- 8 GPUs H100
FPS_1M = 97e3

# 256, 2 -- 8 GPUs H100
FPS_2M = 97e3

# 320, 3 -- 8 GPUs H100
FPS_5M = 96e3

# 384, 4 -- 8 GPUs H100
FPS_10M = 94e3

# 512, 5 -- 8 GPUs H100
FPS_20M = 90e3

# 704, 7 -- 8 GPUs H100
FPS_50M = 82e3

# 896, 8 -- 8 GPUs H100
FPS_100M = 74e3

# 1088, 10 -- 8 GPUs H100, gradient acc: 2
FPS_200M = 61e3

MODEL_SIZES = [
    "200e3",
    "500e3",
    "1e6",
    "2e6",
    "5e6",
    "10e6",
    "20e6",
    "50e6",
    "100e6",
    "200e6"
]

FLOPS = [
    "1e14",
    "2e14",
    "5e14",
    "1e15",
    "2e15",
    "5e15",
    "1e16",
    "2e16",
    "5e16",
    "1e17",
    "2e17",
    "5e17",
    "1e18",
    "2e18",
    "3e18",
    "4e18",
    "5e18"
]

MODEL_SIZE_TO_FPS = {
    "200e3": FPS_200k,
    "500e3": FPS_500k,
    "1e6": FPS_1M,
    "2e6": FPS_2M,
    "5e6": FPS_5M,
    "10e6": FPS_10M,
    "20e6": FPS_20M,
    "50e6": FPS_50M,
    "100e6": FPS_100M,
    "200e6": FPS_200M
}


# Initialize table with headers
print('*************')
print('TIME IN HOURS')
print('*************')
table = [[""] + MODEL_SIZES]

# Fill table with 'x' entries
for flops in FLOPS:
    row = [flops]
    for model_size in MODEL_SIZES:
        samples = float(flops) / (6 * float(model_size))
        fps = MODEL_SIZE_TO_FPS[model_size]
        row.append(f"{samples / fps / 60 / 60:.1f}")
    table.append(row)

# Print table using tabulate
print(tabulate(table, headers="firstrow", tablefmt="pretty"))


print()
print('********')
print('SAMPLES')
print('********')
samples_table = [[""] + MODEL_SIZES]

# Fill table with 'x' entries
for flops in FLOPS:
    row = [flops]
    for model_size in MODEL_SIZES:
        samples = float(flops) / (6 * float(model_size))
        row.append(f"{samples:.2e}")
    samples_table.append(row)

# Print table using tabulate
print(tabulate(samples_table, headers="firstrow", tablefmt="pretty"))


print()
print('**************')
print('GRADIENT STEPS')
print('**************')
grad_table = [[""] + MODEL_SIZES]

# Fill table with 'x' entries
for flops in FLOPS:
    row = [flops]
    for model_size in MODEL_SIZES:
        samples = float(flops) / (6 * float(model_size))
        row.append(f"{samples / (8192 * 32):.1f}")
    grad_table.append(row)

# Print table using tabulate
print(tabulate(grad_table, headers="firstrow", tablefmt="pretty"))
