From 5ddb3596751ddf215010416f78413988d05e228a Mon Sep 17 00:00:00 2001 From: Josh Date: Fri, 22 Aug 2025 14:28:48 -0400 Subject: [PATCH] Fix attack to tests --- binding/python_c_ffi.py | 72 ++++++++++++++++++++--------------------- test/base.py | 8 +++-- test/test_attack_to.py | 26 ++++++--------- 3 files changed, 52 insertions(+), 54 deletions(-) diff --git a/binding/python_c_ffi.py b/binding/python_c_ffi.py index c570d1b..0394b37 100644 --- a/binding/python_c_ffi.py +++ b/binding/python_c_ffi.py @@ -28,12 +28,12 @@ FILES = {c:i for i,c in enumerate("abcdefgh")} WHITE, BLACK, BOTH = 0, 1, 2 P, N, B, R, Q, K, p, n, b, r, q, k = range(12) + """ Register C function bindings and interfaces in a dictionary. """ PAWN_MOVE_SIG = ((C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int)), None) PIECE_MOVE_SIG = ((C.POINTER(Board), C.POINTER(Move), C.POINTER(C.c_int), C.c_bool), None) - FFI_SPEC = { "create_knight_attack_cache": None, "create_king_attack_cache": None, @@ -61,9 +61,40 @@ FFI_SPEC = { } +def draw_bb(mask, origin=None): + print("\n") + lines = [] + for r in range(7, -1, -1): + row = [] + for f in range(8): + sqi = r * 8 + f + bit = (mask >> sqi) & 1 + if origin is not None and sqi == origin: + ch = 'O' + elif bit: + ch = 'x' + else: + ch = '.' + row.append(ch) + lines.append(f"{r+1} " + " ".join(row)) + lines.append(" " + " ".join(FILES)) + lines = "\n".join(lines) + print(lines, "\n") + + +def sq_to_coord(sq): + return chr(ord('a') + (sq % 8)) + chr(ord('1') + (sq // 8)) + + +def sq(name): + f = FILES[name[0].lower()] + r = int(name[1]) - 1 + return f + 8 * r + + class ChessFFI: def __init__(self): - self.lib_path = "./build/lichess.so" + self.lib_path = "./build/libchess.so" self._load_lib() self._load_constants() self._bind_functions() @@ -76,7 +107,7 @@ class ChessFFI: def load_fen(self, board, fen_string): - return self._c_load_fen(C.byref(board), fen_string.encode("ascii")) + return self._c_load_fen(board, fen_string.encode("ascii")) def gen_pseudo_moves(self, board, captures_only=False, cap=256): @@ -98,44 +129,13 @@ class ChessFFI: def attackers_to(self, board, sq, by): - return int(self._c_attackers_to(C.byref(board), int(sq), int(by))) - - - def sq_to_coord(self, sq): - return chr(ord('a') + (sq % 8)) + chr(ord('1') + (sq // 8)) - - - def sq(self, name): - f = FILES[name[0].lower()] - r = int(name[1]) - 1 - return f + 8 * r + return self._c_attackers_to(board, sq, by) def popcount(self, x): return x.bit_count() - def draw_bb(self, mask, origin=None): - print("\n") - lines = [] - for r in range(7, -1, -1): - row = [] - for f in range(8): - sqi = r * 8 + f - bit = (mask >> sqi) & 1 - if origin is not None and sqi == origin: - ch = 'O' - elif bit: - ch = 'x' - else: - ch = '.' - row.append(ch) - lines.append(f"{r+1} " + " ".join(row)) - lines.append(" " + " ".join(FILES)) - lines = "\n".join(lines) - print(lines, "\n") - - def _load_lib(self): self._lib = C.CDLL(self.lib_path) @@ -164,7 +164,7 @@ class ChessFFI: if fn is not None: fn.argtypes = argtypes - fn.restype = restype + fn.restype = restype # Prepend each function name with _c to allow us to reuse the # function names as methods on this class. This means that we diff --git a/test/base.py b/test/base.py index c8143cd..9815980 100644 --- a/test/base.py +++ b/test/base.py @@ -1,4 +1,5 @@ import unittest +from binding.python_c_ffi import ChessFFI from test.chess_ffi import Board from test.chess_ffi import KING_ATTACKS from test.chess_ffi import KNIGHT_ATTACKS @@ -11,6 +12,9 @@ from test.chess_ffi import load_fen class ChessLibTestBase(unittest.TestCase): @classmethod def setUpClass(cls): + cls.chess_ffi = ChessFFI() + + init_attack_caches() cls.KNIGHT_ATTACKS = KNIGHT_ATTACKS @@ -25,8 +29,8 @@ class ChessLibTestBase(unittest.TestCase): def load_fen(self, fen, board=None): if board: - return load_fen(board, fen) - return load_fen(self.board, fen) + return self.chess_ffi.load_fen(board, fen) + return self.chess_ffi.load_fen(self.board, fen) def gen(self, captures_only: bool = False, cap: int = 256): diff --git a/test/test_attack_to.py b/test/test_attack_to.py index cc099d3..9633c8a 100644 --- a/test/test_attack_to.py +++ b/test/test_attack_to.py @@ -1,13 +1,8 @@ -import ctypes +from binding.python_c_ffi import Board +from binding.python_c_ffi import sq +from binding.python_c_ffi import BLACK +from binding.python_c_ffi import WHITE from test.base import ChessLibTestBase -from test.chess_ffi import get_attackers_to -from test.chess_ffi import is_square_attacked -from test.chess_ffi import is_in_check -from test.chess_ffi import sq -from test.chess_ffi import Board -from test.chess_ffi import BLACK -from test.chess_ffi import WHITE -from test.chess_ffi import draw_bb class TestAttackers(ChessLibTestBase): @@ -50,8 +45,8 @@ class TestAttackers(ChessLibTestBase): for fen, sq_str, by, expected, msg in cases: with self.subTest(msg=msg, fen=fen, sq=sq_str, by=by): b = Board() - self.load_fen(fen, board=b) - got = bool(is_square_attacked(b, sq(sq_str), by)) + self.chess_ffi.load_fen(b, fen) + got = self.chess_ffi.square_attacked(b, sq(sq_str), by) self.assertEqual(expected, got, msg) @@ -81,8 +76,8 @@ class TestAttackers(ChessLibTestBase): for fen, side, expected, msg in cases: with self.subTest(msg=msg, fen=fen, side=side): b = Board() - self.load_fen(fen, board=b) - actual = bool(is_in_check(b, side)) + self.chess_ffi.load_fen(b, fen) + actual = self.chess_ffi.in_check(b, side) self.assertEqual(expected, actual, msg) @@ -132,7 +127,6 @@ class TestAttackers(ChessLibTestBase): for fen, sq_str, by, expected_cnt, msg in cases: with self.subTest(msg=msg, fen=fen, sq=sq_str, by=by): b = Board() - self.load_fen(fen, board=b) - - mask = get_attackers_to(b, sq(sq_str), by) + self.chess_ffi.load_fen(b, fen) + mask = self.chess_ffi.attackers_to(b, sq(sq_str), by) self.assertEqual(expected_cnt, int(mask).bit_count()) \ No newline at end of file