""" FFI - Foreign Function Interface This module needs to reflect the function interfaces that are defined in our C modules. This may or may not be a good way to test C since we essentially need to maintain two sets of interfaces. """ import ctypes as C def _lib_path(): # Just use a relative path from makefile. return "./build/libchess.so" _lib = C.CDLL(str(_lib_path())) FILES = {c:i for i,c in enumerate("abcdefgh")} WHITE, BLACK, BOTH = 0, 1, 2 P, N, B, R, Q, K, p, n, b, r, q, k = range(12) class Board(C.Structure): _fields_ = [ ("pieces", C.c_uint64 * 12), ("occ", C.c_uint64 * 3), ("king_square", C.c_uint64 * 2), ("castling_rights", C.c_uint8), ("ep_square", C.c_int), ("side_to_move", C.c_int), ("halfmove_clock", C.c_int), ("fullmove_number", C.c_int), ] class Move(C.Structure): _fields_ = [ ("from", C.c_uint16), ("to", C.c_uint16), ("piece", C.c_uint8), ("promo", C.c_uint8), ("flags", C.c_uint8), ] _lib.load_fen.argtypes = (C.POINTER(Board), C.c_char_p) _lib.load_fen.restype = C.c_int _lib.gen_pseudo_moves.argtypes = (C.POINTER(Board), C.POINTER(Move), C.c_bool) _lib.gen_pseudo_moves.restype = C.c_int def _bind_opt(name, argtypes=(), restype=None): fn = getattr(_lib, name, None) if fn is not None: fn.argtypes = argtypes fn.restype = restype return fn create_knight_attack_cache = _bind_opt("create_knight_attack_cache", (), None) create_king_attack_cache = _bind_opt("create_king_attack_cache", (), None) create_pawn_attack_cache = _bind_opt("create_pawn_attack_cache", (), None) # PAWN move generation. PAWN_SIG = (C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int)), None gen_white_pawn_quiet_pushes = _bind_opt("gen_white_pawn_quiet_pushes", *PAWN_SIG) gen_black_pawn_quiet_pushes = _bind_opt("gen_black_pawn_quiet_pushes", *PAWN_SIG) gen_white_pawn_push_promotions = _bind_opt("gen_white_pawn_push_promotions", *PAWN_SIG) gen_black_pawn_push_promotions = _bind_opt("gen_black_pawn_push_promotions", *PAWN_SIG) gen_white_pawn_captures = _bind_opt("gen_white_pawn_captures", *PAWN_SIG) gen_black_pawn_captures = _bind_opt("gen_black_pawn_captures", *PAWN_SIG) gen_white_pawn_capture_promotions = _bind_opt("gen_white_pawn_capture_promotions", *PAWN_SIG) gen_black_pawn_capture_promotions = _bind_opt("gen_black_pawn_capture_promotions", *PAWN_SIG) # Non pawn move generation. PIECE_SIG = ((C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int), C.c_bool), None) gen_knight_moves = _bind_opt("gen_knight_moves", *PIECE_SIG) gen_bishop_moves = _bind_opt("gen_bishop_moves", *PIECE_SIG) gen_rook_moves = _bind_opt("gen_rook_moves", *PIECE_SIG) gen_queen_moves = _bind_opt("gen_queen_moves", *PIECE_SIG) gen_king_moves = _bind_opt("gen_king_moves", *PIECE_SIG) # Attack checks. ATTACKED_SIG = (C.POINTER(Board), C.c_int, C.c_int) INCHECK_ARGS = (C.POINTER(Board), C.c_int) ATTACKERS_TO = (C.POINTER(Board), C.c_int, C.c_int) GEN_LEGAL_MOVES = (C.POINTER(Board), C.POINTER(Move)) square_attacked = _bind_opt("square_attacked", ATTACKED_SIG, C.c_bool) in_check = _bind_opt("in_check", INCHECK_ARGS, C.c_bool) attackers_to = _bind_opt("attackers_to", ATTACKERS_TO, C.c_uint64) get_legal_moves = _bind_opt("get_legal_moves", GEN_LEGAL_MOVES, C.c_int) PERFT_SIG = (C.POINTER(Board), C.c_int) perft = _bind_opt("perft", PERFT_SIG, C.c_uint64) # Attack cache tables. KnightArr = (C.c_uint64 * 64) KingArr = (C.c_uint64 * 64) PawnRow = (C.c_uint64 * 64) PawnArr = PawnRow * 2 try: KNIGHT_ATTACKS = KnightArr.in_dll(_lib, "KNIGHT_ATTACKS") KING_ATTACKS = KingArr.in_dll(_lib, "KING_ATTACKS") PAWN_ATTACKS = PawnArr.in_dll(_lib, "PAWN_ATTACKS") except ValueError: KNIGHT_ATTACKS = KING_ATTACKS = PAWN_ATTACKS = None # symbols not exported def init_attack_caches(): if create_knight_attack_cache: create_knight_attack_cache() if create_king_attack_cache: create_king_attack_cache() if create_pawn_attack_cache: create_pawn_attack_cache() def load_fen(board, fen): return _lib.load_fen(C.byref(board), fen.encode("ascii")) def gen_moves(board, captures_only=False, cap=256): buf = (Move * cap)() n = _lib.gen_pseudo_moves(C.byref(board), buf, captures_only) return buf, n def gen_legal_moves(board, out): return int(get_legal_moves(C.byref(board), out)) def is_square_attacked(board, sq, by): return bool(square_attacked(C.byref(board), int(sq), int(by))) def is_in_check(board, side): return bool(in_check(C.byref(board), int(side))) def get_attackers_to(board, sq, by): return int(attackers_to(C.byref(board), int(sq), int(by))) def sq_to_coord(sq): return chr(ord('a') + (sq % 8)) + chr(ord('1') + (sq // 8)) def sq(name): f = FILES[name[0].lower()] r = int(name[1]) - 1 return f + 8 * r def bb_from(*algebraic): m = 0 for s in algebraic: m |= (1 << sq(s)) return m def popcount(x: int) -> int: return x.bit_count() def draw_bb(mask, origin=None): print("\n") lines = [] for r in range(7, -1, -1): row = [] for f in range(8): sqi = r * 8 + f bit = (mask >> sqi) & 1 if origin is not None and sqi == origin: ch = 'O' elif bit: ch = 'x' else: ch = '.' row.append(ch) lines.append(f"{r+1} " + " ".join(row)) lines.append(" " + " ".join(FILES)) lines = "\n".join(lines) print(lines, "\n")