Merge sort is done by dividing the array recursively until there are one or two elements, sort it, and merge it back.
To run merge sort in parallel, we will divide a big array into multiple subtasks.
Here is the logic we will use.
(I) If the size of the array under consideration is bigger than the threshold then divide it and create two sub-tasks of sorting and submit those to the thread pool for execution.
(II) If subtasks of sorting divided array are done, then mere it back.
(III) If the size of the array under consideration is smaller than the threshold then sort it using mergesort in the same task.
Code for classes used in doing mergesort in parallel.
The interface implemented by each subtask
public abstract class SortTask implements Runnable{
public abstract boolean isReadyToProcess();
}
MergeSortTask class which implements SortTask interface
import java.util.concurrent.atomic.AtomicBoolean;
/**
* @author MaheshT
*
*/
public class MergeSortTask extends SortTask{
private int sortArray [];
private int iStart = 0;
private int iEnd =0;
private int noOfSubtask =0;
private final int SPLIT_THRESHOLD = 100000;
private int taskDone =0;
private MergeSortTask parentTask = null;
private SortingThreadPool<SortTask> threadPool = null;
private AtomicBoolean splitDone =new AtomicBoolean(false);
private Object waitForTaskDone = new Object();
public MergeSortTask(SortingThreadPool<SortTask> threadPool,MergeSortTask parentTask,int [] inList,int start, int end){
this.sortArray = inList;
this.iStart = start;
this.iEnd = end;
this.threadPool = threadPool;
this.parentTask = parentTask;
}
private void splitSortNMerge(){
splitSortNMerge(sortArray, iStart,iEnd);
}
private void splitSortNMerge(int inList[],int start, int end){
int diff = end - start;
int mid = 0;
int startIdxFirstArray =0;
int endIdxFirstArray =0;
int startIdxSecondArray=0;
int endIdxSecondArray=0;
if (diff > 1){ // more than two then split
mid = diff/2;
startIdxFirstArray = start;
endIdxFirstArray = start + mid;
startIdxSecondArray= start + mid +1;
endIdxSecondArray= end;
if (diff > SPLIT_THRESHOLD){
submitTaskToPool(new MergeSortTask(threadPool,this,sortArray,startIdxFirstArray, endIdxFirstArray));
submitTaskToPool(new MergeSortTask(threadPool,this,sortArray,startIdxSecondArray, endIdxSecondArray));
submitTaskToPool(this);
noOfSubtask=2;
//merge will be done by calling merge() function, when above tasks are done
splitDone.set(true);
return;
}else {
//recursive split and merge
splitSortNMerge (inList,startIdxFirstArray, endIdxFirstArray);
splitSortNMerge (inList,startIdxSecondArray, endIdxSecondArray);
merge(inList,startIdxFirstArray,endIdxFirstArray,startIdxSecondArray, endIdxSecondArray);
}
}else if (diff == 1){ // two element
if (inList[start] > inList[end]){ //swap
int tempVar = inList[start];
inList[start] = inList[end];
inList[end] = tempVar;
}
}
}
private void merge(){
int startIdxFirstArray =0;
int endIdxFirstArray =0;
int startIdxSecondArray=0;
int endIdxSecondArray=0;
int diff = iEnd - iStart;
int mid = 0;
mid = diff/2;
startIdxFirstArray = iStart;
endIdxFirstArray = iStart + mid;
startIdxSecondArray=iStart + mid +1;
endIdxSecondArray= iEnd;
merge(sortArray,startIdxFirstArray,endIdxFirstArray,startIdxSecondArray,endIdxSecondArray);
}
private static void merge(int [] inList,int startIdxFirstArray,int endIdxFirstArray,
int startIdxSecondArray, int endIdxSecondArray){
int firstArryPtr = startIdxFirstArray;
int secondArryPtr = startIdxSecondArray;
int [] tempArry = new int [endIdxSecondArray-startIdxFirstArray +1 ] ;
//merge in sorted order
int tempIdx =0;
for (tempIdx=0;tempIdx < tempArry.length ; tempIdx++){
if ( inList[firstArryPtr] < inList[secondArryPtr]){
tempArry[tempIdx] = inList[firstArryPtr];
firstArryPtr++;
if (firstArryPtr > endIdxFirstArray) break;
}else {
tempArry[tempIdx] = inList[secondArryPtr];
secondArryPtr++;
if (secondArryPtr > endIdxSecondArray) break;
}
}
if (firstArryPtr > endIdxFirstArray){
while (tempIdx < tempArry.length-1){
tempIdx++;
tempArry[tempIdx]= inList[secondArryPtr];
secondArryPtr++;
}
}else if (secondArryPtr > endIdxSecondArray){
while (tempIdx < tempArry.length-1){
tempIdx++;
tempArry[tempIdx]= inList[firstArryPtr];
firstArryPtr++;
}
}
//copy sorted array
for (int j=0;j < tempArry.length ; j++){
inList[startIdxFirstArray+j] = tempArry[j];
}
}
public boolean isReadyToProcess(){
if (!splitDone.get()){
return true;
}else if (taskDone==noOfSubtask){
return true;
}else {
return false;
}
}
public synchronized void subTaskDone(){//
taskDone++;
}
private void submitTaskToPool(MergeSortTask sortTask){
threadPool.addTask(sortTask);
}
public void run(){
if (splitDone.get()){
this.merge();
}else {
this.splitSortNMerge();
}
if (taskDone==noOfSubtask) {
if (parentTask !=null){
this.parentTask.subTaskDone();
}
synchronized(waitForTaskDone){
waitForTaskDone.notifyAll();
}
}
}
public int [] get() throws InterruptedException {
if (splitDone.get() && taskDone==noOfSubtask ){
return sortArray;
}else {
synchronized(waitForTaskDone){
waitForTaskDone.wait();
}
}
return sortArray;
}
public String toString() {
return "Task to sort from " + this.iStart +" To "+this.iEnd ;
}
}
Thread pool which will execute each task when it ready to process.
import java.util.ArrayList;
import java.util.List;
public class SortingThreadPool<E extends SortTask > {
List<E> taskList = new ArrayList<E>();
List<WorkerThread> workers = new ArrayList<WorkerThread>();
boolean shutdown = false;
public SortingThreadPool(int noOfThreads){
//add worker to connection pool
for (int i=0; i <noOfThreads;i++){
WorkerThread workerThread =new WorkerThread("Worker: "+i );
workers.add(workerThread);
workerThread.start();
}
}
public void addTask(E e){
synchronized(taskList){
taskList.add(e);
//System.out.println("Added task " + e);
//done adding all tasks, notify all waiting threads.
taskList.notifyAll();
}
}
public void shutDown(){
shutdown = true;
synchronized(taskList){
taskList.notifyAll();
}
}
private class WorkerThread extends Thread {
public WorkerThread(String name){
super(name);
}
public void run(){
System.out.println("Starting thread ");
while (!shutdown || !taskList.isEmpty() ){
//System.out.println("Reading task 2");
try{
synchronized(taskList){
// System.out.println("Reading task");
E tempTask= null;
for (E myTask : taskList){
// System.out.println( "Checking " +myTask);
if ( myTask.isReadyToProcess()){
tempTask = myTask;
}
}
if (tempTask !=null){
//System.out.println( this.getName() + " working on " + tempTask);
taskList.remove(tempTask);
tempTask.run();
}else {
if ( !shutdown || !taskList.isEmpty()){
System.out.println("Going to wait " + this.getName() + " size " + taskList.size());
taskList.wait();
}
}
//give other treads chance to work on tasks
Thread.yield();
}
}catch(Exception e){
e.printStackTrace();
}
}
System.out.println("Exiting thread " + this.getName());
}
}
}
Finally, the main class will create an instance of a thread pool and submit a task of the sort a million random numbers.
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
public class MergeSortInParallel {
public static void main (String [] args){
final int NO_OF_ELEMENT = 10000000;
Random random = new Random();
int [] sortArray = new int [NO_OF_ELEMENT ];
for (int i=0; i< NO_OF_ELEMENT ;i++){
sortArray[i] = random.nextInt(NO_OF_ELEMENT );
}
//int []sortArray = {16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1};
//int []sortArray = {10,9,8,7,6,5,4,3,2,1};
SortingThreadPool<SortTask> threadPool = new SortingThreadPool<SortTask>(10);
MergeSortTask sortTask = new MergeSortTask(threadPool,null,sortArray,0,sortArray.length-1);
long startTime = System.currentTimeMillis();
threadPool.addTask(sortTask);
try{
sortTask.get();
}catch(InterruptedException ie){
ie.printStackTrace();
}
System.out.println("Time taken " + (System.currentTimeMillis() - startTime));
threadPool.shutDown();
// for (int i=0; i<sortArray.length; i++){
// System.out.println(" element at " + i + " : " + sortArray[i] );
// }
}
}