              
                                                      
                                                                 

from typing import Callable
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Dict, Any, List
import asyncio

import torch
import torch.nn.functional as F


from megatron.core import mpu

from gpatch.core.utils import print_with_rank_and_datetime
from gpatch.core.parallel_state import (
    cpu_barrier,
)
from gpatch.core.wecube import report_ppo_metrics
from gpatch.rpc import call_once_rpc


@dataclass
class GptPpoGenRmClientV3:
    ep_ips: str
    ep_ports: int
    timeout: int
    rpc_max_retries: int
    num_rm: int = 1
    unwrap_model_func: Callable = None
    infer_engine_impl: str = None
    ppo_value_truncate_head: bool = False

    def __post_init__(self):
        assert len(self.ep_ips) == len(self.ep_ports)
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        nep = len(self.ep_ips)

    def wait_until_gen_rm_server_is_ready(self):
        for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
            url = f'http://{ep_ip}:{ep_port}/heartbeat'
            co = call_once_rpc(url, {})
            asyncio.run(co)

    def setup(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/setup'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def sleep(self):

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/sleep'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_gen_rm_ppo_step_begin(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        tag_names = ['weights', 'kv_cache']
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
            'tag_names': tag_names,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_gen_rm_ppo_step_begin'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def mark_gen_rm_ppo_step_end(self, ppo_step):
        dp_rank = mpu.get_data_parallel_rank()
        req_dict = {
            'ppo_step': ppo_step,
            'actor_dp_rank': dp_rank,
        }

        async def rpc_co(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/mark_gen_rm_ppo_step_end'
            return await call_once_rpc(url, req_dict)

        async def co():
            rpc_cos = []
            for ep_ip, ep_port in zip(self.ep_ips, self.ep_ports, strict=True):
                rpc_cos.append(rpc_co(ep_ip, ep_port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())

    def pick_ep_idx(self, sample_idx):
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        rdp_size = len(self.ep_ips)
        if dp_size == rdp_size:
            ep_idx = dp_rank
        elif dp_size < rdp_size:
            ep_idx = dp_rank * (rdp_size // dp_size) + sample_idx
        else:
            ep_idx = dp_rank // (dp_size // rdp_size) + sample_idx
        ep_idx %= rdp_size
        return ep_idx

    async def generate_rewards(self, ppo_step, sample_idx, prompt_idx, batch) -> Dict[str, List[Any]]:
        dp_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        ep_idx = self.pick_ep_idx(sample_idx)
        gen_rm_dp_rank = ep_idx

        ep_ip = self.ep_ips[ep_idx]
        ep_port = self.ep_ports[ep_idx]
        req_dict = {
            'actor_dp_rank': dp_rank,
            'actor_dp_size': dp_size,
            'gen_rm_dp_rank': gen_rm_dp_rank,
            'ppo_step': ppo_step,
            'sample_idx': sample_idx,
            'prompt_idx': prompt_idx,
            'batch': batch,
        }

        url = f'http://{ep_ip}:{ep_port}/generate_rewards'
        resp = await call_once_rpc(url, req_dict)

        print_with_rank_and_datetime(f'generate {ppo_step=} {sample_idx=}')

        if resp.get("values", None) is not None and resp["values"][0] is not None:
            values = resp["values"]
            assert values[0].ndim == 1, f"ndim {values[0].ndim}"
            if self.ppo_value_truncate_head:
                values = [v[1:].contiguous() for v in values]
            else:
                values = [v[:-1].contiguous() for v in values]
            resp["values"] = values
        if resp.get("per_token_rewards",
                    None) is not None and resp["per_token_rewards"][0] is not None:
            per_token_rewards = resp["per_token_rewards"]
            assert per_token_rewards[0].ndim == 1, f"ndim {per_token_rewards[0].ndim}"
            if self.ppo_value_truncate_head:
                per_token_rewards = [ptr[1:].contiguous() for ptr in per_token_rewards]
            else:
                per_token_rewards = [ptr[:-1].contiguous() for ptr in per_token_rewards]
            resp["per_token_rewards"] = per_token_rewards

        if resp.get('rewards', None) is not None and resp["rewards"][0] is not None:
            report_data = {"rm_infer_results": len(resp['rewards'])}
            report_ppo_metrics(report_data)
        print_with_rank_and_datetime(f'get_infer_rm_critic_result {ppo_step=} {sample_idx=}')
        return resp

    def infer_engine_flush_cache(self):
        async def flush_cache_rpc(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/flush_cache'
            return await call_once_rpc(url, {})

        async def co():
            rpc_cos = []
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(flush_cache_rpc(ip, port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())
        cpu_barrier()

    def update_gen_rm_weight_by_model_idx(self, rm_idx):
        async def update_gen_rm_weight_by_model_idx_rpc(ep_ip, ep_port):
            url = f'http://{ep_ip}:{ep_port}/update_gen_rm_weight_by_model_idx'
            return await call_once_rpc(url, {'rm_idx': rm_idx})

        async def co():
            rpc_cos = []
            for ip, port in zip(self.ep_ips, self.ep_ports):
                rpc_cos.append(update_gen_rm_weight_by_model_idx_rpc(ip, port))
            return await asyncio.gather(*rpc_cos)

        if torch.distributed.get_rank() == 0:
            asyncio.run(co())
        cpu_barrier()