refactor: cleanup AhoCorasick (#5358)

This commit is contained in:
Alex Klymenko 2024-08-22 10:08:17 +02:00 committed by GitHub
parent 8a89b42cf8
commit 622a3bf795
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 24 deletions

View File

@ -14,6 +14,7 @@ package com.thealgorithms.strings;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Queue; import java.util.Queue;
@ -24,7 +25,7 @@ public final class AhoCorasick {
// Trie Node Class // Trie Node Class
private static class Node { private static class Node {
// Represents a character in the trie // Represents a character in the trie
private HashMap<Character, Node> child = new HashMap<>(); // Child nodes of the current node private final Map<Character, Node> child = new HashMap<>(); // Child nodes of the current node
private Node suffixLink; // Suffix link to another node in the trie private Node suffixLink; // Suffix link to another node in the trie
private Node outputLink; // Output link to another node in the trie private Node outputLink; // Output link to another node in the trie
private int patternInd; // Index of the pattern that ends at this node private int patternInd; // Index of the pattern that ends at this node
@ -35,7 +36,7 @@ public final class AhoCorasick {
this.patternInd = -1; this.patternInd = -1;
} }
public HashMap<Character, Node> getChild() { public Map<Character, Node> getChild() {
return child; return child;
} }
@ -148,16 +149,16 @@ public final class AhoCorasick {
} }
} }
private ArrayList<ArrayList<Integer>> initializePositionByStringIndexValue() { private List<List<Integer>> initializePositionByStringIndexValue() {
ArrayList<ArrayList<Integer>> positionByStringIndexValue = new ArrayList<>(patterns.length); // Stores positions where patterns are found in the text List<List<Integer>> positionByStringIndexValue = new ArrayList<>(patterns.length); // Stores positions where patterns are found in the text
for (int i = 0; i < patterns.length; i++) { for (int i = 0; i < patterns.length; i++) {
positionByStringIndexValue.add(new ArrayList<Integer>()); positionByStringIndexValue.add(new ArrayList<>());
} }
return positionByStringIndexValue; return positionByStringIndexValue;
} }
// Searches for patterns in the input text and records their positions // Searches for patterns in the input text and records their positions
public ArrayList<ArrayList<Integer>> searchIn(final String text) { public List<List<Integer>> searchIn(final String text) {
var positionByStringIndexValue = initializePositionByStringIndexValue(); // Initialize a list to store positions of the current pattern var positionByStringIndexValue = initializePositionByStringIndexValue(); // Initialize a list to store positions of the current pattern
Node parent = root; // Start searching from the root node Node parent = root; // Start searching from the root node
@ -187,7 +188,7 @@ public final class AhoCorasick {
// by default positionByStringIndexValue contains end-points. This function converts those // by default positionByStringIndexValue contains end-points. This function converts those
// endpoints to start points // endpoints to start points
private void setUpStartPoints(ArrayList<ArrayList<Integer>> positionByStringIndexValue) { private void setUpStartPoints(List<List<Integer>> positionByStringIndexValue) {
for (int i = 0; i < patterns.length; i++) { for (int i = 0; i < patterns.length; i++) {
for (int j = 0; j < positionByStringIndexValue.get(i).size(); j++) { for (int j = 0; j < positionByStringIndexValue.get(i).size(); j++) {
int endpoint = positionByStringIndexValue.get(i).get(j); int endpoint = positionByStringIndexValue.get(i).get(j);
@ -198,20 +199,15 @@ public final class AhoCorasick {
} }
// Class to handle pattern position recording // Class to handle pattern position recording
private static class PatternPositionRecorder { private record PatternPositionRecorder(List<List<Integer>> positionByStringIndexValue) {
private ArrayList<ArrayList<Integer>> positionByStringIndexValue;
// Constructor to initialize the recorder with the position list // Constructor to initialize the recorder with the position list
PatternPositionRecorder(final ArrayList<ArrayList<Integer>> positionByStringIndexValue) {
this.positionByStringIndexValue = positionByStringIndexValue;
}
/** /**
* Records positions for a pattern when it's found in the input text and follows * Records positions for a pattern when it's found in the input text and follows
* output links to record positions of other patterns. * output links to record positions of other patterns.
* *
* @param parent The current node representing a character in the pattern trie. * @param parent The current node representing a character in the pattern trie.
* @param currentPosition The current position in the input text. * @param currentPosition The current position in the input text.
*/ */
public void recordPatternPositions(final Node parent, final int currentPosition) { public void recordPatternPositions(final Node parent, final int currentPosition) {
// Check if the current node represents the end of a pattern // Check if the current node represents the end of a pattern
@ -229,19 +225,20 @@ public final class AhoCorasick {
} }
} }
} }
// method to search for patterns in text // method to search for patterns in text
public static Map<String, ArrayList<Integer>> search(final String text, final String[] patterns) { public static Map<String, List<Integer>> search(final String text, final String[] patterns) {
final var trie = new Trie(patterns); final var trie = new Trie(patterns);
final var positionByStringIndexValue = trie.searchIn(text); final var positionByStringIndexValue = trie.searchIn(text);
return convert(positionByStringIndexValue, patterns); return convert(positionByStringIndexValue, patterns);
} }
// method for converting results to a map // method for converting results to a map
private static Map<String, ArrayList<Integer>> convert(final ArrayList<ArrayList<Integer>> positionByStringIndexValue, final String[] patterns) { private static Map<String, List<Integer>> convert(final List<List<Integer>> positionByStringIndexValue, final String[] patterns) {
Map<String, ArrayList<Integer>> positionByString = new HashMap<>(); Map<String, List<Integer>> positionByString = new HashMap<>();
for (int i = 0; i < patterns.length; i++) { for (int i = 0; i < patterns.length; i++) {
String pattern = patterns[i]; String pattern = patterns[i];
ArrayList<Integer> positions = positionByStringIndexValue.get(i); List<Integer> positions = positionByStringIndexValue.get(i);
positionByString.put(pattern, new ArrayList<>(positions)); positionByString.put(pattern, new ArrayList<>(positions));
} }
return positionByString; return positionByString;

View File

@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -42,7 +43,7 @@ class AhoCorasickTest {
@Test @Test
void testSearch() { void testSearch() {
// Define the expected results for each pattern // Define the expected results for each pattern
final var expected = Map.of("ACC", new ArrayList<>(Arrays.asList()), "ATC", new ArrayList<>(Arrays.asList(2)), "CAT", new ArrayList<>(Arrays.asList(1)), "GCG", new ArrayList<>(Arrays.asList()), "C", new ArrayList<>(Arrays.asList(1, 4)), "T", new ArrayList<>(Arrays.asList(3))); final var expected = Map.of("ACC", new ArrayList<>(List.of()), "ATC", new ArrayList<>(List.of(2)), "CAT", new ArrayList<>(List.of(1)), "GCG", new ArrayList<>(List.of()), "C", new ArrayList<>(List.of(1, 4)), "T", new ArrayList<>(List.of(3)));
assertEquals(expected, AhoCorasick.search(text, patterns)); assertEquals(expected, AhoCorasick.search(text, patterns));
} }
@ -77,7 +78,7 @@ class AhoCorasickTest {
void testPatternAtBeginning() { void testPatternAtBeginning() {
// Define patterns that start at the beginning of the text // Define patterns that start at the beginning of the text
final var searchPatterns = new String[] {"GC", "GCA", "GCAT"}; final var searchPatterns = new String[] {"GC", "GCA", "GCAT"};
final var expected = Map.of("GC", new ArrayList<Integer>(Arrays.asList(0)), "GCA", new ArrayList<Integer>(Arrays.asList(0)), "GCAT", new ArrayList<Integer>(Arrays.asList(0))); final var expected = Map.of("GC", new ArrayList<>(List.of(0)), "GCA", new ArrayList<>(List.of(0)), "GCAT", new ArrayList<>(List.of(0)));
assertEquals(expected, AhoCorasick.search(text, searchPatterns)); assertEquals(expected, AhoCorasick.search(text, searchPatterns));
} }
@ -89,7 +90,7 @@ class AhoCorasickTest {
void testPatternAtEnd() { void testPatternAtEnd() {
// Define patterns that end at the end of the text // Define patterns that end at the end of the text
final var searchPatterns = new String[] {"CG", "TCG", "ATCG"}; final var searchPatterns = new String[] {"CG", "TCG", "ATCG"};
final var expected = Map.of("CG", new ArrayList<Integer>(Arrays.asList(4)), "TCG", new ArrayList<Integer>(Arrays.asList(3)), "ATCG", new ArrayList<Integer>(Arrays.asList(2))); final var expected = Map.of("CG", new ArrayList<>(List.of(4)), "TCG", new ArrayList<>(List.of(3)), "ATCG", new ArrayList<>(List.of(2)));
assertEquals(expected, AhoCorasick.search(text, searchPatterns)); assertEquals(expected, AhoCorasick.search(text, searchPatterns));
} }
@ -102,7 +103,7 @@ class AhoCorasickTest {
void testMultipleOccurrencesOfPattern() { void testMultipleOccurrencesOfPattern() {
// Define patterns with multiple occurrences in the text // Define patterns with multiple occurrences in the text
final var searchPatterns = new String[] {"AT", "T"}; final var searchPatterns = new String[] {"AT", "T"};
final var expected = Map.of("AT", new ArrayList<Integer>(Arrays.asList(2)), "T", new ArrayList<Integer>(Arrays.asList(3))); final var expected = Map.of("AT", new ArrayList<>(List.of(2)), "T", new ArrayList<>(List.of(3)));
assertEquals(expected, AhoCorasick.search(text, searchPatterns)); assertEquals(expected, AhoCorasick.search(text, searchPatterns));
} }
@ -114,7 +115,7 @@ class AhoCorasickTest {
void testCaseInsensitiveSearch() { void testCaseInsensitiveSearch() {
// Define patterns with different cases // Define patterns with different cases
final var searchPatterns = new String[] {"gca", "aTc", "C"}; final var searchPatterns = new String[] {"gca", "aTc", "C"};
final var expected = Map.of("gca", new ArrayList<Integer>(), "aTc", new ArrayList<Integer>(), "C", new ArrayList<Integer>(Arrays.asList(1, 4))); final var expected = Map.of("gca", new ArrayList<Integer>(), "aTc", new ArrayList<Integer>(), "C", new ArrayList<>(Arrays.asList(1, 4)));
assertEquals(expected, AhoCorasick.search(text, searchPatterns)); assertEquals(expected, AhoCorasick.search(text, searchPatterns));
} }
} }