import logging
from socketserver import BaseRequestHandler, ThreadingMixIn, TCPServer
import threading
import traceback

from ray.rllib.env.policy_client import (
    Commands,
)
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.utils.typing import SampleBatchType
import json

logger = logging.getLogger(__name__)


class SocketServerInputEval(ThreadingMixIn, TCPServer, InputReader):
    def __init__(self, address, port, trainer):
        self.port = port
        handler = _make_handler(trainer)
        try:
            import time

            TCPServer.__init__(self, (address, port), handler)
            print(f"Creating a PolicyServer on {address}:{port}")
        except OSError:
            print(f"Creating a PolicyServer on {address}:{port} failed!")
            import time

            time.sleep(1)
            raise

        logger.info(
            "Starting connector server at " f"{self.server_address}"
        )

    def serve_forever1(self):
        serving_thread = threading.Thread(name="server", target=self.serve_forever)
        serving_thread.daemon = True
        serving_thread.start()

    def next(self) -> SampleBatchType:
        return super().next()


def _make_handler(trainer):
    class Handler(BaseRequestHandler):
        def __init__(self, *a, **kw):
            super().__init__(*a, **kw)

        def handle(self) -> None:
            LEN = 131072 # 2**17
            # self.request is the TCP socket connected to the client
            while True:
                data = bytearray()
                while len(data) < LEN:
                    recv_len = min(4096, LEN - len(data))
                    packet = self.request.recv(recv_len).strip()
                    data.extend(packet)
                if not data or len(data) == 0:
                    print('DISCONNECTED')
                    break
                try:
                    body = data.decode("utf-8").strip()
                    body = body[:body.find('\n')]
                    args = json.loads(body)
                    response = self.execute_command(args)
                    response_str = json.dumps(response) + '\n'
                    response_str = response_str.encode("utf-8")
                    response_str = response_str.ljust(100 - len(response_str), b'0')
                    self.request.sendall(response_str)

                    if args["command"] == Commands.END_EPISODE.value:
                        print('END_EPISODE', args["episode_id"])
                        break
                except Exception:
                    print('exception', traceback.format_exc())
                
        def execute_command(self, args):
            command = args["command"]
            response = {}

            if command == Commands.GET_ACTION.value:
                response["action"] = int(trainer.compute_single_action(args["observation"]))

            return response

    return Handler
