from typing import Optional, Tuple, Union
import chess
import torch
from gfn.env import DiscreteEnv
from gfn.preprocessors import IdentityPreprocessor
from gfn.states import States, DiscreteStates
from torchtyping import TensorType

from chess_utils import tb_probe_result, softmax, tensor_to_board, get_xy, tb_probe_wdl, tb_probe_wdl_ab
from evaluate import evaluate_board

_king_configurations = None
_prefiltered = None


def _valid_and_adjacent(k, p, max_dis=1):
    if p < 0 or p >= 64:
        return False
    kx, ky = get_xy(k)
    px, py = get_xy(p)
    return abs(kx-px) <= max_dis and abs(ky-py) <= max_dis


def _init_king_configurations():
    global _king_configurations
    if _king_configurations is not None:
        return
    global _prefiltered
    configurations = list()
    prefiltered = list()
    for w_king in range(64):
        wx = w_king % 8
        wy = w_king // 8
        for b_king in range(64):
            bx = b_king % 8
            by = b_king // 8
            if abs(wx-bx) <= 1 and abs(wy-by) <= 1:
                continue
            configuration = torch.zeros((12, 64))
            configuration[chess.KING-1, w_king] = 1
            configuration[chess.KING - 1 + 6, b_king] = 1
            configurations.append(configuration)

            valid_actions = torch.ones((12, 64)).bool()
            valid_actions[chess.PAWN - 1, 0:8] = 0
            valid_actions[chess.PAWN - 1, (7 * 8):(8 * 8)] = 0
            valid_actions[chess.PAWN - 1 + 6, 0:8] = 0
            valid_actions[chess.PAWN - 1 + 6, (7 * 8):(8 * 8)] = 0

            if _valid_and_adjacent(b_king, b_king-9):
                pieces = [p-1 for p in [chess.PAWN, chess.BISHOP, chess.QUEEN]]
                valid_actions[pieces, b_king-9] = 0
            if _valid_and_adjacent(b_king, b_king-7):
                pieces = [p-1 for p in [chess.PAWN, chess.BISHOP, chess.QUEEN]]
                valid_actions[pieces, b_king-7] = 0
            for off in [7, 9]:
                if _valid_and_adjacent(b_king, b_king + off):
                    pieces = [p - 1 for p in [chess.BISHOP, chess.QUEEN]]
                    valid_actions[pieces, b_king + off] = 0
            for off in [-8, -1, 1, 8]:
                if _valid_and_adjacent(b_king, b_king + off):
                    pieces = [p - 1 for p in [chess.ROOK, chess.QUEEN]]
                    valid_actions[pieces, b_king + off] = 0
            for off in [-17, -15, -10, -6, 6, 10, 15, 17]:
                if _valid_and_adjacent(b_king, b_king + off, max_dis=2):
                    pieces = [p - 1 for p in [chess.KNIGHT]]
                    valid_actions[pieces, b_king + off] = 0
            prefiltered.append(valid_actions)

    _prefiltered = torch.stack(prefiltered, dim=0).view(-1, 12, 64)
    _king_configurations = torch.stack(configurations, dim=0).view(-1, 12 * 64)


class ChessEnv(DiscreteEnv):
    def __init__(self, num_pieces=5, device=torch.device("cpu"), base_reward=0.1, reward_balance=0.9,
                 uniform_kings=True):
        """Chess environment. Based on FacesEnv from GFN tutorial.
        States are represented as 12x64-element binary tensors.

        All trajectories are enforced to be length X using states.forward_masks.
        """
        if uniform_kings:
            _init_king_configurations()
        self.num_pieces = num_pieces
        self.device = device
        self.base_reward = base_reward
        self.reward_balance = reward_balance
        self.piece_values = torch.tensor([[1.0, 2.7, 3.2, 5.0, 9.0, 0, -1.0, -2.7, -3.2, -5.0, -9.0, 0]],
                                         device=device)
        self.uniform_kings = uniform_kings

        state_dim = 12 * 64  # len(self.feature_keys)
        n_actions = state_dim + 1 - 2*64 if uniform_kings else state_dim + 1  # regular actions + 1 exit action.

        self.prefiltered = None

        super().__init__(
            n_actions=n_actions,
            # We start with an empty face.
            s0=torch.zeros(state_dim, dtype=torch.float, device=self.device),
            state_shape=(state_dim,),
            # Sf represents when a trajectory is done (we selected the exit action).
            sf=torch.ones(state_dim, dtype=torch.float, device=self.device) * -1,
            device_str=self.device,
            # These are sometimes handy to generate tensors. In this case, not needed.
            preprocessor=IdentityPreprocessor(output_dim=state_dim)
        )

    def update_masks(self, states):
        """Update the masks based on the current states."""
        # Backward masks are simply any action we've already taken.
        previous_actions = torch.cat([states.tensor[:, :(chess.QUEEN * 64)],
                                      states.tensor[:, (chess.KING * 64):((chess.KING + chess.QUEEN) * 64)]], dim=1) \
            if self.uniform_kings else states.tensor
        states.backward_masks = previous_actions != 0  # n - 1 actions.

        # Forward masks begin as allowing any action. Allowed elementsactions are 1.
        states.init_forward_masks(set_ones=True)

        # Then, we remove any done action, and also the exit action.
        if states.tensor[0].sum() >= 7:
            states.set_nonexit_action_masks(previous_actions == 1, allow_exit=True)
        else:
            states.set_nonexit_action_masks(previous_actions == 1, allow_exit=False)

        board_tensor = states.tensor[..., :(12 * 64)].view(-1, 12, 64)
        if not self.uniform_kings:
            assert self.prefiltered is None
            valid_actions = torch.ones(board_tensor.shape).bool()
            valid_actions[..., chess.PAWN - 1, 0:8] = 0
            valid_actions[..., chess.PAWN - 1, (7 * 8):(8 * 8)] = 0
            valid_actions[..., chess.PAWN - 1 + 6, 0:8] = 0
            valid_actions[..., chess.PAWN - 1 + 6, (7 * 8):(8 * 8)] = 0
            piece_counts = board_tensor.sum(dim=-1)
            # print(piece_counts[:5])
            valid_actions[..., chess.KING - 1, :] *= (piece_counts[:, chess.KING - 1] == 0).view(-1, 1)
            valid_actions[..., chess.KING - 1 + 6, :] *= (piece_counts[:, chess.KING - 1 + 6] == 0).view(-1, 1)
            num_pieces = piece_counts[0].sum()
            num_kings = (piece_counts[:, chess.KING - 1] + piece_counts[:, chess.KING - 1 + 6]).view(-1, 1)
            sufficient_time_for_kings = ((num_pieces - num_kings) < (self.num_pieces - 2)).view(-1, 1, 1)
            # print(sufficient_time_for_kings[:5])
            valid_actions[:, :(chess.KING - 1), :] *= sufficient_time_for_kings
            valid_actions[:, chess.KING:(chess.KING - 1 + 6), :] *= sufficient_time_for_kings
        else:
            valid_actions = self.prefiltered

        square_counts = board_tensor.sum(dim=-2)
        square_empty = (square_counts == 0).view(-1, 1, 64)
        valid_actions[..., :, :] = valid_actions[..., :, :] * square_empty[..., :]

        if self.uniform_kings:
            valid_actions = torch.cat([valid_actions[:, :chess.QUEEN],
                                       valid_actions[:, chess.KING:(chess.KING + chess.QUEEN)]], dim=1)
        action_channels = 10 if self.uniform_kings else 12
        valid_actions = valid_actions.view(-1, action_channels * 64)
        states.forward_masks[..., :(action_channels * 64)] = (states.forward_masks[..., :(action_channels * 64)]
                                                              * valid_actions)

        # Trajectories must be length 4. Any trajectory that has taken 3 actions
        # should be forced to exit.
        batch_idx = states.tensor.sum(-1) >= self.num_pieces
        states.set_exit_masks(batch_idx)

    def reset(
        self,
        batch_shape: Optional[Union[int, Tuple[int]]] = None,
        random: bool = False,
        sink: bool = False,
        seed: int = None,
    ) -> States:
        """Instantiates a batch of initial states.

        `random` and `sink` cannot be both True. When `random` is `True` and `seed` is
            not `None`, environment randomization is fixed by the submitted seed for
            reproducibility.
        """
        assert not (random and sink)

        if random and seed is not None:
            torch.manual_seed(seed)

        if batch_shape is None:
            batch_shape = (1,)
        if isinstance(batch_shape, int):
            batch_shape = (batch_shape,)
        states = self.States.from_batch_shape(
            batch_shape=batch_shape, random=random, sink=sink
        )

        if self.uniform_kings:
            global _king_configurations
            global _prefiltered
            # print(_king_configurations.shape)
            # print(f"Shape of tensor is: {states.tensor.shape}")
            indexes = torch.randint(_king_configurations.shape[0], (states.tensor.shape[0],))
            states.tensor = _king_configurations[indexes].detach()
            self.prefiltered = _prefiltered[indexes].detach()
            # print(f"Shape of tensor is: {states.tensor.shape}")
        else:
            states.tensor = torch.zeros((states.tensor.shape[0], 12*64))
        self.update_masks(states)

        return states

    def step(self, states, actions):
        # print("step")
        # print(f"Actions: {actions.tensor}")
        # print(f"Modified: {actions.tensor + actions.tensor // (64*chess.QUEEN)}")
        if not self.uniform_kings:
            return states.tensor.scatter(-1, actions.tensor, 1, reduce="add")
        return states.tensor.scatter(-1, actions.tensor + 64 * (actions.tensor // (64 * chess.QUEEN)), 1, reduce="add")

    def backward_step(self, states, actions):
        # print("backward step")
        if not self.uniform_kings:
            return states.tensor.scatter(-1, actions.tensor, -1, reduce="add")
        return states.tensor.scatter(-1, actions.tensor + 64 * (actions.tensor // (64 * chess.QUEEN)), -1, reduce="add")

    def compute_valid_rewards(self, board_tensors, boards):
        assert board_tensors.shape[-1] == 64
        piece_score = (board_tensors.sum(dim=-1) * self.piece_values).sum(dim=1).abs() <= 5
        return (self.base_reward + self.reward_balance * piece_score).view(-1, 1)

    def make_states_class(self) -> type[States]:
        env = self

        class ChessEnvStates(DiscreteStates):
            state_shape = env.state_shape
            s0 = env.s0
            sf = env.sf
            make_random_states_tensor = env.make_random_states_tensor
            n_actions = env.n_actions
            device = env.device

            @property
            def is_initial_state(self) -> TensorType["batch_shape", torch.bool]:
                """Return a tensor that is True for states that are $s_0$ of the DAG."""
                s = self.tensor.sum(dim=1)
                if env.uniform_kings:
                    return s <= 2
                return s == 0

            def set_exit_masks(self, batch_idx):
                """Sets forward masks such that the only allowable next action is to exit.

                A convenience function for common mask operations.

                Args:
                    batch_idx: A Boolean index along the batch dimension, along which to
                        enforce exits.
                """
                self.forward_masks[batch_idx, :] = torch.cat(
                    [
                        torch.zeros((torch.sum(batch_idx), env.n_actions-1)),
                        torch.ones((torch.sum(batch_idx),) + (1,)),
                    ],
                    dim=-1,
                ).bool()

            def init_forward_masks(self, set_ones: bool = True):
                """Initalizes forward masks.

                A convienience function for common mask operations.

                Args:
                    set_ones: if True, forward masks are initalized to all ones. Otherwise,
                        they are initalized to all zeros.
                """
                shape = self.batch_shape + (self.n_actions,)
                if set_ones:
                    self.forward_masks = torch.ones(shape).bool()
                else:
                    self.forward_masks = torch.zeros(shape).bool()

        return ChessEnvStates

    def reward(self, states):
        rewards = torch.ones(states.batch_shape + (1,))
        if states.batch_shape[0] > 0:
            r_ids = list()
            r_boards = list()
            board_tensor = states.tensor[..., :(12 * 64)].view(-1, 12, 64)
            piece_counts = board_tensor.sum(dim=-1)
            rewards[piece_counts[:, (chess.KING - 1):chess.KING] != 1] = torch.tensor([1e-5])  # 1e-5
            rewards[piece_counts[:, (chess.KING - 1 + 6):(chess.KING + 6)] != 1] = torch.tensor([1e-5])  # 1e-5
            for idx in range(board_tensor.shape[0]):
                board = tensor_to_board(board_tensor[idx].view(12 * 64))
                # if idx < 5:
                #     print(f"Board validity: {board.is_valid()}")
                #     print(board)
                if not board.is_valid() or board.is_game_over():
                    rewards[idx] = torch.tensor([1e-5])  # 1e-5
                else:
                    r_ids.append(idx)
                    r_boards.append(board)
            if len(r_ids) > 0:
                rewards[r_ids] = self.compute_valid_rewards(board_tensor[r_ids], r_boards)

        if states.batch_shape[0] > 0:
            print(f"Rewards: {rewards.sum()}/{states.batch_shape}: "
                  f"{states.tensor.sum(dim=-1).min()}-{states.tensor.sum(dim=-1).max()}")

        return rewards.squeeze()

    def log_reward(self, states):
        return torch.log(self.reward(states))


def point_reward(model, board_tensors, labels):
    with torch.no_grad():
        labels = 1 - labels
        preds = softmax(model(board_tensors.view(-1, 12 * 64)))
        preds = preds[:, 0] - preds[:, 2]
        rewards = torch.abs(labels - preds)
    return rewards.view(-1, 1)


class OutcomeEnv(ChessEnv):
    def __init__(self, target_model=None, num_pieces=5, device=torch.device("cpu"),
                 # reward_function=lambda m, b, t: 1e-1 + 25 * point_reward(m, b, t) ** 2,
                 reward_function=lambda m, b, t: -1 + 1.5 * point_reward(m, b, t) ** 2):
        assert callable(reward_function)
        self.target_model = target_model
        self.reward_function = reward_function

        super().__init__(num_pieces, device)

    def compute_valid_rewards(self, board_tensors, boards):
        labels = [torch.tensor([tb_probe_result(board)], dtype=torch.int64) for board in boards]
        return self.reward_function(self.target_model, board_tensors, torch.cat(labels))


class MoveEnv(ChessEnv):
    def __init__(self, num_pieces=5, device=torch.device("cpu"), depth=1, base_reward=0.1, uniform_kings=True,
                 reward_balance=0.9, reward_fool=125,
                 nodes=None, engine_limit=None):
        if engine_limit is None:
            if nodes is not None:
                engine_limit = chess.engine.Limit(nodes=nodes)
            else:
                engine_limit = chess.engine.Limit(depth=depth)
        self.good_fens = list()
        self.illegal_moves = list()
        self.engine_limit = engine_limit
        # self.reward_balance = reward_balance
        self.reward_fool = reward_fool
        # self.base_reward = base_reward
        super().__init__(num_pieces, device, base_reward=base_reward, reward_balance=reward_balance,
                         uniform_kings=uniform_kings)

    def compute_valid_rewards(self, board_tensors, boards):
        assert board_tensors.shape[-1] == 64
        piece_score = (board_tensors.sum(dim=-1) * self.piece_values).sum(dim=1).abs() <= 5

        def engine_loss(board):
            label = torch.tensor([tb_probe_wdl(board)], dtype=torch.float)
            if label == -2:
                return torch.zeros((1,))
            _, move = evaluate_board(board, limit=self.engine_limit)
            if not board.is_legal(move):
                self.good_fens.append((board.fen(), "illegal"))
                self.illegal_moves.append((board.fen(), "illegal"))
                return 2 * torch.ones((1,))
            board.push(move)
            post_move_label = -torch.tensor([tb_probe_wdl_ab(board, alpha=-label)], dtype=torch.float)
            board.pop()
            if (int(label / 2) - int(post_move_label / 2)) > 0:
                self.good_fens.append((board.fen(), move.uci()))
            return 0.5 * torch.nn.functional.relu(label - post_move_label)

        labels = [engine_loss(board) for idx, board in enumerate(boards)]

        return (self.base_reward + self.reward_balance * piece_score
                + self.reward_fool * torch.cat(labels) ** 2).view(-1, 1)

