from collections import defaultdict
from contextlib import asynccontextmanager, nullcontext
import asyncio
import io
import random
import os
import time
import sys
import traceback

from fastapi import FastAPI, HTTPException
from fastapi import Request
from fastapi.responses import Response
import fastapi
import torch
import uvicorn

from gpatch.rpc.monitor import set_exit_flag
from gpatch.core.utils import print_with_rank_and_datetime


def once_rpc(do_monitor=False, monitor_server_ip='localhost', monitor_port=62000):
    def once_rpc_wrapper(fn):
        cached_resp_dicts = {}
        gc_wheel = defaultdict(list)

        async def rpc_fn(req: Request) -> Response:
            req_dict = await req.body()
            req_dict = torch.load(io.BytesIO(req_dict), weights_only=False)

            uniq_id = req_dict['once_rpc']['uniq_id']
            cmd = req_dict['once_rpc']['cmd']

            if cmd == 'rpc':
                if uniq_id in cached_resp_dicts:
                    resp_dict = cached_resp_dicts[uniq_id]
                    if resp_dict['once_rpc']['deleted']:
                        raise HTTPException(status_code=400, detail=f'bad req rare case wandering dup {req=}')
                else:
                    cached_resp_dicts[uniq_id] = {
                        'once_rpc': {
                            'deleted': False,
                            'ready': False,
                        },
                    }

                    req_dict.pop('once_rpc')
                    try:
                        resp_dict = await fn(req_dict)
                    except Exception:
                                                                                      
                        print_with_rank_and_datetime(f"once rpc traceback {uniq_id=}\n{traceback.format_exc()}")
                        if do_monitor:
                            set_exit_flag(monitor_server_ip, monitor_port)
                        raise HTTPException(status_code=500, detail=f'fn exception')

                    assert 'once_rpc' not in resp_dict, 'once_rpc is a reserved name'
                    resp_dict['once_rpc'] = {
                        'deleted': False,
                        'ready': True,
                    }
                    cached_resp_dicts[uniq_id] = resp_dict

            elif cmd == 'clean_up':
                                                                                       
                                                                          
                now = time.time()
                resp_dict = {
                    'once_rpc': {
                        'deleted': True,
                        'deleted_time': now,
                        'ret': 'ok',
                    },
                }
                cached_resp_dicts[uniq_id] = resp_dict

                               
                                          
                                                                                                                                                      
                now_as_hour = now // 3600
                gc_wheel[now_as_hour].append(uniq_id)
                gc_hs = [_ for _ in gc_wheel.keys()]
                for gc_hour in gc_hs:
                    gc_bucket = gc_wheel[gc_hour]
                    if now_as_hour - gc_hour >= 2:
                        rank = -1
                        if torch.distributed.is_initialized():
                            rank = torch.distributed.get_rank()
                        print(f'once rpc gc {rank=} {gc_hour=} {now_as_hour=} {len(gc_bucket)=}')
                        for gc_uniq_id in gc_bucket:
                            cached_resp_dicts.pop(gc_uniq_id, None)
                        gc_wheel.pop(gc_hour)

            else:
                raise ValueError('once_rpc unknown {cmd=}')

            fio = io.BytesIO()
            torch.save(resp_dict, fio)
            content = fio.getvalue()
            return Response(content=content)

        return rpc_fn

    return once_rpc_wrapper
