Add python c binding class interface #29

Merged
Josh merged 10 commits from 28-move-bindings-to-class into main 2025-08-22 19:28:03 +00:00
Showing only changes of commit dbb2b8ae42 - Show all commits

View File

@@ -1,188 +0,0 @@
"""
FFI - Foreign Function Interface
This module needs to reflect the function interfaces that are
defined in our C modules. This may or may not be a good way to
test C since we essentially need to maintain two sets of interfaces.
"""
import ctypes as C
def _lib_path():
# Just use a relative path from makefile.
return "./build/libchess.so"
_lib = C.CDLL(str(_lib_path()))
FILES = {c:i for i,c in enumerate("abcdefgh")}
WHITE, BLACK, BOTH = 0, 1, 2
P, N, B, R, Q, K, p, n, b, r, q, k = range(12)
class Board(C.Structure):
_fields_ = [
("pieces", C.c_uint64 * 12),
("occ", C.c_uint64 * 3),
("king_square", C.c_uint64 * 2),
("castling_rights", C.c_uint8),
("ep_square", C.c_int),
("side_to_move", C.c_int),
("halfmove_clock", C.c_int),
("fullmove_number", C.c_int),
]
class Move(C.Structure):
_fields_ = [
("from", C.c_uint16),
("to", C.c_uint16),
("piece", C.c_uint8),
("promo", C.c_uint8),
("flags", C.c_uint8),
]
_lib.load_fen.argtypes = (C.POINTER(Board), C.c_char_p)
_lib.load_fen.restype = C.c_int
_lib.gen_pseudo_moves.argtypes = (C.POINTER(Board), C.POINTER(Move), C.c_bool)
_lib.gen_pseudo_moves.restype = C.c_int
def _bind_opt(name, argtypes=(), restype=None):
fn = getattr(_lib, name, None)
if fn is not None:
fn.argtypes = argtypes
fn.restype = restype
return fn
create_knight_attack_cache = _bind_opt("create_knight_attack_cache", (), None)
create_king_attack_cache = _bind_opt("create_king_attack_cache", (), None)
create_pawn_attack_cache = _bind_opt("create_pawn_attack_cache", (), None)
# PAWN move generation.
PAWN_SIG = (C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int)), None
gen_white_pawn_quiet_pushes = _bind_opt("gen_white_pawn_quiet_pushes", *PAWN_SIG)
gen_black_pawn_quiet_pushes = _bind_opt("gen_black_pawn_quiet_pushes", *PAWN_SIG)
gen_white_pawn_push_promotions = _bind_opt("gen_white_pawn_push_promotions", *PAWN_SIG)
gen_black_pawn_push_promotions = _bind_opt("gen_black_pawn_push_promotions", *PAWN_SIG)
gen_white_pawn_captures = _bind_opt("gen_white_pawn_captures", *PAWN_SIG)
gen_black_pawn_captures = _bind_opt("gen_black_pawn_captures", *PAWN_SIG)
gen_white_pawn_capture_promotions = _bind_opt("gen_white_pawn_capture_promotions", *PAWN_SIG)
gen_black_pawn_capture_promotions = _bind_opt("gen_black_pawn_capture_promotions", *PAWN_SIG)
# Non pawn move generation.
PIECE_SIG = ((C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int), C.c_bool), None)
gen_knight_moves = _bind_opt("gen_knight_moves", *PIECE_SIG)
gen_bishop_moves = _bind_opt("gen_bishop_moves", *PIECE_SIG)
gen_rook_moves = _bind_opt("gen_rook_moves", *PIECE_SIG)
gen_queen_moves = _bind_opt("gen_queen_moves", *PIECE_SIG)
gen_king_moves = _bind_opt("gen_king_moves", *PIECE_SIG)
# Attack checks.
ATTACKED_SIG = (C.POINTER(Board), C.c_int, C.c_int)
INCHECK_ARGS = (C.POINTER(Board), C.c_int)
ATTACKERS_TO = (C.POINTER(Board), C.c_int, C.c_int)
GEN_LEGAL_MOVES = (C.POINTER(Board), C.POINTER(Move))
square_attacked = _bind_opt("square_attacked", ATTACKED_SIG, C.c_bool)
in_check = _bind_opt("in_check", INCHECK_ARGS, C.c_bool)
attackers_to = _bind_opt("attackers_to", ATTACKERS_TO, C.c_uint64)
get_legal_moves = _bind_opt("get_legal_moves", GEN_LEGAL_MOVES, C.c_int)
PERFT_SIG = (C.POINTER(Board), C.c_int)
perft = _bind_opt("perft", PERFT_SIG, C.c_uint64)
# Attack cache tables.
KnightArr = (C.c_uint64 * 64)
KingArr = (C.c_uint64 * 64)
PawnRow = (C.c_uint64 * 64)
PawnArr = PawnRow * 2
try:
KNIGHT_ATTACKS = KnightArr.in_dll(_lib, "KNIGHT_ATTACKS")
KING_ATTACKS = KingArr.in_dll(_lib, "KING_ATTACKS")
PAWN_ATTACKS = PawnArr.in_dll(_lib, "PAWN_ATTACKS")
except ValueError:
KNIGHT_ATTACKS = KING_ATTACKS = PAWN_ATTACKS = None # symbols not exported
def init_attack_caches():
if create_knight_attack_cache: create_knight_attack_cache()
if create_king_attack_cache: create_king_attack_cache()
if create_pawn_attack_cache: create_pawn_attack_cache()
def load_fen(board, fen):
return _lib.load_fen(C.byref(board), fen.encode("ascii"))
def gen_moves(board, captures_only=False, cap=256):
buf = (Move * cap)()
n = _lib.gen_pseudo_moves(C.byref(board), buf, captures_only)
return buf, n
def gen_legal_moves(board, out):
return int(get_legal_moves(C.byref(board), out))
def is_square_attacked(board, sq, by):
return bool(square_attacked(C.byref(board), int(sq), int(by)))
def is_in_check(board, side):
return bool(in_check(C.byref(board), int(side)))
def get_attackers_to(board, sq, by):
return int(attackers_to(C.byref(board), int(sq), int(by)))
def sq_to_coord(sq):
return chr(ord('a') + (sq % 8)) + chr(ord('1') + (sq // 8))
def sq(name):
f = FILES[name[0].lower()]
r = int(name[1]) - 1
return f + 8 * r
def bb_from(*algebraic):
m = 0
for s in algebraic:
m |= (1 << sq(s))
return m
def popcount(x: int) -> int:
return x.bit_count()
def draw_bb(mask, origin=None):
print("\n")
lines = []
for r in range(7, -1, -1):
row = []
for f in range(8):
sqi = r * 8 + f
bit = (mask >> sqi) & 1
if origin is not None and sqi == origin:
ch = 'O'
elif bit:
ch = 'x'
else:
ch = '.'
row.append(ch)
lines.append(f"{r+1} " + " ".join(row))
lines.append(" " + " ".join(FILES))
lines = "\n".join(lines)
print(lines, "\n")