              
                                                      
                                                                 

from dataclasses import dataclass
import asyncio
import io
import time
import math
import json
from typing import List, Dict, Any

import torch
import transformers

from megatron.core import mpu

from gpatch.core.utils import print_with_rank_and_datetime, check_rollout_batch
from gpatch.core.wecube import report_ppo_metrics
from gpatch.rpc import call_once_rpc


@dataclass
class GptPpoRmCriticClientV3:
    ep_ips: List[str]
    ep_ports: List[int]
    combine_rm_and_critic_server: bool
    timeout: int
    pad_token_id: int                                               
    tokenizer: transformers.PreTrainedTokenizer
    rpc_max_retries: int
    num_rm: int = 1
    ppo_debug_fake_rm_critic: bool = False
    ppo_value_truncate_head: bool = False

    def __post_init__(self):
        assert len(self.ep_ips) == len(self.ep_ports)
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        nep = len(self.ep_ips)
        self.infer_ep_rb_idx = math.ceil(nep / dp_size) * dp_rank % nep

        self.balance_dict = {}
        for ep_port in self.ep_ports:
            self.balance_dict[ep_port] = 0

        if torch.distributed.get_rank() == 0:
            print(f"GptPpoRmCriticClient.extra_reward_info {self.balance_dict}", flush=True)

    def wait_until_critic_server_is_ready(self):
        if self.ppo_debug_fake_rm_critic:
            return

        for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
            url = f'http://{ep_ip}:{ep_port}/heartbeat'
            co = call_once_rpc(url, {})
            asyncio.run(co)

    def setup(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/setup'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def sleep(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/sleep'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_rm_ppo_step_begin(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_rm_ppo_step_begin'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_rm_ppo_step_end(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_rm_ppo_step_end'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def pick_ep_idx(self, sample_idx):
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        rdp_size = len(self.ep_ips)
        if dp_size == rdp_size:
            ep_idx = dp_rank
        elif dp_size < rdp_size:
            ep_idx = dp_rank * (rdp_size // dp_size) + sample_idx
        else:
            ep_idx = dp_rank // (dp_size // rdp_size) + sample_idx
        ep_idx %= rdp_size
        return ep_idx

    async def issue_infer_rm_critic(self, ppo_step, sample_idx, batch: Dict[str, List[Any]]):
        assert check_rollout_batch(batch), "rb format error"
        dp_rank = mpu.get_data_parallel_rank()
        ep_idx = self.pick_ep_idx(sample_idx)
        ep_ip = self.ep_ips[ep_idx]
        ep_port = self.ep_ports[ep_idx]

        self.balance_dict[ep_port] += 1
        req_dict = {
            'actor_dp_rank': dp_rank,
            'ppo_step': ppo_step,
            'sample_idx': sample_idx,
        }
        for k in req_dict.keys():
            assert k not in batch
        req_dict.update(batch)

        url = f'http://{ep_ip}:{ep_port}/issue_infer_rm_critic'
        return await call_once_rpc(url, req_dict)

    async def get_infer_rm_critic_result(self, ppo_step, sample_idx, sampling_repeat):
        dp_rank = mpu.get_data_parallel_rank()
        ep_idx = self.pick_ep_idx(sample_idx)
        ep_ip = self.ep_ips[ep_idx]
        ep_port = self.ep_ports[ep_idx]

        req_dict = {
            'actor_dp_rank': dp_rank,
            'ppo_step': ppo_step,
            'sample_idx': sample_idx,
            'sampling_repeat': sampling_repeat,
        }

        url = f'http://{ep_ip}:{ep_port}/get_infer_rm_critic_result'
        resp = await call_once_rpc(url, req_dict)

        if resp.get("values", None) is not None and resp["values"][0] is not None:
            values = resp["values"]
            assert values[0].ndim == 1, f"ndim {values[0].ndim}"
            if self.ppo_value_truncate_head:
                values = [v[1:].contiguous() for v in values]
            else:
                values = [v[:-1].contiguous() for v in values]
            resp["values"] = values
        if resp.get("per_token_rewards",
                    None) is not None and resp["per_token_rewards"][0] is not None:
            per_token_rewards = resp["per_token_rewards"]
            assert per_token_rewards[0].ndim == 1, f"ndim {per_token_rewards[0].ndim}"
            if self.ppo_value_truncate_head:
                per_token_rewards = [ptr[1:].contiguous() for ptr in per_token_rewards]
            else:
                per_token_rewards = [ptr[:-1].contiguous() for ptr in per_token_rewards]
            resp["per_token_rewards"] = per_token_rewards

        if resp.get('rewards', None) is not None and resp["rewards"][0] is not None:
            report_data = {"rm_infer_results": len(resp['rewards'])}
            report_ppo_metrics(report_data)
        print_with_rank_and_datetime(f'get_infer_rm_critic_result {ppo_step=} {sample_idx=}')
        return resp
