diff --git a/scala/src/main/scala/ch12_sorts/QuickSort.scala b/scala/src/main/scala/ch12_sorts/QuickSort.scala new file mode 100644 index 0000000..afcea21 --- /dev/null +++ b/scala/src/main/scala/ch12_sorts/QuickSort.scala @@ -0,0 +1,54 @@ +package ch12_sorts + +object QuickSort { + + //find the K th smallest element int the array + def findKthElement(items: Array[Int], k: Int): Int = { + _findKthElement(items, k, 0, items.length - 1) + } + + private[this] def _findKthElement(items: Array[Int], k: Int, p: Int, r: Int): Int = { + val q = _partition(items, p, r) + + if (k == q + 1) { + items(q) + } else if (k < q + 1) { + _findKthElement(items, k, p, q - 1) + } else { + _findKthElement(items, k, q + 1, r) + } + } + + def quickSort(items: Array[Int]): Array[Int] = { + _quickSort(items, 0, items.length - 1) + items + } + + private[this] def _quickSort(items: Array[Int], p: Int, r: Int): Unit = { + if (p >= r) { + return + } + val q = _partition(items, p, r) + _quickSort(items, p, q - 1) + _quickSort(items, q + 1, r) + } + + private[this] def _partition(items: Array[Int], p: Int, r: Int): Int = { + val pivot = items(r) + var i = p + for (j <- Range(p, r)) { + if (items(j) < pivot) { + val temp = items(i) + items(i) = items(j) + items(j) = temp + i += 1 + } + } + + val temp = items(i) + items(i) = items(r) + items(r) = temp + + i + } +} diff --git a/scala/src/test/scala/ch11_sorts/SortsTest.scala b/scala/src/test/scala/ch11_sorts/SortsTest.scala index ee00637..d18854a 100644 --- a/scala/src/test/scala/ch11_sorts/SortsTest.scala +++ b/scala/src/test/scala/ch11_sorts/SortsTest.scala @@ -1,5 +1,6 @@ package ch11_sorts +import ch12_sorts.{MergeSort, QuickSort} import org.scalatest.{FlatSpec, Matchers} import scala.util.Random @@ -50,6 +51,8 @@ class SortsTest extends FlatSpec with Matchers { timed("bubbleSort", Sorts.bubbleSort, array.clone()) timed("insertSort", Sorts.insertSort, array.clone()) timed("selectionSort", Sorts.selectionSort, array.clone()) + timed("mergeSort", MergeSort.mergeSort, array.clone()) + timed("quickSort", QuickSort.quickSort, array.clone()) } def reportElapsed(name: String, time: Long): Unit = println(name + " takes in " + time + "ms") diff --git a/scala/src/test/scala/ch12_sorts/QuickSortTest.scala b/scala/src/test/scala/ch12_sorts/QuickSortTest.scala new file mode 100644 index 0000000..cafeedb --- /dev/null +++ b/scala/src/test/scala/ch12_sorts/QuickSortTest.scala @@ -0,0 +1,31 @@ +package ch12_sorts + +import org.scalatest.{FlatSpec, Matchers} + +class QuickSortTest extends FlatSpec with Matchers { + + behavior of "QuickSortTest" + + it should "quickSort" in { + var array = Array(4, 5, 6, 3, 2, 1) + array = QuickSort.quickSort(array) + array.mkString("") should equal("123456") + + array = Array(4) + array = QuickSort.quickSort(array) + array.mkString("") should equal("4") + + array = Array(4, 2) + array = QuickSort.quickSort(array) + array.mkString("") should equal("24") + } + + it should "find the Kth element in the array" in { + val array = Array(4, 2, 5, 12, 3) + + QuickSort.findKthElement(array, 3) should equal(4) + QuickSort.findKthElement(array, 5) should equal(12) + QuickSort.findKthElement(array, 1) should equal(2) + } + +}