"""
Multiprocessing inference broker and client for batched inference.
"""

from typing import Optional, Tuple, Dict, Any
import time
import threading
import numpy as np
import torch
import torch.nn as nn
import multiprocessing as mp
from multiprocessing.connection import Connection


class _MPConfig(object):
    def __init__(self,
                 model_type,  # 'cnn' or 'transformer'
                 board_shape,  # Tuple[int,int,int]
                 action_size,
                 embed_dim=None,
                 num_heads=None,
                 num_layers=None,
                 use_sinusoidal_2d_pe=False,
                 use_relative_bias=False,
                 enable_global_context=False,
                 include_action_tokens=False,
                 action_token_in_dim=None,
                 cross_attn_layers=1,
                 device='cuda',
                 max_batch_size=64,
                 max_batch_wait=0.02,
                 auto_batch_wait: bool = False):
        self.model_type = model_type
        self.board_shape = board_shape
        self.action_size = action_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.use_sinusoidal_2d_pe = use_sinusoidal_2d_pe
        self.use_relative_bias = use_relative_bias
        self.enable_global_context = enable_global_context
        self.include_action_tokens = include_action_tokens
        self.action_token_in_dim = action_token_in_dim
        self.cross_attn_layers = cross_attn_layers
        self.device = device
        self.max_batch_size = max_batch_size
        self.max_batch_wait = max_batch_wait
        self.auto_batch_wait = auto_batch_wait


class MPInferenceBroker(object):
    """Owns the model in a dedicated process and serves batched inference over an IPC queue."""

    def __init__(self, state_dict: Dict[str, Any], config: _MPConfig):
        self._request_q = mp.Queue()
        self._control_q = mp.Queue()
        self._proc = None  # type: Optional[mp.Process]
        self._state_dict = state_dict
        self._config = config

    @property
    def request_queue(self) -> mp.Queue:
        return self._request_q

    def start(self):
        if self._proc is not None:
            return
        self._proc = mp.Process(target=_broker_main, args=(self._request_q, self._control_q, self._state_dict, self._config), daemon=True)
        self._proc.start()

    def stop(self, timeout: float = 5.0):
        try:
            self._control_q.put(('stop', None))
        except Exception:
            pass
        if self._proc is not None:
            self._proc.join(timeout=timeout)
            if self._proc.is_alive():
                self._proc.terminate()
            self._proc = None

    def update_weights(self, state_dict: Dict[str, Any]):
        """Non-blocking request to refresh the model weights in the broker."""
        try:
            self._control_q.put(('update', state_dict))
        except Exception:
            pass


class MPInferenceClient(object):
    """Client-side proxy compatible with InferenceServer.submit API (event + container)."""

    def __init__(self, request_queue: mp.Queue):
        self._request_q = request_queue

    def submit(self, board_tensor):
        parent_conn, child_conn = mp.Pipe(duplex=False)
        ev = threading.Event()
        container = {}  # type: Dict[str, Any]
        # background waiter to set event when reply arrives
        def _waiter(conn, _ev, _cont):
            try:
                res = conn.recv()
                try:
                    # Convert dict to InferenceResult namedtuple for compatibility
                    from src.alphazero.mcts import InferenceResult
                    if isinstance(res, dict) and 'policy_logits' in res and 'value' in res:
                        _cont['out'] = InferenceResult(policy_logits=res['policy_logits'], value=res['value'])
                    else:
                        _cont['out'] = res
                except Exception:
                    _cont['out'] = res
            except EOFError:
                _cont['out'] = None
            except Exception:
                _cont['out'] = None
            finally:
                try:
                    conn.close()
                except Exception:
                    pass
                _ev.set()

        try:
            self._request_q.put((board_tensor, child_conn))
        except Exception:
            # If we fail to enqueue, close pipe and return event set with None
            try:
                child_conn.close()
            except Exception:
                pass
            container['out'] = None
            ev.set()
            return ev, container

        t = threading.Thread(target=_waiter, args=(parent_conn, ev, container), daemon=True)
        t.start()
        return ev, container


def _build_model(cfg: _MPConfig) -> nn.Module:
    # Local import to avoid importing in parent unnecessarily
    from src.alphazero.models import AlphaNet, TransformerAlphaNet
    if cfg.model_type == 'cnn':
        net = AlphaNet(board_shape=cfg.board_shape, action_size=cfg.action_size)
    elif cfg.model_type == 'transformer':
        net = TransformerAlphaNet(
            board_shape=cfg.board_shape,
            action_size=cfg.action_size,
            embed_dim=int(cfg.embed_dim or 128),
            depth=int(cfg.num_layers or 4),
            num_heads=int(cfg.num_heads or 4),
            use_sinusoidal_2d_pe=bool(getattr(cfg, 'use_sinusoidal_2d_pe', False)),
            use_relative_bias=bool(getattr(cfg, 'use_relative_bias', False)),
            enable_global_context=bool(getattr(cfg, 'enable_global_context', False)),
            include_action_tokens=bool(getattr(cfg, 'include_action_tokens', False)),
            action_token_in_dim=getattr(cfg, 'action_token_in_dim', None),
            cross_attn_layers=int(getattr(cfg, 'cross_attn_layers', 1)),
        )
    else:
        raise ValueError(f"Unknown model_type: {cfg.model_type}")
    return net


def _broker_main(request_q: mp.Queue, control_q: mp.Queue, state_dict: Dict[str, Any], cfg: _MPConfig):
    import logging
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] broker: %(message)s")
    device = torch.device(cfg.device if (isinstance(cfg.device, str) and cfg.device) else ('cuda' if torch.cuda.is_available() else 'cpu'))
    try:
        net = _build_model(cfg)
        net.load_state_dict(state_dict)
        net.eval()
    except Exception:
        logging.exception("Failed to build/load model in broker; exiting")
        return
    try:
        net.to(device)
    except Exception:
        logging.exception("Failed to move model to %s; falling back to CPU", device)
        device = torch.device('cpu')
        try:
            net.to(device)
        except Exception:
            logging.exception("Failed to move model to CPU; exiting")
            return
    logging.info("Broker running on device %s (batch=%d wait=%.3fs)", device, cfg.max_batch_size, cfg.max_batch_wait)

    buffer = []
    arrivals = []
    last_flush = time.time()
    # Auto-tuning state
    lambda_ema = None  # requests per second
    alpha = 0.2
    last_adjust = time.time()
    adjust_period = 2.0  # seconds
    # Reasonable clamps for wait window
    W_MIN, W_MAX = 0.001, 0.03  # 1–30 ms
    B_TARGET = 32 if str(getattr(cfg, 'model_type', 'cnn')) == 'cnn' else 64
    running = True
    while running:
        # Pull requests with short timeout to also process control messages
        try:
            item = request_q.get(timeout=min(cfg.max_batch_wait, 0.01))
            if item is None:
                # sentinel from parent (optional)
                running = False
            else:
                x, conn = item
                buffer.append((x, conn))
                arrivals.append(time.time())
        except Exception:
            pass

        # Handle control messages
        try:
            while True:
                cmd, payload = control_q.get_nowait()
                if cmd == 'stop':
                    running = False
                    break
                elif cmd == 'update':
                    try:
                        net.load_state_dict(payload)
                        net.eval()
                        logging.info("Weights updated in broker")
                    except Exception:
                        logging.exception("Failed to update weights in broker")
                else:
                    pass
        except Exception:
            pass

    # Flush if batch ready
        if buffer and (len(buffer) >= cfg.max_batch_size or (time.time() - arrivals[0]) >= cfg.max_batch_wait):
            prev_flush = last_flush
            now_ts = time.time()
            batch = buffer[:cfg.max_batch_size]
            buffer = buffer[cfg.max_batch_size:]
            arrivals = arrivals[len(batch):]
            try:
                tensors = []
                expected_C = getattr(net, 'C', None)
                for x, _ in batch:
                    if isinstance(x, np.ndarray):
                        t = torch.from_numpy(x).to(dtype=torch.float32)
                    else:
                        # fallback: assume tensor-like with .cpu().numpy()
                        t = torch.as_tensor(x, dtype=torch.float32)
                    # Auto-expand channels with zeros if needed
                    try:
                        if expected_C is not None and t.dim() == 3 and t.shape[0] != expected_C:
                            c_in, h, w = t.shape
                            if expected_C > c_in and (expected_C % c_in == 0):
                                pad = torch.zeros((expected_C - c_in, h, w), dtype=t.dtype, device=t.device)
                                t = torch.cat([t, pad], dim=0)
                    except Exception:
                        pass
                    tensors.append(t)
                try:
                    bat = torch.stack(tensors, dim=0).to(device, non_blocking=True)
                except Exception:
                    # fallback to CPU batch
                    device = torch.device('cpu')
                    net.to(device)
                    bat = torch.stack(tensors, dim=0).to(device)
                with torch.no_grad():
                    p_logits, v = net(bat)
                p_logits = p_logits.detach().cpu().numpy()
                v = v.detach().cpu().numpy()
                for i, (_, conn) in enumerate(batch):
                    try:
                        conn.send({'policy_logits': p_logits[i], 'value': float(v[i])})
                    except Exception:
                        pass
                    finally:
                        try:
                            conn.close()
                        except Exception:
                            pass
            except Exception:
                logging.exception("Broker batch forward failed; sending None to clients")
                # Send failure to all
                for _, conn in batch:
                    try:
                        conn.send(None)
                    except Exception:
                        pass
                    finally:
                        try:
                            conn.close()
                        except Exception:
                            pass
            # Update flush time and arrival rate estimates
            last_flush = now_ts
            dt = max(1e-3, last_flush - prev_flush)
            obs_lambda = len(batch) / dt
            lambda_ema = obs_lambda if (lambda_ema is None) else ((1 - alpha) * lambda_ema + alpha * obs_lambda)

            # Periodically adjust wait window if enabled
            if getattr(cfg, 'auto_batch_wait', False) and (last_flush - last_adjust) >= adjust_period and (lambda_ema is not None):
                w_new = B_TARGET / max(lambda_ema, 1e-3)
                # Clamp and smooth minor oscillations
                w_new = min(max(W_MIN, w_new), W_MAX)
                if abs(w_new - cfg.max_batch_wait) > 0.001:
                    cfg.max_batch_wait = w_new
                    logging.info("Auto-tune: lambda≈%.1f req/s, B*= %d -> wait=%.3f s", lambda_ema, B_TARGET, cfg.max_batch_wait)
                last_adjust = last_flush

    # Drain any remaining requests with None
    try:
        while True:
            _, conn = request_q.get_nowait()
            try:
                conn.send(None)
            except Exception:
                pass
            finally:
                try:
                    conn.close()
                except Exception:
                    pass
    except Exception:
        pass
    logging.info("Broker stopped")
