const SEED = [1111, 2222, 3333];
import pAll from "p-all";
import runSingleTrain, { TrainArgs } from "./train.mjs";
import _ from "lodash";
const NETS = _.reverse([
  [2, 64], // 0.05M
  [2, 128], // 0.2 M
  [4, 64], // 0.1 M
  [4, 128], // 0.5 M
  [4, 256], // 2.11 M
  [4, 512], // 8.45 M
  [4, 1024], // 34 M
  [8, 256], // 4.22 M
  [8, 512], // 16.88 M
]);

const GPUS = parseInt(process.argv[2]);
const ENV = process.argv[3];

const args = [
  {
    shrink: 1,
    batchSize: 512,
    lrRatio: 1.0,
  },
  {
    shrink: 10,
    batchSize: 256,
    lrRatio: 1.5,
  },
  {
    shrink: 36,
    batchSize: 256,
    lrRatio: 1.5,
  },
  {
    shrink: 216,
    batchSize: 128,
    lrRatio: 2.0,
  },
  {
    shrink: 648,
    batchSize: 128,
    lrRatio: 2.0,
  },
  {
    shrink: 1296,
    batchSize: 64,
    lrRatio: 3.0,
  },
  {
    shrink: 2592,
    batchSize: 64,
    lrRatio: 3.5,
  },
  {
    shrink: 5184,
    batchSize: 64,
    lrRatio: 3.75,
  },
];

async function runAll() {
  let allCmds = args.reduce<Partial<TrainArgs>[]>(
    (acc, { shrink, batchSize, lrRatio }) => {
      let to_be_added: Partial<TrainArgs>[] = [];
      for (let i = 0; i < NETS.length; i++) {
        for (let j = 0; j < 3; j++) {
          to_be_added.push({
            shrink,
            lrRatio,
            batchSize,
            n: NETS[i][0],
            hiddenSize: NETS[i][1],
            seed: SEED[j],
            evalRatio: 4,
            skip: false,
          });
        }
      }
      return [...acc, ...to_be_added];
    },
    []
  );

  allCmds = _.shuffle(allCmds);

  console.log(
    `${ENV} in total, ${allCmds.length} commands are: ${JSON.stringify(
      allCmds
    )}`
  );
  await pAll(
    allCmds.map((arg) => () => runSingleTrain("scaling_law", arg)),
    { concurrency: GPUS }
  );
}

await runAll();
