public class RandomForest
extends Object
This is a sketch of the algorithm to help new developers.
The algorithm partitions data by instances (rows). On each iteration, the algorithm splits a set of nodes. In order to choose the best split for a given node, sufficient statistics are collected from the distributed data. For each node, the statistics are collected to some worker node, and that worker selects the best split.
This setup requires discretization of continuous features. This binning is done in the findSplits() method during initialization, after which each continuous feature becomes an ordered discretized feature with at most maxBins possible values.
The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes lie at the periphery of the tree being trained. If multiple trees are being trained at once, then this queue contains nodes from all of them. Each iteration works roughly as follows: On the master node: - Some number of nodes are pulled off of the queue (based on the amount of memory required for their sufficient statistics). - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate features are chosen for each node. See method selectNodesToSplit(). On worker nodes, via method findBestSplits(): - The worker makes one pass over its subset of instances. - For each (tree, node, feature, split) tuple, the worker collects statistics about splitting. Note that the set of (tree, node) pairs is limited to the nodes selected from the queue for this iteration. The set of features considered can also be limited based on featureSubsetStrategy. - For each node, the statistics for that node are aggregated to a particular worker via reduceByKey(). The designated worker chooses the best (feature, split) pair, or chooses to stop splitting if the stopping criteria are met. On the master node: - The master collects all decisions about splitting nodes and updates the model. - The updated model is passed to the workers on the next iteration. This process continues until the node queue is empty.
Most of the methods in this implementation support the statistics aggregation, which is the heaviest part of the computation. In general, this implementation is bound by either the cost of statistics computation on workers or by communicating the sufficient statistics.
Constructor and Description |
---|
RandomForest() |
Modifier and Type | Method and Description |
---|---|
static DecisionTreeModel[] |
run(RDD<LabeledPoint> input,
Strategy strategy,
int numTrees,
String featureSubsetStrategy,
long seed,
scala.Option<org.apache.spark.ml.util.Instrumentation> instr,
boolean prune,
scala.Option<String> parentUID)
Train a random forest.
|
public static DecisionTreeModel[] run(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID)
input
- Training data: RDD of LabeledPoint
strategy
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)seed
- (undocumented)instr
- (undocumented)prune
- (undocumented)parentUID
- (undocumented)