import time
import random
import numpy as np
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed  # NEW

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = None

from .mcts import MCTS, InferenceServer


class GameResult:
    def __init__(self, winner, moves, history):
        self.winner = winner     # 1, -1, or 0 for draw
        self.moves = moves
        self.history = history   # list of (player_to_move, action)


class Player:
    """Abstract player. Implement select_action(game, board, to_play)."""
    name = "Player"

    def select_action(self, game, board, to_play):
        raise NotImplementedError


class RandomPlayer(Player):
    def __init__(self, name="Random"):
        self.name = name

    def select_action(self, game, board, to_play):
        legal = game.legal_actions(board)
        return int(random.choice(legal)) if legal else -1


class MCTSPlayer(Player):
    """Synchronous MCTS player using the shared InferenceServer.
    """
    def __init__(self, game, inference, num_sims=50, temperature=0.0, name=None, history_steps: int = 0, featurizer=None, featurizer_config: dict | None = None,
                 shaping_config=None, phi_fn=None):
        self.game = game
        self.inference = inference
        self.num_sims = num_sims
        self.temperature = max(0.0, float(temperature))
        self.name = name or f"MCTS(sims={num_sims},T={self.temperature})"
        self.history_steps = max(0, int(history_steps or 0))
        # Build per-player featurizer (stateless across games, reset per select if needed)
        self._featurizer = featurizer
        self._featurizer_config = dict(featurizer_config) if featurizer_config else None
        if self._featurizer is None and self._featurizer_config is not None:
            try:
                from .featurizer import TransformerFeaturizer as _TF
                self._featurizer = _TF(
                    self.game,
                    history_steps=self._featurizer_config.get('history_steps', 0),
                    include_steps_left_plane=self._featurizer_config.get('include_steps_left_plane', False),
                    include_repetition_plane=self._featurizer_config.get('include_repetition_plane', False),
                    include_since_damage_plane=self._featurizer_config.get('include_since_damage_plane', False),
                )
            except Exception:
                self._featurizer = None
        # Shaping (store for cloning)
        self._shaping_config = shaping_config
        self._phi_fn = phi_fn
        self._mcts = MCTS(
            self.game,
            self.inference,
            num_simulations=self.num_sims,
            history_steps=self.history_steps,
            featurizer=self._featurizer,
            shaping_config=self._shaping_config,
            phi_fn=self._phi_fn,
        )
        # Track our own history of canonical encoded boards (most recent first)
        from collections import deque as _dq
        self._hist = _dq(maxlen=self.history_steps)

    def select_action(self, game, board, to_play):
        # Reuse the internal search tree across moves; search() will reset when needed
        # Provide history planes if enabled
        if self.history_steps > 0:
            try:
                self._mcts.set_root_history(list(self._hist))
            except Exception:
                self._hist.clear()
        probs = self._mcts.search(board, to_play)  # expects canonicalization inside
        legal = game.legal_actions(board)
        if not legal:
            return -1
        p = np.zeros_like(probs, dtype=np.float32)
        p[legal] = probs[legal]
        if self.temperature <= 1e-6:
            # Greedy over legal actions
            return int(max(legal, key=lambda a: (p[a], random.random())))
        # Sample with temperature
        p = p ** (1.0 / self.temperature)
        if p.sum() <= 0:
            return int(random.choice(legal))
        p = p / p.sum()
        action = int(np.random.choice(len(p), p=p))
        # Update our local history for the next time we're to move
        if self.history_steps > 0:
            try:
                canonical = game.canonical_form(board, to_play)
                enc = game.encode_board(canonical)
                self._hist.appendleft(enc)
            except Exception:
                # If anything goes wrong, reset history to avoid shape mismatches
                self._hist.clear()
        return action

    # NEW: notify the MCTS about applied actions to advance the root
    def on_action_applied(self, game, board_before, action, player_who_moved, next_board=None):
        try:
            self._mcts.advance_root(board_before, action, player_who_moved, next_board=next_board)
        except Exception:
            logging.exception("MCTSPlayer failed to advance root; will rebuild on next search")


class Arena:
    """Pit two players against each other on a given game."""
    def __init__(self, game, player1, player2, verbose=False, max_moves=None, visualize=False, render_every=1, render_delay=0.0):
        self.game = game
        self.p1 = player1
        self.p2 = player2
        self.verbose = verbose
        self.max_moves = max_moves
        # Visualization options
        self.visualize = bool(visualize)
        self.render_every = max(1, int(render_every))
        self.render_delay = float(render_delay)
        self.logger = logging.getLogger("alphazero.arena")

    def _fresh_player(self, p):
        # Recreate players for thread safety (MCTSPlayer holds stateful MCTS)
        if isinstance(p, MCTSPlayer):
            # Preserve history_steps and featurizer_config for channel alignment
            cfg = getattr(p, "_featurizer_config", None)
            return MCTSPlayer(
                self.game,
                p.inference,
                num_sims=p.num_sims,
                temperature=p.temperature,
                name=p.name,
                history_steps=p.history_steps,
                featurizer=None,
                featurizer_config=cfg,
                shaping_config=getattr(p, '_shaping_config', None),
                phi_fn=getattr(p, '_phi_fn', None),
            )
        if isinstance(p, RandomPlayer):
            return RandomPlayer(name=p.name)
        # Fallback: try to reuse (may not be thread-safe)
        return p

    def _maybe_render(self, board, move_idx):
        if not self.visualize:
            return
        if hasattr(self.game, 'render') and callable(getattr(self.game, 'render')):
            try:
                if move_idx % self.render_every == 0:
                    self.game.render(board)
                    if self.render_delay > 0:
                        time.sleep(self.render_delay)
            except Exception:
                self.logger.exception("Rendering failed on move %d", move_idx)
        else:
            # Only warn once per game start
            if move_idx == 0:
                self.logger.warning("Game does not implement render(state); visualization disabled.")

    # NEW: When visualizing, display which action was taken on this move
    def _maybe_show_action(self, to_play, action, player, move_idx):
        if not self.visualize:
            return
        try:
            side = "P1" if to_play == 1 else "P2"
            name = getattr(player, "name", side)
            # Prefer game-provided action name when available
            readable = None
            if hasattr(self.game, 'action_name') and callable(getattr(self.game, 'action_name')):
                try:
                    readable = self.game.action_name(action)
                except Exception:
                    readable = None
            if readable is None:
                readable = str(action)
            print(f"[Arena] Move {move_idx + 1}: {name} ({side}) played {readable} (action={action})")
        except Exception:
            self.logger.exception("Failed to display action on move %d", move_idx)

    # NEW: notify both players after a move so they can advance their internal trees
    def _notify_action(self, board_before, action, to_play, next_board):
        for pl in (self.p1, self.p2):
            cb = getattr(pl, 'on_action_applied', None)
            if callable(cb):
                try:
                    cb(self.game, board_before, action, to_play, next_board)
                except Exception:
                    self.logger.exception("Player '%s' on_action_applied failed", getattr(pl, 'name', '?'))

    def play_game(self, starting_player=1, seed=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        board = self.game.get_initial_state()
        to_play = 1 if starting_player >= 0 else -1
        history = []
        moves = 0

        # Render initial state
        self._maybe_render(board, moves)

        while True:
            player = self.p1 if to_play == 1 else self.p2
            try:
                action = player.select_action(self.game, board, to_play)
            except Exception:
                self.logger.exception("Player '%s' crashed while selecting action.", getattr(player, "name", "?"))
                # Opponent wins by forfeit
                return GameResult(winner=-to_play, moves=moves, history=history)

            legal = self.game.legal_actions(board)
            if action not in legal:
                # Fallback to a random legal move to keep the match going
                if self.verbose:
                    print(f"[Arena] Illegal action {action} by {player.name}. Picking random legal.")
                action = int(random.choice(legal)) if legal else -1

            history.append((to_play, action))

            # NEW: Show the action taken by the current player when visualizing
            self._maybe_show_action(to_play, action, player, moves)

            # Apply move and notify players to advance their trees
            prev_board = board
            try:
                board = self.game.next_state(prev_board, action, to_play)
            except Exception:
                self.logger.warning("Illegal next_state by '%s' (action=%s). Forfeit.", getattr(player, "name", "?"), action)
                return GameResult(winner=-to_play, moves=moves, history=history)

            # Notify both players (after move so we can pass next_board)
            self._notify_action(prev_board, action, to_play, board)

            moves += 1

            # Render after move
            self._maybe_render(board, moves)

            terminal, winner = self.game.is_terminal(board)
            if terminal:
                if self.verbose:
                    print(f"[Arena] Game ended. Winner: {winner} after {moves} moves.")
                return GameResult(winner=winner, moves=moves, history=history)

            if self.max_moves and moves >= self.max_moves:
                if self.verbose:
                    print(f"[Arena] Reached max_moves={self.max_moves}. Declaring draw.")
                return GameResult(winner=0, moves=moves, history=history)

            to_play = -to_play

    def play_games(self, n, alternate_colors=True, show_progress=True, num_workers=1):
        # Serial path (default behavior)
        if num_workers is None or num_workers <= 1:
            pbar = tqdm(total=n, desc="Arena") if (tqdm and show_progress) else None
            results = {self.p1.name: 0, self.p2.name: 0, "draws": 0}
            for i in range(n):
                start = 1 if (not alternate_colors or i % 2 == 0) else -1
                res = self.play_game(starting_player=start)
                if res.winner == 1:
                    results[self.p1.name] += 1
                elif res.winner == -1:
                    results[self.p2.name] += 1
                else:
                    results["draws"] += 1
                if pbar:
                    pbar.update(1)
                    pbar.set_postfix(results)
            if pbar:
                pbar.close()
            return results

        # Parallel path (threads share InferenceServer for batching)
        p1_name, p2_name = self.p1.name, self.p2.name
        results = {p1_name: 0, p2_name: 0, "draws": 0}
        pbar = tqdm(total=n, desc=f"Arena x{num_workers}") if (tqdm and show_progress) else None

        # Precompute starting sides to keep alternation semantics
        starts = [1 if (not alternate_colors or i % 2 == 0) else -1 for i in range(n)]
        seed_base = int(time.time())

        def _one(i, start):
            # Build fresh players for this game (thread-safe)
            local_p1 = self._fresh_player(self.p1)
            local_p2 = self._fresh_player(self.p2)
            local_arena = Arena(self.game, local_p1, local_p2, verbose=False, max_moves=self.max_moves,
                                visualize=self.visualize, render_every=self.render_every, render_delay=self.render_delay)
            try:
                res = local_arena.play_game(starting_player=start, seed=seed_base + i)
                return res.winner
            except Exception:
                self.logger.exception("Arena worker failed on game %d", i)
                return 0

        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            futures = [ex.submit(_one, i, starts[i]) for i in range(n)]
            for fut in as_completed(futures):
                w = fut.result()
                if w == 1:
                    results[p1_name] += 1
                elif w == -1:
                    results[p2_name] += 1
                else:
                    results["draws"] += 1
                if pbar:
                    pbar.update(1)
                    pbar.set_postfix(results)
        if pbar:
            pbar.close()
        return results

    def play_games_balanced(self, n, alternate_colors=None, show_progress=True, num_workers=1):
        """
        Play n games ensuring each player moves first exactly n//2 times (balanced).
        Parameters:
            n (int): total games. If odd, one is discarded.
            alternate_colors: accepted for API compatibility (ignored).
            show_progress (bool)
            num_workers (int)
        """
        if n < 2:
            return self.play_games(n, alternate_colors=False, show_progress=show_progress, num_workers=num_workers)
        if n % 2 != 0:
            self.logger.warning("play_games_balanced: n=%d is odd; using %d games (discarding one).", n, n - 1)
            n = n - 1
        half = n // 2
        starts = [1] * half + [-1] * half  # First half P1 starts, second half P2 starts

        # Serial path
        if num_workers is None or num_workers <= 1:
            pbar = tqdm(total=n, desc="Arena(balanced)") if (tqdm and show_progress) else None
            results = {self.p1.name: 0, self.p2.name: 0, "draws": 0}
            for i, start in enumerate(starts):
                res = self.play_game(starting_player=start)
                if res.winner == 1:
                    results[self.p1.name] += 1
                elif res.winner == -1:
                    results[self.p2.name] += 1
                else:
                    results["draws"] += 1
                if pbar:
                    pbar.update(1)
                    pbar.set_postfix(results)
            if pbar:
                pbar.close()
            return results

        # Parallel path
        p1_name, p2_name = self.p1.name, self.p2.name
        results = {p1_name: 0, p2_name: 0, "draws": 0}
        pbar = tqdm(total=n, desc=f"Arena(bal) x{num_workers}") if (tqdm and show_progress) else None
        seed_base = int(time.time())

        def _one(i, start):
            local_p1 = self._fresh_player(self.p1)
            local_p2 = self._fresh_player(self.p2)
            local_arena = Arena(self.game, local_p1, local_p2, verbose=False, max_moves=self.max_moves,
                                visualize=self.visualize, render_every=self.render_every, render_delay=self.render_delay)
            try:
                res = local_arena.play_game(starting_player=start, seed=seed_base + i)
                return res.winner
            except Exception:
                self.logger.exception("Arena balanced worker failed on game %d", i)
                return 0

        from concurrent.futures import ThreadPoolExecutor, as_completed
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            futures = [ex.submit(_one, i, starts[i]) for i in range(n)]
            for fut in as_completed(futures):
                w = fut.result()
                if w == 1:
                    results[p1_name] += 1
                elif w == -1:
                    results[p2_name] += 1
                else:
                    results["draws"] += 1
                if pbar:
                    pbar.update(1)
                    pbar.set_postfix(results)
        if pbar:
            pbar.close()
        return results


def pit(game, player1, player2, games=20, alternate_colors=True, verbose=False, num_workers=1, visualize=False, render_every=1, render_delay=0.0):
    arena = Arena(game, player1, player2, verbose=verbose, visualize=visualize, render_every=render_every, render_delay=render_delay)
    return arena.play_games(games, alternate_colors=alternate_colors, show_progress=not verbose, num_workers=num_workers)

def pit_balanced(game, player1, player2, games=20, verbose=False, num_workers=1,
                 visualize=False, render_every=1, render_delay=0.0):
    """
    Convenience wrapper to run a balanced pit: first-move split evenly.
    """
    arena = Arena(game, player1, player2, verbose=verbose,
                  visualize=visualize, render_every=render_every, render_delay=render_delay)
    return arena.play_games_balanced(games, show_progress=not verbose, num_workers=num_workers)
