180 lines
5.9 KiB
Java
180 lines
5.9 KiB
Java
|
package Searches;
|
||
|
|
||
|
import java.util.Collections;
|
||
|
import java.util.ArrayList;
|
||
|
import java.util.Comparator;
|
||
|
import java.util.Random;
|
||
|
|
||
|
/**
|
||
|
* Monte Carlo Tree Search (MCTS) is a heuristic search algorithm
|
||
|
* used in decition taking problems especially games.
|
||
|
*
|
||
|
* See more: https://en.wikipedia.org/wiki/Monte_Carlo_tree_search,
|
||
|
* https://www.baeldung.com/java-monte-carlo-tree-search
|
||
|
*/
|
||
|
public class MonteCarloTreeSearch {
|
||
|
public class Node {
|
||
|
Node parent;
|
||
|
ArrayList <Node> childNodes;
|
||
|
boolean isPlayersTurn; // True if it is the player's turn.
|
||
|
boolean playerWon; // True if the player won; false if the opponent won.
|
||
|
int score;
|
||
|
int visitCount;
|
||
|
|
||
|
public Node() {}
|
||
|
|
||
|
public Node(Node parent, boolean isPlayersTurn) {
|
||
|
this.parent = parent;
|
||
|
childNodes = new ArrayList<>();
|
||
|
this.isPlayersTurn = isPlayersTurn;
|
||
|
playerWon = false;
|
||
|
score = 0;
|
||
|
visitCount = 0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static final int WIN_SCORE = 10;
|
||
|
static final int TIME_LIMIT = 500; // Time the algorithm will be running for (in milliseconds).
|
||
|
|
||
|
public static void main(String[] args) {
|
||
|
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
|
||
|
|
||
|
mcts.monteCarloTreeSearch(mcts.new Node(null, true));
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Explores a game tree using Monte Carlo Tree Search (MCTS)
|
||
|
* and returns the most promising node.
|
||
|
*
|
||
|
* @param rootNode Root node of the game tree.
|
||
|
* @return The most promising child of the root node.
|
||
|
*/
|
||
|
public Node monteCarloTreeSearch(Node rootNode) {
|
||
|
Node winnerNode;
|
||
|
double timeLimit;
|
||
|
|
||
|
// Expand the root node.
|
||
|
addChildNodes(rootNode, 10);
|
||
|
|
||
|
timeLimit = System.currentTimeMillis() + TIME_LIMIT;
|
||
|
|
||
|
// Explore the tree until the time limit is reached.
|
||
|
while (System.currentTimeMillis() < timeLimit) {
|
||
|
Node promisingNode;
|
||
|
|
||
|
// Get a promising node using UCT.
|
||
|
promisingNode = getPromisingNode(rootNode);
|
||
|
|
||
|
// Expand the promising node.
|
||
|
if (promisingNode.childNodes.size() == 0) {
|
||
|
addChildNodes(promisingNode, 10);
|
||
|
}
|
||
|
|
||
|
simulateRandomPlay(promisingNode);
|
||
|
}
|
||
|
|
||
|
winnerNode = getWinnerNode(rootNode);
|
||
|
printScores(rootNode);
|
||
|
System.out.format("\nThe optimal node is: %02d\n", rootNode.childNodes.indexOf(winnerNode) + 1);
|
||
|
|
||
|
return winnerNode;
|
||
|
}
|
||
|
|
||
|
public void addChildNodes(Node node, int childCount) {
|
||
|
for (int i = 0; i < childCount; i++) {
|
||
|
node.childNodes.add(new Node(node, !node.isPlayersTurn));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Uses UCT to find a promising child node to be explored.
|
||
|
*
|
||
|
* UCT: Upper Confidence bounds applied to Trees.
|
||
|
*
|
||
|
* @param rootNode Root node of the tree.
|
||
|
* @return The most promising node according to UCT.
|
||
|
*/
|
||
|
public Node getPromisingNode(Node rootNode) {
|
||
|
Node promisingNode = rootNode;
|
||
|
|
||
|
// Iterate until a node that hasn't been expanded is found.
|
||
|
while (promisingNode.childNodes.size() != 0) {
|
||
|
double uctIndex = Double.MIN_VALUE;
|
||
|
int nodeIndex = 0;
|
||
|
|
||
|
// Iterate through child nodes and pick the most promising one
|
||
|
// using UCT (Upper Confidence bounds applied to Trees).
|
||
|
for (int i = 0; i < promisingNode.childNodes.size(); i++) {
|
||
|
Node childNode = promisingNode.childNodes.get(i);
|
||
|
double uctTemp;
|
||
|
|
||
|
// If child node has never been visited
|
||
|
// it will have the highest uct value.
|
||
|
if (childNode.visitCount == 0) {
|
||
|
nodeIndex = i;
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
uctTemp = ((double) childNode.score / childNode.visitCount)
|
||
|
+ 1.41 * Math.sqrt(Math.log(promisingNode.visitCount) / (double) childNode.visitCount);
|
||
|
|
||
|
if (uctTemp > uctIndex) {
|
||
|
uctIndex = uctTemp;
|
||
|
nodeIndex = i;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
promisingNode = promisingNode.childNodes.get(nodeIndex);
|
||
|
}
|
||
|
|
||
|
return promisingNode;
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
* Simulates a random play from a nodes current state
|
||
|
* and back propagates the result.
|
||
|
*
|
||
|
* @param promisingNode Node that will be simulated.
|
||
|
*/
|
||
|
public void simulateRandomPlay(Node promisingNode) {
|
||
|
Random rand = new Random();
|
||
|
Node tempNode = promisingNode;
|
||
|
boolean isPlayerWinner;
|
||
|
|
||
|
// The following line randomly determines whether the simulated play is a win or loss.
|
||
|
// To use the MCTS algorithm correctly this should be a simulation of the nodes' current
|
||
|
// state of the game until it finishes (if possible) and use an evaluation function to
|
||
|
// determine how good or bad the play was.
|
||
|
// e.g. Play tic tac toe choosing random squares until the game ends.
|
||
|
promisingNode.playerWon = (rand.nextInt(6) == 0);
|
||
|
|
||
|
isPlayerWinner = promisingNode.playerWon;
|
||
|
|
||
|
// Back propagation of the random play.
|
||
|
while (tempNode != null) {
|
||
|
tempNode.visitCount++;
|
||
|
|
||
|
// Add wining scores to bouth player and opponent depending on the turn.
|
||
|
if ((tempNode.isPlayersTurn && isPlayerWinner) ||
|
||
|
(!tempNode.isPlayersTurn && !isPlayerWinner)) {
|
||
|
tempNode.score += WIN_SCORE;
|
||
|
}
|
||
|
|
||
|
tempNode = tempNode.parent;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public Node getWinnerNode(Node rootNode) {
|
||
|
return Collections.max(rootNode.childNodes, Comparator.comparing(c -> c.score));
|
||
|
}
|
||
|
|
||
|
public void printScores(Node rootNode) {
|
||
|
System.out.println("N.\tScore\t\tVisits");
|
||
|
|
||
|
for (int i = 0; i < rootNode.childNodes.size(); i++) {
|
||
|
System.out.println(String.format("%02d\t%d\t\t%d", i + 1,
|
||
|
rootNode.childNodes.get(i).score, rootNode.childNodes.get(i).visitCount));
|
||
|
}
|
||
|
}
|
||
|
}
|