Update simulation class
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import uuid
|
||||||
from binding.python_c_ffi import Board
|
from binding.python_c_ffi import Board
|
||||||
from binding.python_c_ffi import Move
|
from binding.python_c_ffi import Move
|
||||||
from binding.python_c_ffi import ChessFFI
|
from binding.python_c_ffi import ChessFFI
|
||||||
@@ -6,11 +7,12 @@ from binding.python_c_ffi import print_board
|
|||||||
from binding.python_c_ffi import sq_to_coord
|
from binding.python_c_ffi import sq_to_coord
|
||||||
from scripts.evaluation import RandomEval
|
from scripts.evaluation import RandomEval
|
||||||
from scripts.evaluation import NegaMaxEval
|
from scripts.evaluation import NegaMaxEval
|
||||||
|
from scripts.format import LongPGNFormatter
|
||||||
|
|
||||||
|
|
||||||
YMD_HM = "%Y-%m-%d-%H-%M"
|
YMD_HM = "%Y-%m-%d-%H-%M"
|
||||||
MAX_MOVES = 256
|
MAX_MOVES = 256
|
||||||
DATA_PATH = "./data"
|
DATA_PATH = "./data/games"
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
@@ -19,11 +21,12 @@ class Engine:
|
|||||||
self.board = Board()
|
self.board = Board()
|
||||||
self.max_plys = max_plys
|
self.max_plys = max_plys
|
||||||
self.moves = []
|
self.moves = []
|
||||||
|
self.fens = []
|
||||||
self._load_attack_cache()
|
self._load_attack_cache()
|
||||||
self._seed_engine()
|
self._seed_engine()
|
||||||
|
|
||||||
random_strat = RandomEval(chess_ffi=self.chess_ffi)
|
random_strat = RandomEval(chess_ffi=self.chess_ffi)
|
||||||
nega_strat = NegaMaxEval(chess_ffi=self.chess_ffi, depth=2, cp_window=30)
|
|
||||||
if not strat_white:
|
if not strat_white:
|
||||||
self.strat_white = random_strat
|
self.strat_white = random_strat
|
||||||
else:
|
else:
|
||||||
@@ -34,21 +37,33 @@ class Engine:
|
|||||||
else:
|
else:
|
||||||
self.strat_black = strat_black
|
self.strat_black = strat_black
|
||||||
|
|
||||||
|
self.strategies = {
|
||||||
|
"white": {
|
||||||
|
"name": self.strat_white.NAME,
|
||||||
|
"params": self.strat_white.get_params()
|
||||||
|
},
|
||||||
|
"black": {
|
||||||
|
"name": self.strat_black.NAME,
|
||||||
|
"params": self.strat_black.get_params()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def run(self, fen):
|
|
||||||
|
def run(self, fen, save=True):
|
||||||
self.chess_ffi.load_fen(self.board, fen)
|
self.chess_ffi.load_fen(self.board, fen)
|
||||||
self._clear_moves()
|
self._clear_moves()
|
||||||
plys = 0
|
plys = 0
|
||||||
|
|
||||||
|
self.fens.append(fen)
|
||||||
while plys <= self.max_plys:
|
while plys <= self.max_plys:
|
||||||
print_board(self.board)
|
print_board(self.board)
|
||||||
|
|
||||||
moves_buf = (Move * MAX_MOVES)()
|
moves_buf = (Move * MAX_MOVES)()
|
||||||
n_legal = self.chess_ffi.get_legal_moves(self.board, moves_buf)
|
n_legal = self.chess_ffi.get_legal_moves(self.board, moves_buf)
|
||||||
|
|
||||||
game_over, ending = self._is_game_over(n_legal)
|
game_over, result = self._is_game_over(n_legal)
|
||||||
if game_over:
|
if game_over:
|
||||||
print(ending)
|
print(result)
|
||||||
break
|
break
|
||||||
|
|
||||||
legal = [moves_buf[i] for i in range(n_legal)]
|
legal = [moves_buf[i] for i in range(n_legal)]
|
||||||
@@ -57,7 +72,6 @@ class Engine:
|
|||||||
else:
|
else:
|
||||||
best_move = self.strat_black.get_best_move(self.board, legal)
|
best_move = self.strat_black.get_best_move(self.board, legal)
|
||||||
|
|
||||||
|
|
||||||
new_board = Board()
|
new_board = Board()
|
||||||
if not self.chess_ffi.apply_move_on_copy(self.board, new_board, best_move):
|
if not self.chess_ffi.apply_move_on_copy(self.board, new_board, best_move):
|
||||||
print("ERROR: apply_move_on_copy failed")
|
print("ERROR: apply_move_on_copy failed")
|
||||||
@@ -66,22 +80,31 @@ class Engine:
|
|||||||
self.board = new_board
|
self.board = new_board
|
||||||
plys += 1
|
plys += 1
|
||||||
|
|
||||||
|
fen_string = self.chess_ffi.board_to_fen(self.board)
|
||||||
|
self.fens.append(fen_string)
|
||||||
|
|
||||||
move = self.to_uci(best_move)
|
move = self.to_uci(best_move)
|
||||||
self.moves.append(move)
|
self.moves.append(move)
|
||||||
|
|
||||||
if self.board.halfmove_clock >= 100:
|
if self.board.halfmove_clock >= 100:
|
||||||
print("draw (50-move rule)")
|
print("draw (50-move rule)")
|
||||||
|
result = "1/2-1/2"
|
||||||
break
|
break
|
||||||
|
|
||||||
for move in self.moves:
|
|
||||||
# print(move)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
print("done")
|
print("done")
|
||||||
print(f"winner: {self.side_tag(not self.board.side_to_move)}")
|
|
||||||
print(len(self.moves))
|
print(len(self.moves))
|
||||||
|
|
||||||
|
if save:
|
||||||
|
print("saving game to disk.")
|
||||||
|
id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
formatter = LongPGNFormatter(
|
||||||
|
save_path=DATA_PATH,
|
||||||
|
strategies=self.strategies,
|
||||||
|
result=result
|
||||||
|
)
|
||||||
|
formatter.save_game(id, self.moves, self.fens)
|
||||||
|
|
||||||
|
|
||||||
def side_tag(self, side):
|
def side_tag(self, side):
|
||||||
return 'w' if side == WHITE else 'b'
|
return 'w' if side == WHITE else 'b'
|
||||||
@@ -90,35 +113,32 @@ class Engine:
|
|||||||
def load_fen(self, fen):
|
def load_fen(self, fen):
|
||||||
self.chess_ffi.load_fen(self.board, fen)
|
self.chess_ffi.load_fen(self.board, fen)
|
||||||
|
|
||||||
|
|
||||||
def _clear_moves(self):
|
|
||||||
self.moves = []
|
|
||||||
|
|
||||||
|
|
||||||
def to_uci(self, move):
|
def to_uci(self, move):
|
||||||
fr = sq_to_coord(getattr(move, "from"))
|
fr = sq_to_coord(getattr(move, "from"))
|
||||||
to = sq_to_coord(move.to)
|
to = sq_to_coord(move.to)
|
||||||
promo = getattr(move, "promo", 0) or ""
|
|
||||||
|
# Only add a promotion letter if promo is set
|
||||||
if promo:
|
promo_letter = ""
|
||||||
MAP = {
|
promo_val = int(getattr(move, "promo", 0) or 0)
|
||||||
1:"n",
|
if promo_val:
|
||||||
2:"b",
|
# Normalize piece id to type 0..5 (P,N,B,R,Q,K)
|
||||||
3:"r",
|
pt = promo_val % 6
|
||||||
4:"q",
|
# Map N=1, B=2, R=3, Q=4
|
||||||
"n":"n",
|
letter_map = {1: "n", 2: "b", 3: "r", 4: "q"}
|
||||||
"b":"b",
|
promo_letter = letter_map.get(pt, "")
|
||||||
"r":"r",
|
|
||||||
"q":"q"
|
return f"{fr}{to}{promo_letter}"
|
||||||
}
|
|
||||||
promo = MAP.get(promo, "").lower()
|
|
||||||
return f"{fr}{to}{promo}"
|
|
||||||
|
|
||||||
|
|
||||||
def _load_attack_cache(self):
|
def _load_attack_cache(self):
|
||||||
self.chess_ffi.init_attack_cache()
|
self.chess_ffi.init_attack_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_moves(self):
|
||||||
|
self.moves = []
|
||||||
|
|
||||||
|
|
||||||
def _seed_engine(self):
|
def _seed_engine(self):
|
||||||
import ctypes as C, time
|
import ctypes as C, time
|
||||||
|
|
||||||
@@ -135,15 +155,28 @@ class Engine:
|
|||||||
def _is_game_over(self, legal_moves):
|
def _is_game_over(self, legal_moves):
|
||||||
if legal_moves == 0:
|
if legal_moves == 0:
|
||||||
if self.chess_ffi.in_check(self.board, self.board.side_to_move):
|
if self.chess_ffi.in_check(self.board, self.board.side_to_move):
|
||||||
ending = "checkmate"
|
# Checkmate.
|
||||||
|
if self.board.side_to_move == WHITE:
|
||||||
|
result = "0-1"
|
||||||
|
else:
|
||||||
|
result = "1-0"
|
||||||
else:
|
else:
|
||||||
ending = "stalemate"
|
# Stalemate.
|
||||||
return True, ending
|
result = "1/2-1/2"
|
||||||
|
return True, result
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
engine = Engine()
|
chess_ffi = ChessFFI()
|
||||||
|
nega_strat_1 = NegaMaxEval(chess_ffi=chess_ffi, depth=2, cp_window=5)
|
||||||
|
nega_strat_2 = NegaMaxEval(chess_ffi=chess_ffi, depth=1, cp_window=5)
|
||||||
|
|
||||||
|
|
||||||
|
engine = Engine(
|
||||||
|
strat_white=nega_strat_1,
|
||||||
|
strat_black=nega_strat_2,
|
||||||
|
)
|
||||||
|
|
||||||
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
||||||
engine.run(fen)
|
engine.run(fen, save=True)
|
||||||
Reference in New Issue
Block a user