              
                                                      
                                                                 

import asyncio
import io
import sys
import threading
import uuid
import time

import httpx
import torch

from gpatch.core.utils import print_with_rank_and_datetime


async def call_once_rpc(
    url,
    req_dict,
    timeout=60,
    max_retry=None,
    retry_backoff_wait=1,
    retry_backoff_max_wait=16,                             
):
    assert max_retry is None

    if not torch.distributed.is_initialized():
        uniq_id = str(uuid.uuid4().hex)
    else:
        uniq_id = f'{torch.distributed.get_rank()}-{str(uuid.uuid4().hex)}'
    ts = time.time()
    assert 'once_rpc' not in req_dict
    req_dict['once_rpc'] = {
        'cmd': 'rpc',
        'uniq_id': uniq_id,
        'ts': ts,
    }
    fio = io.BytesIO()
    torch.save(req_dict, fio)
    req_dict.pop('once_rpc')
    req_data = fio.getvalue()

    async def rpc_co():
        backoff = retry_backoff_wait
        retry_cnt = -1

        async with httpx.AsyncClient() as cli:
            while True:
                try:
                    retry_cnt += 1
                    resp = await cli.post(url, content=req_data, timeout=timeout)
                    resp.raise_for_status()
                    resp_dict = torch.load(io.BytesIO(resp.content), weights_only=False)
                    if not resp_dict['once_rpc']['ready']:
                        backoff = min(backoff * 2, retry_backoff_max_wait)
                        await asyncio.sleep(backoff)
                    else:
                        resp_dict.pop('once_rpc')
                        return resp_dict
                except httpx.HTTPError:
                                                            
                    backoff = min(backoff * 2, retry_backoff_max_wait)
                    await asyncio.sleep(backoff)

                if retry_cnt > 1 and retry_cnt % 100 == 0:
                    print_with_rank_and_datetime(f"once rpc polling {retry_cnt} {url=}")

    ret_resp_dict = await rpc_co()

    req_dict = {
        'once_rpc': {
            'cmd': 'clean_up',
            'uniq_id': uniq_id,
            'ts': ts,
        },
    }
    fio = io.BytesIO()
    torch.save(req_dict, fio)
    req_data = fio.getvalue()

    async def clean_up_co():
        backoff = retry_backoff_wait

        async with httpx.AsyncClient() as cli:
            while True:
                try:
                    resp = await cli.post(url, content=req_data, timeout=timeout)
                    resp_dict = torch.load(io.BytesIO(resp.content), weights_only=False)
                    assert resp_dict['once_rpc']['ret'] == 'ok'
                    break
                except httpx.HTTPError:
                    backoff = min(backoff * 2, retry_backoff_max_wait)
                    await asyncio.sleep(backoff)

    await clean_up_co()

    return ret_resp_dict


if __name__ == '__main__':
    ep_ip = '127.0.0.1'
    ep_port = 8080
    url = f'http://{ep_ip}:{ep_port}/generate'

    x = asyncio.run(call_once_rpc(url, {'a': 1, 'b': 2}, timeout=1))
    print(x)
