name: SSH into our runners

on:
  workflow_dispatch:
    inputs:
      runner_type:
        description: 'Type of runner to test (a10 or t4)'
        required: true 
      docker_image:
        description: 'Name of the Docker image'
        required: true
      num_gpus:
        description: 'Type of the number of gpus to use (`single` or `multi`)'
        required: true

env:
  HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
  HF_HOME: /mnt/cache 
  TRANSFORMERS_IS_CI: yes 
  OMP_NUM_THREADS: 8 
  MKL_NUM_THREADS: 8 
  RUN_SLOW: yes # For gated repositories, we still need to agree to share information on the Hub repo. page in order to get access. # This token is created under the bot `hf-transformers-bot`. 
  SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }} 
  TF_FORCE_GPU_ALLOW_GROWTH: true 
  CUDA_VISIBLE_DEVICES: 0,1
  RUN_PT_TF_CROSS_TESTS: 1

jobs:
  get_runner:
    name: "Get runner to use"
    runs-on: ubuntu-22.04
    outputs:
      RUNNER: ${{ steps.set_runner.outputs.RUNNER }}
    steps:
      - name: Get runner to use
        shell: bash
        run: |
          if [[ "${{ github.event.inputs.num_gpus }}" == "single" && "${{ github.event.inputs.runner_type }}" == "t4" ]]; then
            echo "RUNNER=aws-g4dn-2xlarge-cache" >> $GITHUB_ENV
          elif [[ "${{ github.event.inputs.num_gpus }}" == "multi" && "${{ github.event.inputs.runner_type }}" == "t4" ]]; then
            echo "RUNNER=aws-g4dn-12xlarge-cache" >> $GITHUB_ENV
          elif [[ "${{ github.event.inputs.num_gpus }}" == "single" && "${{ github.event.inputs.runner_type }}" == "a10" ]]; then
            echo "RUNNER=aws-g5-4xlarge-cache" >> $GITHUB_ENV
          elif [[ "${{ github.event.inputs.num_gpus }}" == "multi" && "${{ github.event.inputs.runner_type }}" == "a10" ]]; then
            echo "RUNNER=aws-g5-12xlarge-cache" >> $GITHUB_ENV
          else
            echo "RUNNER=" >> $GITHUB_ENV
          fi

      - name: Set runner to use
        id: set_runner
        run: |
          echo ${{ env.RUNNER }}
          echo "RUNNER=${{ env.RUNNER }}" >> $GITHUB_OUTPUT

  ssh_runner:
    name: "SSH"
    needs: get_runner
    runs-on:
      group: ${{ needs.get_runner.outputs.RUNNER }}
    container:
      image: ${{ github.event.inputs.docker_image }}
      options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

    steps:
      - name: Update clone
        working-directory: /transformers
        run: |
          git fetch && git checkout ${{ github.sha }}

      - name: Cleanup
        working-directory: /transformers
        run: |
          rm -rf tests/__pycache__
          rm -rf tests/models/__pycache__
          rm -rf reports

      - name: Show installed libraries and their versions
        working-directory: /transformers
        run: pip freeze
      
      - name: NVIDIA-SMI
        run: |
          nvidia-smi

      - name: Store Slack infos
        #because the SSH can be enabled dynamically if the workflow failed, so we need to store slack infos to be able to retrieve them during the waitforssh step
        shell: bash
        run: |
          echo "${{ github.actor }}"
          github_actor=${{ github.actor }}
          github_actor=${github_actor/'-'/'_'}
          echo "$github_actor"
          echo "github_actor=$github_actor" >> $GITHUB_ENV

      - name: Store Slack infos
        #because the SSH can be enabled dynamically if the workflow failed, so we need to store slack infos to be able to retrieve them during the waitforssh step
        shell: bash
        run: |
          echo "${{ env.github_actor }}"
          if [ "${{ secrets[format('{0}_{1}', env.github_actor, 'SLACK_ID')] }}" != "" ]; then
            echo "SLACKCHANNEL=${{ secrets[format('{0}_{1}', env.github_actor, 'SLACK_ID')] }}" >> $GITHUB_ENV
          else
            echo "SLACKCHANNEL=${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}" >> $GITHUB_ENV
          fi

      - name: Tailscale # In order to be able to SSH when a test fails
        uses: huggingface/tailscale-action@main
        with:
          authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
          slackChannel: ${{ env.SLACKCHANNEL }}
          slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
          waitForSSH: true
          sshTimeout: 15m
