Bug fixes to the RL algorithm and some tests

This commit is contained in:
2023-09-06 15:06:20 +01:00
parent 1aa8ffa8fc
commit 6d4e364f8d
6 changed files with 124 additions and 45 deletions
+34 -8
View File
@@ -63,19 +63,44 @@ class Board:
if piece != 0:
if piece.colour == GREEN:
self.greenLeft -= 1
return
continue
self.whiteLeft -= 1
def getAllMoves(self, colour):
moves = []
possibleMoves = []
possiblePieces = []
pieces = self.getAllPieces(colour)
hasForcedCapture = False
for piece in self.getAllPieces(colour):
for piece in pieces:
validMoves = self.getValidMoves(piece)
for move, skip in validMoves.items():
tempBoard = deepcopy(self)
tempPiece = tempBoard.getPiece(piece.row, piece.col)
newBoard = self._simulateMove(tempPiece, move, tempBoard, skip)
moves.append(newBoard)
# Check if there are forced capture moves for this piece
forcedCaptureMoves = [move for move, skip in validMoves.items() if skip]
if forcedCaptureMoves:
hasForcedCapture = True
possiblePieces.append(piece)
possibleMoves.append({move: skip for move, skip in validMoves.items() if skip})
if hasForcedCapture:
# If there are forced capture moves, consider only those
for i in range(len(possibleMoves)):
for move, skip in possibleMoves[i].items():
tempBoard = deepcopy(self)
tempPiece = tempBoard.getPiece(possiblePieces[i].row, possiblePieces[i].col)
newBoard = self._simulateMove(tempPiece, move, tempBoard, skip)
moves.append(newBoard)
else:
# If no forced capture moves, consider all valid moves
for piece in pieces:
validMoves = self.getValidMoves(piece)
for move, skip in validMoves.items():
tempBoard = deepcopy(self)
tempPiece = tempBoard.getPiece(piece.row, piece.col)
newBoard = self._simulateMove(tempPiece, move, tempBoard, skip)
moves.append(newBoard)
return moves
def _simulateMove(self, piece, move, board, skip):
@@ -134,6 +159,7 @@ class Board:
forcedCapture = forced
else:
forcedCapture = forced
return forcedCapture
def scoreOfTheBoard(self):
@@ -241,7 +267,7 @@ class Board:
def _decode(self, move):
# Split digits back out
str_code = str(move)
print(str_code)
# print(str_code)
start_row = int(str_code[0])
start_col = int(str_code[1])
end_row = int(str_code[2])