import numpy as np
import chess
import chess.engine
import chess.pgn
import os
from tqdm import tqdm
from utils import _z_to_wp, _softmax, SF_PATH, CFG_DIR
import sys
from multiprocessing import Pool
import zstandard as zstd


class MatchBoard(chess.Board):
    def __init__(self, fen=None):
        super().__init__(fen)
        self.history: list[dict] = []
        self.winner: float | None = None


def get_zero_fen(board: MatchBoard):
    fen = board.fen()
    fen = " ".join(fen.split(" ")[:-2])
    return fen


class SF_Match:
    def __init__(
        self,
        sf: chess.engine.SimpleEngine,
        sf2=None,
        temp=1,
        temp2=None,
        n=1,
        max_game=600,
        nodes=1000000,
        nodes2=None,
        seed: int | None = 0,
        tqdm_kwargs: dict | None = None,
    ):
        self.sf = [sf, sf if sf2 is None else sf2]
        self.temp = [temp, temp if temp2 is None else temp2]
        self.nodes = [nodes, nodes if nodes2 is None else nodes2]
        self.n = n
        self.max_game = max_game
        self.seed = seed
        if not tqdm_kwargs:
            self.tqdm_kwargs = {"ncols": 80}
        else:
            self.tqdm_kwargs = tqdm_kwargs

    def _get_move(self, board: chess.Board, turn: int):
        legals = list(sorted([m.uci() for m in board.legal_moves]))
        result = self.sf[turn].analyse(
            board, chess.engine.Limit(nodes=self.nodes[turn]), multipv=len(legals)
        )

        ucis, wps = [], []
        for info in result:
            assert "pv" in info and "score" in info
            ucis.append(info["pv"][0].uci())
            wdl = np.array(list(info["score"].wdl()))
            wdl = wdl / wdl.sum()
            wps.append(_z_to_wp(wdl))
        wps = np.array(wps)
        wps_probs = _softmax(wps, self.temp[turn])

        best_val = wps.max()
        temp_val = wps @ wps_probs
        move = self._rng_model.choice(ucis, p=wps_probs)

        return move, best_val, temp_val

    def run_half_match(self, reverse=False, savepath=None):
        if self.seed is None:
            seed = int.from_bytes(os.urandom(4), byteorder="little")
            self._rng_choice = np.random.default_rng(seed)
            self._rng_model = np.random.default_rng(seed)
        else:
            self._rng_choice = np.random.default_rng(self.seed)
            self._rng_model = np.random.default_rng(self.seed)

        reverse = int(reverse)
        data = []
        moves = []
        for i in range(self.n):
            board = chess.Board()
            data_i = []
            for movenum in tqdm(
                range(self.max_game),
                **self.tqdm_kwargs,
                desc=f"Match {i}/{self.n}",
            ):
                turn = (movenum + reverse) % 2
                move, best_val, temp_val = self._get_move(board, turn)
                data_i.append(
                    {
                        "seed": self.seed,
                        "temp1": self.temp[reverse],
                        "temp2": self.temp[not reverse],
                        "nodes1": self.nodes[turn],
                        "nodes2": self.nodes[not turn],
                        "i": i,
                        "movei": movenum,
                        "uci": move,
                        "best_val": round(best_val, 4),
                        "temp_val": round(temp_val, 4),
                    }
                )
                moves.append(move)
                board.push_uci(move)
                if board.is_game_over():
                    result = board.result()
                    break
            if not board.is_game_over():
                result = "1/2-1/2"
            result = 1 if result == "1-0" else 0 if result == "0-1" else 0.5
            data.append(data_i)

            if savepath:
                with zstd.open(savepath, "ab") as f:
                    line = f"{self.seed},{chess.STARTING_FEN},{' '.join(moves)},{self.nodes},{self.temp},{result}\n"
                    f.write(line.encode("utf-8"))

        return data

    def run_full_match(self, savepath=None):
        data1 = self.run_half_match(False, savepath)
        data2 = self.run_half_match(True, savepath)
        data = np.concatenate([data1, data2])
        return data


if __name__ == "__main__":
    #     temp = float(sys.argv[1])
    #     seed = int(sys.argv[2])
    #     with chess.engine.SimpleEngine.popen_uci(SF_PATH) as sf:
    #         match = SF_Match(
    #             sf,
    #             nodes=1000000,
    #             temp=temp,
    #             n=32,
    #             seed=seed,
    #             tqdm_kwargs={"ncols": 80, "disable": True},
    #         )
    #         data = match.run_half_match(
    #             savepath="/storage1/fs1/XXXX-1/Active/chess/data/leela/data.csv"
    #         )

    temp = float(sys.argv[1])
    nodes = int(sys.argv[2])
    s0 = int(sys.argv[3])
    s1 = s0 + 100

    print(temp, nodes, s0, s1)

    out_path = f"/storage1/fs1/XXXX-1/Active/chess/data/sf_matches/{temp}_{nodes}_{s0}.csv.zst"
    for seed in tqdm(range(s0, s1)):
        with chess.engine.SimpleEngine.popen_uci(SF_PATH) as engine:
            res = SF_Match(
                engine,
                temp=temp,
                nodes=nodes,
                n=1,
                seed=seed,
                tqdm_kwargs={"disable": True},
            ).run_half_match(savepath=out_path)


class Match:
    def __init__(
        self,
        tfp1,
        tfp2,
        tfp_assistant=None,
        asst_thresh: int | None = None,
        first_moves: list[str] | None = None,
        n=1,
        seed: int | None = 0,
        temp=1,
        max_game=600,
        boards: chess.Board | list[chess.Board] | None = None,
        board_cache=None,
        batch_size=None,
        tqdm_kwargs: dict | None = None,
        verbose=False,
    ):
        self.tfp1 = tfp1
        self.tfp2 = tfp2
        self.tfp_assistant = tfp_assistant
        self.first_moves = first_moves
        self.asst_thresh = asst_thresh
        self.n = n
        self.seed = seed
        self.temp = temp
        self.max_game = max_game
        self.board_cache = board_cache if board_cache else {}
        self.boards = boards
        self.batch_size = batch_size
        if tqdm_kwargs is None:
            if verbose:
                self.tqdm_kwargs = {"ncols": 80}
            else:
                self.tqdm_kwargs = {"disable": True}
        else:
            self.tqdm_kwargs = tqdm_kwargs

    def _get_results(self, tfp, arrs):
        if tfp == "random":
            return [
                {"y": self._rng_model.random(1858), "z": None, "m": None}
                for _ in range(len(arrs))
            ]

        if self.batch_size is None:
            bs = len(arrs)
        else:
            bs = self.batch_size

        res: list[dict[str, np.ndarray | None]] = []
        for i in range(0, len(arrs), bs):
            out = tfp.model(arrs[i : i + bs])
            for y, z, m in zip(*out[:3]):
                res.append({"y": y.numpy(), "z": z.numpy(), "m": m.numpy()})
        return res

    def _get_choices(self, preds: np.ndarray, legals: np.ndarray, verbose=False):
        if self.temp:
            preds = np.where(legals, preds, -np.inf)
            preds = _softmax(preds / self.temp)
            preds = np.where(legals, preds, 0)
            preds = preds / preds.sum(axis=1, keepdims=True)
            s = np.cumsum(preds, axis=1)
            # Use the dedicated choice RNG, avoid rounding errors
            r = self._rng_choice.random(size=(preds.shape[0], 1)) * 0.999999
            maxs = np.sum(s < r, axis=1)

        else:
            preds = np.where(legals, preds, -np.inf)
            maxs = np.argmax(preds, axis=1)
        return maxs

    def _push_policies(self, boards: list[MatchBoard], moves=None):
        from data import board_to_uci

        if moves is not None:
            for b, m in zip(boards, moves):
                b.history[-1]["choice"] = m
                b.push_uci(m)
            return

        legals = np.zeros((len(boards), 1858), dtype=int)
        legal_uci_dicts = [dict() for _ in range(len(boards))]
        for i, board in tqdm(enumerate(boards), disable=True):
            legal_list = board.history[-1]["legals"]
            idxs = board_to_uci(board, legal_list)
            legals[i, idxs] = 1
            for j, idx in enumerate(idxs):
                legal_uci_dicts[i][idx] = legal_list[j]

        policies = np.stack([board.history[-1]["y"] for board in boards])

        choices = self._get_choices(policies, legals)
        for b, c, d in tqdm(zip(boards, choices, legal_uci_dicts), disable=True):
            b.history[-1]["choice"] = d[c]
            b.push_uci(d[c])

    def _run_half_match(self, tfp1, tfp2):
        from data import board_to_arr

        if self.seed is None:
            seed = int.from_bytes(os.urandom(4), byteorder="little")
            self._rng_choice = np.random.default_rng(seed)
            self._rng_model = np.random.default_rng(seed)
        else:
            self._rng_choice = np.random.default_rng(self.seed)
            self._rng_model = np.random.default_rng(self.seed)

        if isinstance(self.boards, chess.Board):
            fen = self.boards.fen()
            boards = [MatchBoard(fen) for _ in range(self.n)]
        elif isinstance(self.boards, list):
            boards = [MatchBoard(board.fen()) for board in self.boards]
        else:
            boards = [MatchBoard(chess.STARTING_FEN) for _ in range(self.n)]
        ongoing_boards = boards

        tfps = [tfp1, tfp2]
        pbar = tqdm(range(self.max_game), **self.tqdm_kwargs)  # type: ignore ; just raise error if tqdm_kwargs is invalid
        for movenum in pbar:
            turn = movenum % 2
            arrs = []
            for board in tqdm(ongoing_boards, disable=True):
                board.history.append({})
                board.history[-1]["fen"] = get_zero_fen(board)
                board.history[-1]["legals"] = sorted(
                    [m.uci() for m in board.legal_moves]
                )
                if (
                    board.is_fifty_moves()
                    or board.is_insufficient_material()
                    or board.is_fivefold_repetition()
                    or (  # Stalemate
                        not board.history[-1]["legals"] and not board.is_check()
                    )
                ):
                    board.winner = 0.5
                elif not board.history[-1]["legals"] and board.is_check():
                    # Checkmate
                    board.winner = turn
                else:
                    if self.board_cache and (
                        hit := self.board_cache.get(board.history[-1]["fen"], None)
                    ):
                        board.history[-1]["y"] = hit[0]
                        board.history[-1]["z"] = hit[1]
                        board.history[-1]["m"] = hit[2]
                    else:
                        arrs.append(board_to_arr(board)[0])
            ongoing_boards = [board for board in boards if board.winner is None]
            if not ongoing_boards:
                pbar.set_description(
                    f"Finished {len(boards) - len(ongoing_boards)} games",
                    refresh=False,
                )
                break
            if arrs:
                arrs = np.stack(arrs)
                misses = [
                    board
                    for board in ongoing_boards
                    if not self.board_cache
                    or board.history[-1]["fen"] not in self.board_cache
                ]
                results = self._get_results(tfps[turn], arrs)
                for result, board in zip(
                    results,
                    misses,
                ):
                    board.history[-1]["y"] = result["y"]
                    board.history[-1]["z"] = result["z"]
                    board.history[-1]["m"] = result["m"]

            if movenum % 2 == 0 and self.tfp_assistant and self.asst_thresh is not None:
                set_asst_policy(self.tfp_assistant, ongoing_boards)
                for board in ongoing_boards:
                    # set_asst_policy(self.tfp_assistant, board)
                    z = _softmax(board.history[-1]["z"])
                    z = z[0] + 0.5 * z[1]
                    z_asst = board.history[-1]["z_asst"]
                    if (z_asst - z) > self.asst_thresh:
                        board.history[-1]["y_orig"] = board.history[-1]["y"]
                        board.history[-1]["y"] = board.history[-1]["y_asst"]

            elif self.tfp_assistant:
                if movenum == 0:  # or movenum % 2 == 0:
                    set_asst_policy(self.tfp_assistant, boards)
                    for board in ongoing_boards:
                        board.history[-1]["y_orig"] = board.history[-1]["y"]
                        board.history[-1]["y"] = board.history[-1]["y_asst"]

            if self.first_moves and movenum == 0:
                self._push_policies(ongoing_boards, self.first_moves)
            else:
                self._push_policies(ongoing_boards)
            pbar.set_description(f"{len(ongoing_boards)} left", refresh=False)

        for board in boards:
            if board.winner is None:
                board.winner = 0.5
        return boards

    def run_half_match(self):
        boards1 = self._run_half_match(self.tfp1, self.tfp2)
        res = np.mean([board.winner for board in boards1 if board.winner is not None])
        return res, boards1

    def run_full_match(self):
        boards1 = self._run_half_match(self.tfp1, self.tfp2)
        boards2 = self._run_half_match(self.tfp2, self.tfp1)
        res = np.mean([board.winner for board in boards1 if board.winner is not None])
        res += 1 - np.mean(
            [board.winner for board in boards2 if board.winner is not None]
        )
        return res / 2, boards1, boards2


def set_asst_policy(tfp, boards: list[MatchBoard]):
    from data import board_to_arr, board_to_uci

    arrs = []
    for board in boards:
        for uci in board.history[-1]["legals"]:
            board.push_uci(uci)
            arrs.append(board_to_arr(board)[0])
            board.pop()
    arrs = np.stack(arrs)

    z, m = [], []
    num_batches = ((len(arrs) - 1) // 1024) + 1
    for i in range(num_batches):
        _, z_, m_, _ = tfp.model(arrs[i * 1024 : (i + 1) * 1024])
        z.append(z_.numpy())
        m.append(m_.numpy())
    z, m = np.concatenate(z), np.concatenate(m)

    z_ = _softmax(z)
    z_ = z[:, 2] + 0.5 * z[:, 1]

    # print(f"z.shape={z.shape}")

    i = 0
    for board in boards:
        legals = board.history[-1]["legals"]
        best = np.argmax(z_[i : i + len(legals)])
        y = np.zeros_like(board.history[-1]["y"])
        y[board_to_uci(board, legals[best])] = 1
        board.history[-1]["y_asst"] = y
        board.history[-1]["z_asst"] = z[i + best]
        board.history[-1]["intervention"] = legals[best]
        board.history[-1]["m_asst"] = m[i + best] + 1

        i += len(legals)


def run_match_wrapper(kwargs):
    from utils import import_tf

    import_tf()
    from nets.tfprocess import TFProcess
    import yaml

    tfp1 = kwargs.pop("tfp1")
    tfp2 = kwargs.pop("tfp2")
    if isinstance(tfp1, str) and tfp1 != "random":
        if "cfg" in kwargs:
            cfg = kwargs.pop("cfg")
        else:
            with open(f"{CFG_DIR}/t74-config.yaml", "r") as f:
                cfg = yaml.safe_load(f)
        tfp_path = tfp1
        tfp1 = TFProcess(cfg)
        tfp1.init_net()
        tfp1.replace_weights(tfp_path, ignore_errors=True)

    return Match(tfp1=tfp1, tfp2=tfp2, **kwargs).run_full_match()


def run_multi_match(kwargs_list, workers=1):
    if workers > 1:
        with Pool(workers) as p:
            res = p.map(run_match_wrapper, kwargs_list)
        return res
    else:
        return [run_match_wrapper(x) for x in kwargs_list]
