import argparse
import os
import pickle
import random
from tqdm import trange
import chess
import chess.engine
import chess.syzygy
import torch  # Used only for logging compatibility

# Existing project imports
import config
import log_writer
from evaluate import initialize_engine, close_engine, evaluate_board
from utils import get_default_experiment_name, backup_sources
from chess_utils import get_board_tensor, tb_probe_wdl, open_egtb, close_egtb, tb_probe_wdl_ab


def propose(board: chess.Board) -> chess.Board:
    """
    Proposes a new state by either flipping a non-king piece or moving a king.

    This proposal mechanism is symmetric. It defines a set of all possible atomic
    changes and selects one at random.
    """
    proposal_board = board.copy(stack=False)

    transform_piece = random.choice([True, False])

    piece_squares = [sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is not None]
    if transform_piece:
        piece_squares = [sq for sq in piece_squares if proposal_board.piece_type_at(sq) != chess.KING]
    source_square = random.choice(piece_squares)
    destination_square = random.choice([sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is None])

    piece = proposal_board.piece_at(source_square)
    proposal_board.remove_piece_at(source_square)

    if transform_piece:
        piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
        color = random.choice([chess.WHITE, chess.BLACK])
        piece = chess.Piece(piece_type, color)

    proposal_board.set_piece_at(destination_square, piece)

    return proposal_board


def propose_v1(board: chess.Board) -> chess.Board:
    """
    Proposes a new state by either flipping a non-king piece or moving a king.

    This proposal mechanism is symmetric. It defines a set of all possible atomic
    changes and selects one uniformly at random.
    """
    proposal_board = board.copy(stack=False)

    source_square = random.choice([sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is not None])
    piece = proposal_board.piece_at(source_square)

    destination_square = random.choice([sq for sq in chess.SQUARES if proposal_board.piece_at(sq) is None])

    proposal_board.remove_piece_at(source_square)

    if piece.piece_type != chess.KING:
        piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
        color = random.choice([chess.WHITE, chess.BLACK])
        piece = chess.Piece(piece_type, color)

    proposal_board.set_piece_at(destination_square, piece)

    return proposal_board


def move_loss(board, limit):
    pre_label = int(tb_probe_wdl(board))
    # pre_label = 2 - tb_probe_result(board)
    if pre_label == -2:
        return 0.0, "null"
    engine_move = evaluate_board(board, limit=limit)[1]
    if not board.is_legal(engine_move):
        return 2.0, None
    board.push(engine_move)
    if int(pre_label / 2) == 0:
        pre_label = -1
    post_label = -int(tb_probe_wdl_ab(board, alpha=-pre_label))
    board.pop()
    return 1.0 * (int(pre_label / 2) - int(post_label / 2)), engine_move


def calculate_reward(board: chess.Board, limit, args) -> float:
    """Calculates the reward for a given board state."""
    # 1. Check for basic validity.
    if not board.is_valid() or board.is_game_over() or len(board.piece_map()) > args.num_pieces:
        return 1e-5, None
    # Ensure exactly one king of each color exists.
    if len(board.pieces(chess.KING, chess.WHITE)) != 1 or len(board.pieces(chess.KING, chess.BLACK)) != 1:
        return 1e-5, None

    # 2. Calculate material balance score.
    piece_values = {p: v for p, v in zip([1, 2, 3, 4, 5], [1.0, 2.7, 3.2, 5.0, 9.0])}
    balance = sum((piece_values.get(p.piece_type, 0.0) * (1 if p.color else -1)) for p in board.piece_map().values())
    piece_score = 1.0 if abs(balance) <= 5 else 0.0

    # 3. Calculate the engine blunder score.
    engine_blunder, move = move_loss(board, limit)

    # 4. Combine scores into the final reward.
    reward = args.base_reward + args.reward_balance * piece_score + args.reward_fool * (engine_blunder ** 2)
    return max(reward, 1e-5), move  # Ensure reward is positive.


def log_mcmc(samples: list, iteration: int, acceptance_rate: float, args):
    """Logs MCMC performance statistics, mirroring the GFN logger."""
    if not samples:
        return

    # log_writer.save_fens(samples, "fens_mcmc", iteration, path="fens_mcmc")

    num_samples = args.log_frequency * 1024
    reg_error_count = 0
    max_error_count = 0
    success_fens = []

    for board, blunder, move in samples:
        assert blunder >= 1.0 * args.reward_fool
        reg_error_count += 1
        success_fens.append((board.fen(), move.uci() if move else "illegal"))
        if blunder >= 4.0 * args.reward_fool:
            max_error_count += 1

    # For visual logging, convert boards to tensors to use the existing log_writer function.
    # board_tensors = torch.stack([get_board_tensor(b) for b in samples])
    # log_writer.add_board_tensor_dist("val/dist", board_tensors.view(-1, 12, 64), iteration)

    log_writer.add_scalar("mcmc/acceptance_rate", acceptance_rate, iteration)
    log_writer.add_scalar("val/error_percent", 100.0 * reg_error_count / num_samples, iteration)
    log_writer.add_scalar("val/win_to_loss_percent", 100.0 * max_error_count / num_samples, iteration)
    # log_writer.save_fens(success_fens, "fens_mcmc", iteration, path="fens_mcmc")
    log_writer.save_fens(success_fens, "fens", iteration, path="fens_v2")

    print(
        f"MCMC Iter {iteration}: Acceptance Rate: {acceptance_rate:.4f}, Error Pct: {100.0 * reg_error_count / num_samples:.2f}%")


def save_checkpoint(filepath, iteration, board, accept_count, random_state):
    """Saves the MCMC state to a file."""
    checkpoint = {
        'iteration': iteration,
        'board_fen': board.fen(),
        'accept_count': accept_count,
        'random_state': random_state
    }
    with open(filepath, 'wb') as f:
        pickle.dump(checkpoint, f)


def load_checkpoint(filepath, args):
    """Loads the MCMC state from a file, or returns a default initial state."""
    if args.load and os.path.exists(filepath):
        with open(filepath, 'rb') as f:
            checkpoint = pickle.load(f)
        random.setstate(checkpoint['random_state'])
        print(f"Loading with FEN: {checkpoint['board_fen']} at iteration {checkpoint['iteration'] + 1}")
        print(f"Accept count is {checkpoint['accept_count']}")
        return (
            checkpoint['iteration'] + 1,
            checkpoint['accept_count'],
            chess.Board(checkpoint['board_fen'])
        )
    else:
        # Default initial state
        print(f"Default load state is {args.load} and filepath exists: {os.path.exists(filepath)}")
        if os.path.exists(filepath):
            print(f"Filepath {filepath} exists")
        else:
            print(f"Filepath {filepath} does not exist")
        board = chess.Board(None)
        w_king_sq = random.choice(chess.SQUARES)
        invalid_b_king_sqs = {w_king_sq} | set(chess.SquareSet(chess.BB_KING_ATTACKS[w_king_sq]))
        valid_b_king_sqs = [sq for sq in chess.SQUARES if sq not in invalid_b_king_sqs]
        b_king_sq = random.choice(valid_b_king_sqs)
        board.set_piece_at(w_king_sq, chess.Piece(chess.KING, chess.WHITE))
        board.set_piece_at(b_king_sq, chess.Piece(chess.KING, chess.BLACK))
        for i in range(2, args.num_pieces):
            square = random.choice([sq for sq in chess.SQUARES if board.piece_at(sq) is None])
            piece_type = random.choice([chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN])
            color = random.choice([chess.WHITE, chess.BLACK])
            piece = chess.Piece(piece_type, color)
            board.set_piece_at(square, piece)
        return 1, 0, board


def run_mcmc(args):
    """Main MCMC sampling loop."""
    config.tb_path = args.tb_path
    open_egtb()

    initialize_engine(args.engine, uci_separate=True)
    limit = chess.engine.Limit(nodes=args.nodes)

    limit_str = f"{args.depth}" if args.nodes is None else f"n{args.nodes}"
    setting_str = f"{args.engine}_us__move_{limit_str}_{args.num_pieces}"
    log_writer.init(log_path=os.path.join("..", "tensorboard", setting_str), name=args.name)
    # log_path = os.path.join("..", "tensorboard", "MCMC_PYCHESS", setting_str)
    # log_writer.init(log_path=log_path, name=args.name)

    # === MODIFIED: Initialize state from checkpoint or default ===
    checkpoint_path = os.path.join("..", "checkpoints", setting_str, args.name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_file = os.path.join(checkpoint_path, "mcmc_checkpoint.pkl")
    start_iter, accept_count, current_board = load_checkpoint(checkpoint_file, args)
    if start_iter > 1:
        print(f"Resuming experiment from iteration {start_iter - 1}")

    current_reward, current_move = calculate_reward(current_board, limit, args)

    samples = []
    it_mult = 1024  # Our batch size is 1, so this matches what we have in the main code.
    # === MCMC Sampling Loop ===
    for i in trange(start_iter, args.mcmc_steps + 1):
        proposed_board = propose(current_board)
        proposed_reward, engine_move = calculate_reward(proposed_board, limit, args)

        acceptance_prob = min(1.0, proposed_reward / current_reward)
        if random.random() < acceptance_prob:
            current_board = proposed_board
            current_reward = proposed_reward
            accept_count += 1

        if i > args.burn_in and proposed_reward >= args.reward_fool:
            samples.append((proposed_board.copy(), proposed_reward, engine_move))

        if i % (args.log_frequency * it_mult) == 0 and samples:
            log_mcmc(samples, i // it_mult, accept_count / i, args)
            samples = []
            save_checkpoint(checkpoint_file, i, current_board, accept_count, random.getstate())

    # Final cleanup
    log_writer.close()
    close_engine()
    close_egtb()


def main():
    parser = argparse.ArgumentParser(description='Run PyChess MCMC baseline for chess position generation.')
    # Arguments mirroring mcmc_baseline.py and main.py
    parser.add_argument('--mcmc-steps', type=int, default=1_000_000_000)
    parser.add_argument('--burn-in', type=int, default=50_000)
    parser.add_argument('--log-frequency', type=int, default=20_000)
    parser.add_argument('--name', type=str, default=get_default_experiment_name())
    parser.add_argument('--num-pieces', type=int, default=5)
    parser.add_argument('--engine', type=str, default='Winter')
    # parser.add_argument('--depth', type=int, default=1)
    parser.add_argument('--nodes', type=int, default=None)
    parser.add_argument('--tb-path', type=str, default="/home/Chess/TB_Merged")
    parser.add_argument('--base-reward', type=float, default=0.1)
    parser.add_argument('--reward-balance', type=float, default=0.9)
    parser.add_argument('--reward-fool', type=float, default=125)
    parser.add_argument('--load', default=False, action=argparse.BooleanOptionalAction,
                        help="Whether to continue from a previous save point.")

    args = parser.parse_args()
    print(args)
    run_mcmc(args)


if __name__ == '__main__':
    main()
