from typing import Optional, Tuple, Union
import chess
import torch
import gymnasium
from gymnasium import spaces
from torchtyping import TensorType

from chess_utils import tb_probe_result, softmax, tensor_to_board, get_xy, tb_probe_wdl, tb_probe_wdl_ab
from evaluate import evaluate_board, get_engine_instance, engine_eval
import config

_king_configurations = None
_prefiltered = None


def _valid_and_adjacent(k, p, max_dis=1):
    if p < 0 or p >= 64:
        return False
    kx, ky = get_xy(k)
    px, py = get_xy(p)
    return abs(kx-px) <= max_dis and abs(ky-py) <= max_dis


def _init_king_configurations():
    global _king_configurations
    if _king_configurations is not None:
        return
    global _prefiltered
    configurations = list()
    prefiltered = list()
    for w_king in range(64):
        wx = w_king % 8
        wy = w_king // 8
        for b_king in range(64):
            bx = b_king % 8
            by = b_king // 8
            if abs(wx-bx) <= 1 and abs(wy-by) <= 1:
                continue
            configuration = torch.zeros((12, 64))
            configuration[chess.KING-1, w_king] = 1
            configuration[chess.KING - 1 + 6, b_king] = 1
            configurations.append(configuration)

            valid_actions = torch.ones((12, 64)).bool()
            valid_actions[chess.PAWN - 1, 0:8] = 0
            valid_actions[chess.PAWN - 1, (7 * 8):(8 * 8)] = 0
            valid_actions[chess.PAWN - 1 + 6, 0:8] = 0
            valid_actions[chess.PAWN - 1 + 6, (7 * 8):(8 * 8)] = 0

            if _valid_and_adjacent(b_king, b_king-9):
                pieces = [p-1 for p in [chess.PAWN, chess.BISHOP, chess.QUEEN]]
                valid_actions[pieces, b_king-9] = 0
            if _valid_and_adjacent(b_king, b_king-7):
                pieces = [p-1 for p in [chess.PAWN, chess.BISHOP, chess.QUEEN]]
                valid_actions[pieces, b_king-7] = 0
            for off in [7, 9]:
                if _valid_and_adjacent(b_king, b_king + off):
                    pieces = [p - 1 for p in [chess.BISHOP, chess.QUEEN]]
                    valid_actions[pieces, b_king + off] = 0
            for off in [-8, -1, 1, 8]:
                if _valid_and_adjacent(b_king, b_king + off):
                    pieces = [p - 1 for p in [chess.ROOK, chess.QUEEN]]
                    valid_actions[pieces, b_king + off] = 0
            for off in [-17, -15, -10, -6, 6, 10, 15, 17]:
                if _valid_and_adjacent(b_king, b_king + off, max_dis=2):
                    pieces = [p - 1 for p in [chess.KNIGHT]]
                    valid_actions[pieces, b_king + off] = 0
            prefiltered.append(valid_actions)

    _prefiltered = torch.stack(prefiltered, dim=0).view(-1, 12, 64)
    _king_configurations = torch.stack(configurations, dim=0).view(-1, 12 * 64)


class ChessGymEnv(gymnasium.Env):
    def __init__(self, num_pieces=5):
        """Chess environment. Based on FacesEnv from GFN tutorial.
        States are represented as 12x64-element binary tensors.

        All trajectories are enforced to be length X using states.forward_masks.
        """
        _init_king_configurations()
        self.num_pieces = num_pieces
        self.piece_values = torch.tensor([[1.0, 2.7, 3.2, 5.0, 9.0, 0, -1.0, -2.7, -3.2, -5.0, -9.0, 0]],
                                         device=config.device)

        state_dim = 12 * 64  # len(self.feature_keys)
        n_actions = state_dim + 1 - 2*64  # all regular actions + 1 exit action.
        self.prefiltered = None
        self.action_space = spaces.Discrete(10*64)
        self.observation_space = spaces.MultiBinary(12*64)

    def reset(
        self,
        batch_shape: Optional[Union[int, Tuple[int]]] = None,
        random: bool = False,
        sink: bool = False,
        seed: int = None,
    ):
        """Instantiates a batch of initial states.

        `random` and `sink` cannot be both True. When `random` is `True` and `seed` is
            not `None`, environment randomization is fixed by the submitted seed for
            reproducibility.
        """
        assert not (random and sink)

        #if random and seed is not None:
        #    torch.manual_seed(seed)

        global _king_configurations
        global _prefiltered
        # print(_king_configurations.shape)
        # print(f"Shape of tensor is: {states.tensor.shape}")
        index = torch.randint(_king_configurations.shape[0], (1,))
        # index = torch.tensor([42])
        self.states = _king_configurations[index:index+1].clone().detach().cpu()
        # print(index)
        # print(self.states.nonzero())
        # print(f"Reset -> {index}")

        return self.states, {}

    def step(self, actions):
        # print(f"Step: {actions}")
        self.states[0, actions + 64 * (actions // (64*chess.QUEEN))] += 1
        # self.states.scatter(-1, actions + 64 * (actions // (64*chess.QUEEN)), 1, reduce="add")
        return (self.states,
                self.reward(self.states),
                self.states.sum() >= self.num_pieces or self.invalid(),
                False,
                dict())

    def invalid(self):
        board_tensor = self.states.view(12, 64)
        square_counts = board_tensor.sum(dim=0)
        if square_counts.max() > 1:
            # print("State is invalid due to multiple pieces on square.")
            return True
        pawns = board_tensor.view(12, 8, 8)[chess.PAWN-1::6]
        if pawns[:, 0].sum() + pawns[:, 7].sum() > 0:
            #print("State is invalid due to pawns on the back rank")
            return True
        return False

    def backward_step(self, actions):
        # print("backward step")
        return self.states.scatter(-1, actions + 64 * (actions // (64*chess.QUEEN)), -1, reduce="add")

    def compute_valid_rewards(self, board_tensors, boards):
        assert board_tensors.shape[-1] == 64
        piece_score = (board_tensors.sum(dim=-1) * self.piece_values).sum(dim=1).abs() <= 5
        return (0.1 + 0.9 * piece_score).view(-1, 1)

    def reward(self, states):
        board_tensor = states.view(-1, 12, 64)
        assert board_tensor.shape[0] == 1
        if states.sum() != self.num_pieces:
            return 0
        board = tensor_to_board(board_tensor[0].view(12 * 64))
        if not board.is_valid() or board.is_game_over() or self.invalid():
            # print("Reward is 0 due to non-valid board according to python chess")
            return 0
        # print(f"Computing Reward for {board.fen()}")
        return self.compute_valid_rewards(board_tensor, [board]).squeeze()

    #def log_reward(self, states):
    #    return torch.log(self.reward(states))


class MoveGymEnv(ChessGymEnv):
    def __init__(self, num_pieces=5, nodes=400, engine="Stockfish"):
        self.engine = engine  # get_engine_instance(engine)
        print(f"Initializing MoveGymEnv with {engine}")
        self.nodes = nodes
        self.success = list()
        super().__init__(num_pieces)

    def compute_valid_rewards(self, board_tensors, boards):
        assert board_tensors.shape[-1] == 64
        board_tensors = board_tensors.to(config.device)
        piece_score = (board_tensors.sum(dim=-1) * self.piece_values).sum(dim=1).abs() <= 5
        if isinstance(self.engine, str):
            self.engine = get_engine_instance(self.engine)

        def engine_loss(board):
            # print(f"Probing {board.fen()}")
            label = torch.tensor([tb_probe_wdl(board)], dtype=torch.float)
            if label == -2:
                return torch.zeros((1,))
            _, move = engine_eval(board, self.engine, nodes=self.nodes)
            board.push(move)
            post_move_label = -torch.tensor([tb_probe_wdl_ab(board, alpha=-label)], dtype=torch.float)
            if int(0.5 * label) > int(0.5 * post_move_label):  # int(0.5 * label) > int(0.5 * post_move_label):
                board.pop()
                self.success.append((board.fen(), move.uci()))
                # print(f"Internal success: {self.success}")
            return 0.5 * torch.nn.functional.relu(label - post_move_label)

        labels = [engine_loss(board) for idx, board in enumerate(boards)]
        # print(f"Reward: {0.1 + 0.9 * piece_score + 125 * torch.cat(labels) ** 2}")

        return (0.1 + 0.9 * piece_score + 125 * torch.cat(labels) ** 2).view(-1, 1)
