const SEED = [1111, 2222, 3333];
import pAll from "p-all";
import runSingleTrain, { TrainArgs } from "./train.mjs";
import _ from "lodash";

const GPUS = parseInt(process.argv[2]);
const ENV = process.argv[3];

const args = [
  {
    numCenters: 1,
  },
  {
    numCenters: 2,
  },
  {
    numCenters: 5,
  },
  {
    numCenters: 10,
  },
  {
    numCenters: 32,
  },
  {
    numCenters: 48,
  },
  {
    numCenters: 64,
  },
  {
    numCenters: 96,
  },
];

async function runAll() {
  let allCmds = args.reduce<Partial<TrainArgs>[]>((acc, { numCenters }) => {
    let to_be_added: Partial<TrainArgs>[] = [];
    for (let j = 0; j < 3; j++) {
      to_be_added.push({
        numCenters,
        seed: SEED[j],
        skip: false,
        evalRatio: 4,
      });
    }
    return [...acc, ...to_be_added];
  }, []);

  allCmds = _.shuffle(allCmds);

  console.log(
    `${ENV} in total, ${allCmds.length} commands are: ${JSON.stringify(
      allCmds
    )}`
  );
  await pAll(
    allCmds.map((arg) => () => runSingleTrain("modality", arg)),
    { concurrency: GPUS }
  );
}

await runAll();
