import argparse
import os
import pickle
import random
from tqdm import trange
import chess
import chess.engine
import chess.syzygy
import torch  # Used only for logging compatibility

# Existing project imports
import config
import log_writer
from evaluate import initialize_engine, close_engine, evaluate_board
from utils import get_default_experiment_name, backup_sources
from chess_utils import get_board_tensor, tb_probe_wdl, open_egtb, close_egtb, tb_probe_wdl_ab


def is_valid(board):
    return board.is_valid() and not board.is_game_over()


def propose(board: chess.Board) -> chess.Board:
    """
    Proposes a new state by either flipping a non-king piece or moving a king.

    This proposal mechanism is symmetric. It defines a set of all possible atomic
    changes and selects one at random.
    """
    proposal_board = board.copy(stack=False)

    transform_piece = random.choice([True, False])

    piece_squares = [sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is not None]
    if transform_piece:
        piece_squares = [sq for sq in piece_squares if proposal_board.piece_type_at(sq) != chess.KING]
    source_square = random.choice(piece_squares)
    destination_square = random.choice([sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is None])

    piece = proposal_board.piece_at(source_square)
    proposal_board.remove_piece_at(source_square)

    if transform_piece:
        piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
        color = random.choice([chess.WHITE, chess.BLACK])
        piece = chess.Piece(piece_type, color)

    proposal_board.set_piece_at(destination_square, piece)

    return proposal_board


def engine_is_fooled(board, limit):
    pre_label = int(tb_probe_wdl(board))
    # pre_label = 2 - tb_probe_result(board)
    if pre_label == -2:
        return False, "null"
    engine_move = evaluate_board(board, limit=limit)[1]
    if not board.is_legal(engine_move):
        return True, None
    board.push(engine_move)
    if int(pre_label / 2) == 0:
        pre_label = -1
    post_label = -int(tb_probe_wdl_ab(board, alpha=-pre_label))
    board.pop()
    return ((int(pre_label / 2) - int(post_label / 2)) > 0), engine_move


def log_novel(samples: list, iteration: int, args):
    """Logs MCMC performance statistics, mirroring the GFN logger."""
    if not samples:
        return

    # log_writer.save_fens(samples, "fens_mcmc", iteration, path="fens_mcmc")

    num_samples = args.log_frequency * 1024
    reg_error_count = 0
    success_fens = []

    for board, move in samples:
        reg_error_count += 1
        success_fens.append((board.fen(), move.uci() if move else "illegal"))

    # For visual logging, convert boards to tensors to use the existing log_writer function.
    # board_tensors = torch.stack([get_board_tensor(b) for b in samples])
    # log_writer.add_board_tensor_dist("val/dist", board_tensors.view(-1, 12, 64), iteration)

    log_writer.add_scalar("val/error_percent", 100.0 * reg_error_count / num_samples, iteration)
    # log_writer.save_fens(success_fens, "fens_mcmc", iteration, path="fens_mcmc")
    log_writer.save_fens(success_fens, "fens", iteration, path="fens_v2")

    print(f"MCMC Iter {iteration}: Error Pct: {100.0 * reg_error_count / num_samples:.2f}%")


def save_checkpoint(filepath, iteration, board, random_state):
    """Saves the MCMC state to a file."""
    checkpoint = {
        'iteration': iteration,
        'board_fen': board.fen(),
        'random_state': random_state
    }
    with open(filepath, 'wb') as f:
        pickle.dump(checkpoint, f)


def load_checkpoint(filepath, args):
    """Loads the MCMC state from a file, or returns a default initial state."""
    if args.load and os.path.exists(filepath):
        with open(filepath, 'rb') as f:
            checkpoint = pickle.load(f)
        random.setstate(checkpoint['random_state'])
        print(f"Loading with FEN: {checkpoint['board_fen']} at iteration {checkpoint['iteration'] + 1}")
        #print(f"Accept count is {checkpoint['accept_count']}")
        return (
            checkpoint['iteration'] + 1,
            chess.Board(checkpoint['board_fen'])
        )
    else:
        # Default initial state
        print(f"Default load state is {args.load} and filepath exists: {os.path.exists(filepath)}")
        if os.path.exists(filepath):
            print(f"Filepath {filepath} exists")
        else:
            print(f"Filepath {filepath} does not exist")
        board = chess.Board(None)
        w_king_sq = random.choice(chess.SQUARES)
        invalid_b_king_sqs = {w_king_sq} | set(chess.SquareSet(chess.BB_KING_ATTACKS[w_king_sq]))
        valid_b_king_sqs = [sq for sq in chess.SQUARES if sq not in invalid_b_king_sqs]
        b_king_sq = random.choice(valid_b_king_sqs)
        board.set_piece_at(w_king_sq, chess.Piece(chess.KING, chess.WHITE))
        board.set_piece_at(b_king_sq, chess.Piece(chess.KING, chess.BLACK))
        for i in range(2, args.num_pieces):
            square = random.choice([sq for sq in chess.SQUARES if board.piece_at(sq) is None])
            piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
            color = random.choice([chess.WHITE, chess.BLACK])
            piece = chess.Piece(piece_type, color)
            board.set_piece_at(square, piece)
        return 1, board


def add_neighbors(board, queue):
    destination_squares = [sq for sq in chess.SQUARES if board.piece_at(sq) is None]
    random.shuffle(destination_squares)
    piece_squares = [sq for sq in chess.SQUARES if board.piece_at(sq) is not None]
    fen = board.fen()

    for destination_square in destination_squares:
        random.shuffle(piece_squares)
        for piece_square in piece_squares:
            proposal_board = board.copy(stack=False)
            piece = proposal_board.piece_at(piece_square)
            if piece.piece_type != chess.KING and random.randint(0, 9) == 0:
                piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
                color = random.choice([chess.WHITE, chess.BLACK])
                piece = chess.Piece(piece_type, color)
            proposal_board.remove_piece_at(piece_square)
            proposal_board.set_piece_at(destination_square, piece)
            if is_valid(proposal_board):
                queue.append(proposal_board)
    assert fen == board.fen()


def add_neighbors_old(board, queue):
    piece_squares = [sq for sq in chess.SQUARES if board.piece_at(sq) is not None]
    random.shuffle(piece_squares)
    fen = board.fen()

    for piece_square in piece_squares:
        destination_squares = [sq for sq in chess.SQUARES if board.piece_at(sq) is None]
        random.shuffle(destination_squares)

        for destination_square in destination_squares:
            proposal_board = board.copy(stack=False)
            piece = proposal_board.piece_at(piece_square)
            proposal_board.remove_piece_at(piece_square)
            proposal_board.set_piece_at(destination_square, piece)
            if is_valid(proposal_board):
                queue.append(proposal_board)
    assert fen == board.fen()


def run_mcmc(args):
    """Main MCMC sampling loop."""
    config.tb_path = args.tb_path
    open_egtb()

    initialize_engine(args.engine, uci_separate=True)
    limit = chess.engine.Limit(nodes=args.nodes)

    limit_str = f"{args.depth}" if args.nodes is None else f"n{args.nodes}"
    setting_str = f"{args.engine}_us__move_{limit_str}_{args.num_pieces}"
    log_writer.init(log_path=os.path.join("..", "tensorboard", setting_str), name=args.name)
    # log_path = os.path.join("..", "tensorboard", "MCMC_PYCHESS", setting_str)
    # log_writer.init(log_path=log_path, name=args.name)

    # === MODIFIED: Initialize state from checkpoint or default ===
    checkpoint_path = os.path.join("..", "checkpoints", setting_str, args.name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_file = os.path.join(checkpoint_path, "mcmc_checkpoint.pkl")
    start_iter, current_board = load_checkpoint(checkpoint_file, args)
    if start_iter > 1:
        print(f"Resuming experiment from iteration {start_iter - 1}")

    samples = []
    it_mult = 1024  # Our batch size is 1, so this matches what we have in the main code.
    tried = set()
    queue = list()

    for i in trange(start_iter, args.mcmc_steps + 1):
        proposed_board = None
        while len(queue) > 0 and proposed_board is None:
            queue_board = queue.pop()
            if queue_board.fen() not in tried:
                proposed_board = queue_board

        # Avoid memory issues if we have a high density of successful samples.
        if len(queue) > 10000:
            queue = queue[::2]

        if proposed_board is None:
            tried.clear()
            proposed_board = propose(current_board)
            idx = 1
            while idx < 3 or not is_valid(proposed_board):
                proposed_board = propose(proposed_board)
                idx += 1

        current_board = proposed_board
        assert is_valid(current_board)
        tried.add(current_board.fen())
        board_fools_engine, engine_move = engine_is_fooled(current_board, limit)

        if i > args.burn_in and board_fools_engine:
            samples.append((current_board.copy(), engine_move))
            add_neighbors(current_board, queue)

        if i % (args.log_frequency * it_mult) == 0 and samples:
            log_novel(samples, i // it_mult, args)
            samples = []
            save_checkpoint(checkpoint_file, i, current_board, random.getstate())

    # Final cleanup
    log_writer.close()
    close_engine()
    close_egtb()


def main():
    parser = argparse.ArgumentParser(description='Run PyChess MCMC baseline for chess position generation.')
    # Arguments mirroring mcmc_baseline.py and main.py
    parser.add_argument('--mcmc-steps', type=int, default=1_000_000_000)
    parser.add_argument('--burn-in', type=int, default=50_000)
    parser.add_argument('--log-frequency', type=int, default=20_000)
    parser.add_argument('--name', type=str, default=get_default_experiment_name())
    parser.add_argument('--num-pieces', type=int, default=5)
    parser.add_argument('--engine', type=str, default='Winter')
    # parser.add_argument('--depth', type=int, default=1)
    parser.add_argument('--nodes', type=int, default=None)
    parser.add_argument('--tb-path', type=str, default="/home/Chess/TB_Merged")
    parser.add_argument('--base-reward', type=float, default=0.1)
    parser.add_argument('--reward-balance', type=float, default=0.9)
    parser.add_argument('--reward-fool', type=float, default=125)
    parser.add_argument('--load', default=False, action=argparse.BooleanOptionalAction,
                        help="Whether to continue from a previous save point.")

    args = parser.parse_args()
    print(args)
    run_mcmc(args)


if __name__ == '__main__':
    main()
