合并任务的结果

Fork/Join 框架提供了执行返回结果的任务的能力。此类任务由 RecursiveTask 类实现。 此类扩展 ForkJoinTask 并实现了 Executor 框架提供的 Future 接口。

在任务中,您必须使用 Java API 文档推荐的结构:

    If (problem size > size){
        tasks=Divide(task);
        execute(tasks);
        groupResults()
        return result;
    } else {
        resolve problem;
        return result;
    }

本节任务

开发一个应用程序在一份文档里查找一个单词。实现以下两种任务:

  • 一个文档任务,在一个文档的行集合里搜索一个单词
  • 一个行任务,在文档的一部分中搜索一个单词

所有任务将返回在它们处理部分的文档或行中指定单词的出现次数。

实现

本节的示例代码在 com.elanzone.books.noteeg.chpt5.sect03 package中

  • 文档类 : Document ,将生成一个字符串矩阵模拟一份文档

    • 创建一个字符串数组,用来生成字符串矩阵
            private String words[] = {
                    "the", "hello", "goodbye", "package", "java", "thread", "pool", "random", "class", "main"
            };
    
    • generateDocument() 方法:
      • 指定行数、每行指定单词数生成一个字符串矩阵,矩阵中的单词从 words 数组里随机抽取
      • 生成过程统计指定单词出现的次数
            public String[][] generateDocument(int numLines, int numWords,
                                               String word) {
                int counter = 0;
                String document[][] = new String[numLines][numWords];
                Random random = new Random();
    
                for (int i = 0; i < numLines; i++) {
                    for (int j = 0; j < numWords; j++) {
                        int index = random.nextInt(words.length);
                        document[i][j] = words[index];
                        if (document[i][j].equals(word)) {
                            counter++;
                        }
                    }
                }
    
                System.out.println("DocumentMock: The word appears " + counter + " times in the document");
    
                return document;
            }
    
  • 行处理任务类 : LineTask

    • 扩展 RecursiveTask<Integer> : (结果是整数,故泛型参数为 Integer)
        public class LineTask extends RecursiveTask<Integer> {
    
    • 声明类的序列版本号(因为 RecursiveTask 的父类 ForkJoinTask 实现了 Serializable 接口)
            private static final long serialVersionUID = 2L;
    
    • 属性及构造函数
            private String line[];
            private int start, end;
            private String word;
    
            public LineTask(String[] line, int start, int end, String word) {
                this.line = line;
                this.start = start;
                this.end = end;
                this.word = word;
            }
    
    • compute() 方法
      • 如果要处理的行片段(start 和 end 之间)的单词数小于 100 ,则使用 count() 方法直接计数
      • 否则:
        • 将这组字符串分为 2 个,创建 2 个新的 LineTask 对象来处理,调用 invokeAll() 方法在线程池中执行它们
        • 使用 groupResults() 方法将2个任务的结果合并
      • 返回最终结果
            @Override
            protected Integer compute() {
                Integer result = null;
                if (end - start < 100) {
                    result = count(line, start, end, word);
                } else {
                    int mid = (start + end) / 2;
                    LineTask task1 = new LineTask(line, start, mid, word);
                    LineTask task2 = new LineTask(line, mid, end, word);
                    invokeAll(task1, task2);
    
                    try {
                        result = groupResults(task1.get(), task2.get());
                    } catch (InterruptedException | ExecutionException e) {
                        e.printStackTrace();
                    }
                }
    
                return result;
            }
    
    • count() 方法
      • 在 start 和 end 限定的范围内逐个字符串比较计数
      • 为了演示而降低执行速度,睡上 10 毫秒
      • 返回结果
            private Integer count(String[] line, int start, int end, String word) {
                int counter;
                counter = 0;
                for (int i = start; i < end; i++) {
                    if (line[i].equals(word)) {
                        counter++;
                    }
                }
    
                try {
                    Thread.sleep(10);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
    
                return counter;
            }
    
    • groupResults 方法 :将2个数字相加后返回结果
  • 文档处理任务类 : DocumentTask

    • 扩展 RecursiveTask<Integer> : (结果是整数,故泛型参数为 Integer)
        public class DocumentTask extends RecursiveTask<Integer> {
    
    • 声明类的序列版本号(因为 RecursiveTask 的父类 ForkJoinTask 实现了 Serializable 接口)
            private static final long serialVersionUID = 1L;
    
    • 属性及构造函数
            private String line[];
            private int start, end;
            private String word;
    
            public LineTask(String[] line, int start, int end, String word) {
                this.line = line;
                this.start = start;
                this.end = end;
                this.word = word;
            }
    
    • compute() 方法
      • 如果要处理的行数(由 start 和 end 属性限定)小于 10 行,则调用 processLines() 方法处理
      • 否则:
        • 将其均分为2部分,创建2个新的 DocumentTask 对象来处理,并用 invokeAll() 方法执行它们
        • 使用 groupResults() 方法将2个任务的结果合并
      • 返回最终结果
            @Override
            protected Integer compute() {
                int result = 0;
                if (end - start < 10) {
                    result = processLines(document, start, end, word);
                } else {
                    int mid = (start + end) / 2;
                    DocumentTask task1 = new DocumentTask(document, start, mid, word);
                    DocumentTask task2 = new DocumentTask(document, mid, end, word);
                    invokeAll(task1, task2);
    
                    try {
                        result = groupResults(task1.get(), task2.get());
                    } catch (InterruptedException | ExecutionException e) {
                        e.printStackTrace();
                    }
                }
    
                return result;
            }
    
    • processLines() 方法
      • 对要处理的每行,创建一个 LineTask 对象来处理整行,并保存到一个 List
      • 调用 invokeAll() 方法执行 LineTask List中的所有任务
      • 把所有这些任务的结果求和并返回结果
            private Integer processLines(String[][] document, int start, int
                    end, String word) {
                List<LineTask> tasks = new ArrayList<LineTask>();
                for (int i = start; i < end; i++) {
                    LineTask task = new LineTask(document[i], 0, document[i].length, word);
                    tasks.add(task);
                }
                invokeAll(tasks);
    
                int result = 0;
                for (LineTask task : tasks) {
                    try {
                        result = result + task.get();
                    } catch (InterruptedException | ExecutionException e) {
                        e.printStackTrace();
                    }
                }
                return result;
            }
    
    • groupResults 方法 :将2个数字相加后返回结果
  • 控制类 : Main

    • 创建一个 Document 对象并用它生成一份有 100 行、每行有 1000 个单词的文档(字符串矩阵)
        Document mock = new Document();
        String[][] document = mock.generateDocument(100, 1000, "the");
    
    • 创建一个 DocumentTask 对象以处理整份文档(start = 0, end = 100)计算the的出现次数
        DocumentTask task = new DocumentTask(document, 0, 100, "the");
    
    • 创建一个 ForkJoinPool 对象并调用 execute() 方法执行 task
        ForkJoinPool pool = new ForkJoinPool();
        pool.execute(task);
    
    • 每秒输出一次线程池的信息直到任务完成
        do {
            System.out.printf("******************************************\n");
            System.out.printf("Main: Parallelism: %d\n", pool.getParallelism());
            System.out.printf("Main: Active Threads: %d\n", pool.getActiveThreadCount());
            System.out.printf("Main: Task Count: %d\n", pool.getQueuedTaskCount());
            System.out.printf("Main: Steal Count: %d\n", pool.getStealCount());
            System.out.printf("******************************************\n");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
    
    • 使用 shutdown() 方法停止线程池
        pool.shutdown();
    
    • 等待任务完成
        try {
            pool.awaitTermination(1, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    
    • 输出任务的结果,与预期的比较
        try {
            System.out.printf("Main: The word appears %d in the document", task.get());
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
    

讲解

在本例中实现了2个不同的任务:

  • DocumentTask 类: 此类的任务是处理由 start 和 end 属性决定的文档的行集合。
    • 如果此集合的行数小于10,则每行创建一个 LineTask,当 LineTask 们执行结束,则将结果求和并返回。
    • 否则将要处理的行集合均分成2个,创建2个DocumentTask对象来处理新的行集合。当这些任务运行完后,再把结果求和并返回。
  • LineTask 类:处理文件的一行中的单词集。
    • 如果单词集中的单词数小于 100,此任务直接在单词集中搜索指定单词,并返回指定单词的出现次数
    • 否则将单词集均分成2个并创建2个 LineTask 对象来处理新的单词集。当这些任务运行完后,再把结果求和并返回。

在 Main 类中,使用默认的构造函数创建了一个 ForkJoinPool 对象, 并在其中运行一个 DocumentTask 任务处理一个有 100 行、每行 1000 个单词的文档。 此任务将用其他的 DocumentTask 对象和 LineTask 对象将此问题分解, 当所有任务运行完,可用原始任务获得单词在整个文档中出现的总次数。 因为这些任务返回结果,它们扩展 RecursiveTask 类。

为了获得任务返回的结果,使用了 get() 方法。此方法在 Future 接口中声明,在 RecursiveTask 类中实现。

了解更多

ForkJoinTask 类提供了另一个方法来结束任务的执行并返回结果。那就是 complete() 方法。 此方法接受一个对象(类型为在RecursiveTask泛型类中使用的类型)为参数,并在join()方法被调用时返回此对象作为任务的结果。 建议用于为异步任务提供结果。

因为 RecursiveTask 类实现了 Future 接口,所以有另一个版本的 get() 方法:

  • get(long timeout, TimeUnit unit): 如果任务的结果还没有出来,则等待指定的时间; 如果指定的时间期限已过而结果仍没有出来,则返回 null。