diff --git a/src/main/java/com/thealgorithms/maths/FindKthNumber.java b/src/main/java/com/thealgorithms/maths/FindKthNumber.java index daea3f96..a9b26767 100644 --- a/src/main/java/com/thealgorithms/maths/FindKthNumber.java +++ b/src/main/java/com/thealgorithms/maths/FindKthNumber.java @@ -1,10 +1,9 @@ package com.thealgorithms.maths; -import java.util.Arrays; import java.util.Random; /** - * use quick sort algorithm to get kth largest or kth smallest element in given array + * Use a quicksort-based approach to identify the k-th largest or k-th max element within the provided array. */ public final class FindKthNumber { private FindKthNumber() { @@ -12,66 +11,55 @@ public final class FindKthNumber { private static final Random RANDOM = new Random(); - public static void main(String[] args) { - /* generate an array with random size and random elements */ - int[] nums = generateArray(100); - - /* get 3th largest element */ - int kth = 3; - int kthMaxIndex = nums.length - kth; - int targetMax = findKthMax(nums, kthMaxIndex); - - /* get 3th smallest element */ - int kthMinIndex = kth - 1; - int targetMin = findKthMax(nums, kthMinIndex); - - Arrays.sort(nums); - assert nums[kthMaxIndex] == targetMax; - assert nums[kthMinIndex] == targetMin; - } - - private static int[] generateArray(int capacity) { - int size = RANDOM.nextInt(capacity) + 1; - int[] array = new int[size]; - - for (int i = 0; i < size; i++) { - array[i] = RANDOM.nextInt() % 100; + public static int findKthMax(int[] array, int k) { + if (k <= 0 || k > array.length) { + throw new IllegalArgumentException("k must be between 1 and the size of the array"); } - return array; + + // Convert k-th largest to index for QuickSelect + return quickSelect(array, 0, array.length - 1, array.length - k); } - private static int findKthMax(int[] nums, int k) { - int start = 0; - int end = nums.length; - while (start < end) { - int pivot = partition(nums, start, end); - if (k == pivot) { - return nums[pivot]; - } else if (k > pivot) { - start = pivot + 1; - } else { - end = pivot; + private static int quickSelect(int[] array, int left, int right, int kSmallest) { + if (left == right) { + return array[left]; + } + + // Randomly select a pivot index + int pivotIndex = left + RANDOM.nextInt(right - left + 1); + pivotIndex = partition(array, left, right, pivotIndex); + + if (kSmallest == pivotIndex) { + return array[kSmallest]; + } else if (kSmallest < pivotIndex) { + return quickSelect(array, left, pivotIndex - 1, kSmallest); + } else { + return quickSelect(array, pivotIndex + 1, right, kSmallest); + } + } + + private static int partition(int[] array, int left, int right, int pivotIndex) { + int pivotValue = array[pivotIndex]; + // Move pivot to end + swap(array, pivotIndex, right); + int storeIndex = left; + + // Move all smaller elements to the left + for (int i = left; i < right; i++) { + if (array[i] < pivotValue) { + swap(array, storeIndex, i); + storeIndex++; } } - return -1; + + // Move pivot to its final place + swap(array, storeIndex, right); + return storeIndex; } - private static int partition(int[] nums, int start, int end) { - int pivot = nums[start]; - int j = start; - for (int i = start + 1; i < end; i++) { - if (nums[i] < pivot) { - j++; - swap(nums, i, j); - } - } - swap(nums, start, j); - return j; - } - - private static void swap(int[] nums, int a, int b) { - int tmp = nums[a]; - nums[a] = nums[b]; - nums[b] = tmp; + private static void swap(int[] array, int i, int j) { + int temp = array[i]; + array[i] = array[j]; + array[j] = temp; } } diff --git a/src/test/java/com/thealgorithms/maths/FindKthNumberTest.java b/src/test/java/com/thealgorithms/maths/FindKthNumberTest.java new file mode 100644 index 00000000..ddf62ff1 --- /dev/null +++ b/src/test/java/com/thealgorithms/maths/FindKthNumberTest.java @@ -0,0 +1,59 @@ +package com.thealgorithms.maths; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Arrays; +import java.util.Random; +import org.junit.jupiter.api.Test; + +public class FindKthNumberTest { + @Test + public void testFindKthMaxTypicalCases() { + int[] array1 = {3, 2, 1, 4, 5}; + assertEquals(3, FindKthNumber.findKthMax(array1, 3)); + assertEquals(4, FindKthNumber.findKthMax(array1, 2)); + assertEquals(5, FindKthNumber.findKthMax(array1, 1)); + + int[] array2 = {7, 5, 8, 2, 1, 6}; + assertEquals(5, FindKthNumber.findKthMax(array2, 4)); + assertEquals(6, FindKthNumber.findKthMax(array2, 3)); + assertEquals(8, FindKthNumber.findKthMax(array2, 1)); + } + + @Test + public void testFindKthMaxEdgeCases() { + int[] array1 = {1}; + assertEquals(1, FindKthNumber.findKthMax(array1, 1)); + + int[] array2 = {5, 3}; + assertEquals(5, FindKthNumber.findKthMax(array2, 1)); + assertEquals(3, FindKthNumber.findKthMax(array2, 2)); + } + + @Test + public void testFindKthMaxInvalidK() { + int[] array = {1, 2, 3, 4, 5}; + assertThrows(IllegalArgumentException.class, () -> FindKthNumber.findKthMax(array, 0)); + assertThrows(IllegalArgumentException.class, () -> FindKthNumber.findKthMax(array, 6)); + } + + @Test + public void testFindKthMaxLargeArray() { + int[] array = generateArray(1000); + int k = new Random().nextInt(array.length); + int result = FindKthNumber.findKthMax(array, k); + Arrays.sort(array); + assertEquals(array[array.length - k], result); + } + + public static int[] generateArray(int capacity) { + int size = new Random().nextInt(capacity) + 1; + int[] array = new int[size]; + + for (int i = 0; i < size; i++) { + array[i] = new Random().nextInt(100); // Ensure positive values for testing + } + return array; + } +}