"""
Multiprocessing self-play worker entry point.

"""

import time
import numpy as np
from typing import Any, Optional
import multiprocessing as mp

from src.alphazero.mcts import MCTS, GameExample
try:
    from src.alphazero.shaping import ShapingConfig
    from src.alphazero import shaping as shaping_mod
except Exception:
    ShapingConfig = None  # type: ignore
    shaping_mod = None  # type: ignore
try:
    from src.alphazero.featurizer import TransformerFeaturizer
except Exception:
    TransformerFeaturizer = None  # type: ignore
from src.alphazero.mp_infer import MPInferenceClient


def _play_one_game(game_cls, mcts: MCTS, temperature: float = 1.0, history_steps: int = 0, featurizer=None):
    examples = []
    board = game_cls.get_initial_state()
    player = 1
    turn = 0
    # Track history of encoded canonical boards (most recent first)
    from collections import deque as _dq
    hist = _dq(maxlen=max(0, int(history_steps or 0)))
    # Track repetition set and local featurizer state
    seen = set()
    feat = featurizer
    if feat is not None:
        try:
            feat.reset()
        except Exception:
            pass
    while True:
        # Provide history frames to MCTS
        if history_steps and history_steps > 0:
            mcts.set_root_history(list(hist))
        # Update featurizer root context if available
        if feat is not None:
            try:
                steps_left_norm = 0.0
                if hasattr(game_cls, 'STEP_CHANNEL') and hasattr(game_cls, 'MAX_STEPS'):
                    try:
                        ch = int(getattr(game_cls, 'STEP_CHANNEL'))
                        max_steps = float(getattr(game_cls, 'MAX_STEPS') or 1.0)
                        val = float(board[ch, 0, 0])
                        steps_left_norm = max(0.0, min(1.0, val / max_steps))
                    except Exception:
                        steps_left_norm = 0.0
                try:
                    canonical_tmp = game_cls.canonical_form(board, player)
                except Exception:
                    canonical_tmp = board if player == 1 else -board
                try:
                    key = np.ascontiguousarray(canonical_tmp).tobytes()
                except Exception:
                    key = bytes(memoryview(np.ascontiguousarray(canonical_tmp)))
                feat.set_root_context(is_repetition=(key in seen), since_last_damage_norm=0.0, steps_left_norm=steps_left_norm)
            except Exception:
                pass
        probs = mcts.search(board, player)
        if temperature == 0:
            action = int(np.argmax(probs))
        else:
            probs_temp = probs ** (1.0 / temperature)
            if probs_temp.sum() == 0:
                legal = game_cls.legal_actions(board)
                probs_temp = np.zeros_like(probs)
                probs_temp[legal] = 1.0
            probs_temp /= probs_temp.sum()
            action = int(np.random.choice(len(probs_temp), p=probs_temp))

        canonical = game_cls.canonical_form(board, player)
        board_enc = game_cls.encode_board(canonical)
        if feat is not None:
            stacked = feat.make_input(board_enc.astype(np.float32, copy=False))
            examples.append(GameExample(board=stacked, pi=probs.copy(), value=None))
            try:
                feat.push(board_enc)
            except Exception:
                pass
        elif history_steps and history_steps > 0:
            C, H, W = board_enc.shape
            planes = [board_enc.astype(np.float32, copy=False)]
            for i in range(history_steps):
                if i < len(hist):
                    planes.append(hist[i].astype(np.float32, copy=False))
                else:
                    planes.append(np.zeros((C, H, W), dtype=np.float32))
            stacked = np.concatenate(planes, axis=0)
            examples.append(GameExample(board=stacked, pi=probs.copy(), value=None))
            hist.appendleft(board_enc)
        else:
            examples.append(GameExample(board=board_enc, pi=probs.copy(), value=None))
        # Mark repetition after using this root
        try:
            seen.add(key)
        except Exception:
            pass

        prev_board = board
        try:
            board = game_cls.next_state(board, action, player)
        except Exception:
            # illegal transition, treat as loss for acting player
            winner = -player
            result = 1.0
            break
        try:
            mcts.advance_root(prev_board, action, player, next_board=board)
        except Exception:
            pass

        terminal, winner = game_cls.is_terminal(board)
        if terminal:
            if winner == 0:
                result = 0.0
            else:
                result = 1.0 if winner == player else -1.0
            cur = 1
            out = []
            for ex in examples:
                val = result if cur == player else -result
                out.append(GameExample(board=ex.board, pi=ex.pi, value=val))
                cur *= -1
            return out
        player = -player
        turn += 1


def run_selfplay_proc(request_queue: mp.Queue, out_queue: mp.Queue, game_cls: Any, num_sims: int = 50, temperature: float = 1.0, history_steps: int = 0, featurizer_config: dict | None = None,
                      shaping_config: Optional[dict] = None, phi_name: Optional[str] = None,
                      root_noise: Optional[dict] = None, board_rows: Optional[int] = None, board_cols: Optional[int] = None,
                      shoot_range: Optional[int] = None, shrink_interval: Optional[int] = None, num_obstacles: Optional[int] = None,
                      max_steps: Optional[int] = None, capture_steps: Optional[int] = None, max_health: Optional[int] = None):
    # Apply board geometry overrides inside the child process (spawn doesn't carry class attr mutations)
    # Apply overrides; if configure() fails or is absent, fall back to direct attribute writes.
    override_kwargs = {
        k: v for k, v in dict(
            board_rows=board_rows, board_cols=board_cols,
            SHOOT_RANGE=shoot_range, SHRINK_INTERVAL=shrink_interval,
            NUM_OBSTACLES=num_obstacles, MAX_STEPS=max_steps,
            CAPTURE_STEPS=capture_steps, MAX_HEALTH=max_health,
        ).items() if v is not None
    }
    used_configure = False
    if hasattr(game_cls, 'configure') and callable(getattr(game_cls, 'configure')) and override_kwargs:
        try:
            game_cls.configure(**override_kwargs)
            used_configure = True
        except Exception:
            # Fall through to direct setattr below
            used_configure = False
    if not used_configure:
        try:
            if (board_rows is not None) and hasattr(game_cls, 'ROWS'):
                setattr(game_cls, 'ROWS', int(board_rows))
            if (board_cols is not None) and hasattr(game_cls, 'COLS'):
                setattr(game_cls, 'COLS', int(board_cols))
            for attr, val in [
                ('SHOOT_RANGE', shoot_range), ('SHRINK_INTERVAL', shrink_interval),
                ('NUM_OBSTACLES', num_obstacles), ('MAX_STEPS', max_steps),
                ('CAPTURE_STEPS', capture_steps), ('MAX_HEALTH', max_health),
            ]:
                if val is not None and hasattr(game_cls, attr):
                    try:
                        setattr(game_cls, attr, type(getattr(game_cls, attr))(val))
                    except Exception:
                        setattr(game_cls, attr, val)
        except Exception:
            # As a last resort, leave defaults (may cause mismatches caught downstream)
            pass
    # Also propagate overrides to base classes that expose the same attributes, because
    # some game implementations reference parent static attributes directly.
    try:
        bases = getattr(game_cls, '__mro__', ())[1:]  # exclude self
        attrs = {
            'ROWS': board_rows,
            'COLS': board_cols,
            'SHOOT_RANGE': shoot_range,
            'SHRINK_INTERVAL': shrink_interval,
            'NUM_OBSTACLES': num_obstacles,
            'MAX_STEPS': max_steps,
            'CAPTURE_STEPS': capture_steps,
            'MAX_HEALTH': max_health,
        }
        for base in bases:
            if base is object:
                continue
            for name, val in attrs.items():
                if val is None:
                    continue
                if hasattr(base, name):
                    try:
                        cur = getattr(base, name)
                        setattr(base, name, type(cur)(val))
                    except Exception:
                        try:
                            setattr(base, name, val)
                        except Exception:
                            pass
    except Exception:
        pass
    client = MPInferenceClient(request_queue)
    # Build local featurizer if configured
    feat = None
    if featurizer_config is not None and TransformerFeaturizer is not None:
        try:
            feat = TransformerFeaturizer(
                game_cls,
                history_steps=int(featurizer_config.get('history_steps', 0) or 0),
                include_steps_left_plane=bool(featurizer_config.get('include_steps_left_plane', False)),
                include_repetition_plane=bool(featurizer_config.get('include_repetition_plane', False)),
                include_since_damage_plane=bool(featurizer_config.get('include_since_damage_plane', False)),
            )
        except Exception:
            feat = None
    # Resolve shaping if provided
    phi_fn = None
    cfg_obj = None
    if shaping_config and ShapingConfig is not None:
        try:
            cfg_obj = ShapingConfig(**shaping_config)
        except Exception:
            cfg_obj = None
    if phi_name and shaping_mod is not None:
        try:
            phi_fn = getattr(shaping_mod, phi_name)
        except Exception:
            phi_fn = None
    rn_enable = True if root_noise is None else bool(root_noise.get('enable', True))
    rn_alpha = 0.3 if root_noise is None else float(root_noise.get('alpha', 0.3))
    rn_frac = 0.25 if root_noise is None else float(root_noise.get('frac', 0.25))
    mcts = MCTS(game_cls, client, num_simulations=num_sims, add_root_noise=rn_enable,
                root_dirichlet_alpha=rn_alpha, root_exploration_frac=rn_frac,
                history_steps=max(0, int(history_steps or 0)), featurizer=feat,
                shaping_config=cfg_obj, phi_fn=phi_fn)
    while True:
        try:
            examples = _play_one_game(game_cls, mcts, temperature=temperature, history_steps=history_steps, featurizer=feat)
            for ex in examples:
                out_queue.put(ex)
        except KeyboardInterrupt:
            break
        except Exception:
            time.sleep(0.05)
