import os
import requests
import torch
from megatron.core import parallel_state
from gpatch.core.utils import print_with_rank_and_datetime
from torch.multiprocessing import Process, Queue
import httpx
import asyncio
import io
import queue
import time

_grpo_biz_id = 16644
_cube_report_url = 'http://api.cube.wx.com/cube/report/reportbizdata?f=json'


async def do_report_async(report_data):
    headers = {"Content-Type": "application/json"}
    async with httpx.AsyncClient() as cli:
        timeout = httpx.Timeout(2)
        for _ in range(3):
            try:
                resp = await cli.post(
                    _cube_report_url, headers=headers, json=report_data, timeout=timeout
                )
                                                                                                       
                return resp
            except httpx.HTTPError as e:
                print_with_rank_and_datetime(f"[WECUBE] report fail, error {e}")
                await asyncio.sleep(1)
        return None


_global_que = Queue()
_max_batch_size = 512
_enable_report = False


def wecube_report_consumer():
    while True:
        report_batches = []
        for i in range(_max_batch_size):
            try:
                report_data = _global_que.get(block=False, timeout=None)
                report_batches.append(report_data)
            except queue.Empty as e:
                break

        if len(report_batches) > 0:

            async def report_co_func():
                cos = []
                for bi, report_data in enumerate(report_batches):
                    co = do_report_async(report_data)
                    cos.append(co)
                return await asyncio.gather(*cos)

            report_results = asyncio.run(report_co_func())
            num_succ_items = 0
            num_fail_items = 0
            for result in report_results:
                if result is not None:
                    num_succ_items += 1
                else:
                    num_fail_items += 1
            print_with_rank_and_datetime(
                f"[WECUBE] report batch, tot {len(report_batches)}, succ {num_succ_items} fail {num_fail_items}."
            )

        if len(report_batches) < _max_batch_size:
            time.sleep(1)


def init_wecube_reporter():
    global _enable_report
    _enable_report = True
    p = Process(target=wecube_report_consumer)
    p.start()
    print_with_rank_and_datetime(f"[WECUBE] wecube_reporter start.")


def report_ppo_metrics_async(metrics: dict):
    task = os.environ.get('GCORE_TASK_NAME', 'rl_default')
    report_data = {
        "biz_id": _grpo_biz_id,
        "task": task,
        "rank": torch.distributed.get_rank(),
        "dp_rank": parallel_state.get_data_parallel_rank(),
    }
    report_data.update(metrics)
    _global_que.put(report_data)
                                                                                     


def report_ppo_metrics_sync(metrics: dict):
                                                  
    task = os.environ.get('GCORE_TASK_NAME', 'rl_default')
    report_data = {
        "biz_id": _grpo_biz_id,
        "task": task,
        "rank": torch.distributed.get_rank(),
        "dp_rank": parallel_state.get_data_parallel_rank(),
    }
    report_data.update(metrics)
    headers = {"Content-Type": "application/json"}
    try:
        response = requests.post(_cube_report_url, headers=headers, timeout=2, json=report_data)
        if response.status_code == 200:
            pass
                                           
                                                                                                                                                                                                             
        else:
            print_with_rank_and_datetime(
                f"[WECUBE] report fail, status_code {response.status_code}"
            )
    except Exception as e:
        print_with_rank_and_datetime(f"[WECUBE] report fail, error {e=}")


def report_ppo_metrics(metrics: dict):
    if _enable_report:
        return report_ppo_metrics_async(metrics)
