public class GBTClassifier extends ProbabilisticClassifier<Vector,GBTClassifier,GBTClassificationModel> implements GBTClassifierParams, DefaultParamsWritable, org.apache.spark.internal.Logging
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - We expect to implement TreeBoost in the future: [https://issues.apache.org/jira/browse/SPARK-4240]
Constructor and Description |
---|
GBTClassifier() |
GBTClassifier(String uid) |
Modifier and Type | Method and Description |
---|---|
BooleanParam |
cacheNodeIds()
If false, the algorithm will pass trees to executors to match instances with nodes.
|
IntParam |
checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
GBTClassifier |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<String> |
featureSubsetStrategy()
The number of features to consider for splits at each tree node.
|
Param<String> |
impurity()
Criterion used for information gain calculation (case-insensitive).
|
Param<String> |
leafCol()
Leaf indices column name.
|
static GBTClassifier |
load(String path) |
Param<String> |
lossType()
Loss function which GBT tries to minimize.
|
IntParam |
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split
on features at each node.
|
IntParam |
maxDepth()
Maximum depth of the tree (nonnegative).
|
IntParam |
maxIter()
Param for maximum number of iterations (>= 0).
|
IntParam |
maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation.
|
DoubleParam |
minInfoGain()
Minimum information gain for a split to be considered at a tree node.
|
IntParam |
minInstancesPerNode()
Minimum number of instances each child must have after split.
|
DoubleParam |
minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split.
|
static MLReader<T> |
read() |
LongParam |
seed()
Param for random seed.
|
GBTClassifier |
setCacheNodeIds(boolean value) |
GBTClassifier |
setCheckpointInterval(int value)
Specifies how often to checkpoint the cached node IDs.
|
GBTClassifier |
setFeatureSubsetStrategy(String value) |
GBTClassifier |
setImpurity(String value)
The impurity setting is ignored for GBT models.
|
GBTClassifier |
setLossType(String value) |
GBTClassifier |
setMaxBins(int value) |
GBTClassifier |
setMaxDepth(int value) |
GBTClassifier |
setMaxIter(int value) |
GBTClassifier |
setMaxMemoryInMB(int value) |
GBTClassifier |
setMinInfoGain(double value) |
GBTClassifier |
setMinInstancesPerNode(int value) |
GBTClassifier |
setMinWeightFractionPerNode(double value) |
GBTClassifier |
setSeed(long value) |
GBTClassifier |
setStepSize(double value) |
GBTClassifier |
setSubsamplingRate(double value) |
GBTClassifier |
setValidationIndicatorCol(String value) |
GBTClassifier |
setWeightCol(String value)
Sets the value of param
weightCol . |
DoubleParam |
stepSize()
Param for Step size (a.k.a.
|
DoubleParam |
subsamplingRate()
Fraction of the training data used for learning each decision tree, in range (0, 1].
|
static String[] |
supportedLossTypes()
Accessor for supported loss settings: logistic
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
validationIndicatorCol()
Param for name of the column that indicates whether each row is for training or for validation.
|
DoubleParam |
validationTol()
Threshold for stopping early when fit with validation is used.
|
Param<String> |
weightCol()
Param for weight column name.
|
probabilityCol, setProbabilityCol, setThresholds, thresholds
rawPredictionCol, setRawPredictionCol
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
params
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getLossType, getOldLossType
getOldBoostingStrategy, getValidationTol
getMaxIter
getStepSize
getValidationIndicatorCol
validateAndTransformSchema
getFeatureSubsetStrategy, getOldStrategy, getSubsamplingRate
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafCol
getCheckpointInterval
getWeightCol
extractInstances
extractInstances, extractInstances
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
getRawPredictionCol, rawPredictionCol
getProbabilityCol, probabilityCol
getThresholds, thresholds
getImpurity, getOldImpurity
write
save
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public GBTClassifier(String uid)
public GBTClassifier()
public static final String[] supportedLossTypes()
public static GBTClassifier load(String path)
public static MLReader<T> read()
public Param<String> lossType()
GBTClassifierParams
lossType
in interface GBTClassifierParams
public final Param<String> impurity()
HasVarianceImpurity
impurity
in interface HasVarianceImpurity
public final DoubleParam validationTol()
GBTParams
validationTol
in interface GBTParams
validationIndicatorCol
public final DoubleParam stepSize()
GBTParams
stepSize
in interface HasStepSize
stepSize
in interface GBTParams
public final Param<String> validationIndicatorCol()
HasValidationIndicatorCol
validationIndicatorCol
in interface HasValidationIndicatorCol
public final IntParam maxIter()
HasMaxIter
maxIter
in interface HasMaxIter
public final DoubleParam subsamplingRate()
TreeEnsembleParams
subsamplingRate
in interface TreeEnsembleParams
public final Param<String> featureSubsetStrategy()
TreeEnsembleParams
These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
featureSubsetStrategy
in interface TreeEnsembleParams
public final Param<String> leafCol()
DecisionTreeParams
leafCol
in interface DecisionTreeParams
public final IntParam maxDepth()
DecisionTreeParams
maxDepth
in interface DecisionTreeParams
public final IntParam maxBins()
DecisionTreeParams
maxBins
in interface DecisionTreeParams
public final IntParam minInstancesPerNode()
DecisionTreeParams
minInstancesPerNode
in interface DecisionTreeParams
public final DoubleParam minWeightFractionPerNode()
DecisionTreeParams
minWeightFractionPerNode
in interface DecisionTreeParams
public final DoubleParam minInfoGain()
DecisionTreeParams
minInfoGain
in interface DecisionTreeParams
public final IntParam maxMemoryInMB()
DecisionTreeParams
maxMemoryInMB
in interface DecisionTreeParams
public final BooleanParam cacheNodeIds()
DecisionTreeParams
cacheNodeIds
in interface DecisionTreeParams
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final LongParam seed()
HasSeed
public final IntParam checkpointInterval()
HasCheckpointInterval
checkpointInterval
in interface HasCheckpointInterval
public String uid()
Identifiable
uid
in interface Identifiable
public GBTClassifier setMaxDepth(int value)
public GBTClassifier setMaxBins(int value)
public GBTClassifier setMinInstancesPerNode(int value)
public GBTClassifier setMinWeightFractionPerNode(double value)
public GBTClassifier setMinInfoGain(double value)
public GBTClassifier setMaxMemoryInMB(int value)
public GBTClassifier setCacheNodeIds(boolean value)
public GBTClassifier setCheckpointInterval(int value)
SparkContext
.
Must be at least 1.
(default = 10)value
- (undocumented)public GBTClassifier setImpurity(String value)
value
- (undocumented)public GBTClassifier setSubsamplingRate(double value)
public GBTClassifier setSeed(long value)
public GBTClassifier setMaxIter(int value)
public GBTClassifier setStepSize(double value)
public GBTClassifier setFeatureSubsetStrategy(String value)
public GBTClassifier setLossType(String value)
public GBTClassifier setValidationIndicatorCol(String value)
public GBTClassifier setWeightCol(String value)
weightCol
.
If this is not set or empty, we treat all instance weights as 1.0.
By default the weightCol is not set, so all instances have weight 1.0.
value
- (undocumented)public GBTClassifier copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,GBTClassifier,GBTClassificationModel>
extra
- (undocumented)