public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
int |
numClasses()
Number of classes (values which the label can take).
|
protected double |
predict(Vector features)
Predict label for the given feature vector.
|
protected Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
normalizeToProbabilitiesInPlace, predictProbability, raw2probability, setProbabilityCol, transform
setRawPredictionCol
featuresDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public Vector weights()
public double intercept()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value
- (undocumented)public double getThreshold()
If threshold
is set, returns that value.
Otherwise, if thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
trainingSummary == None
.public boolean hasSummary()
protected double predict(Vector features)
thresholds
.predict
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform()
and output probabilityCol
.
raw2probabilityInPlace
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected Vector predictRaw(Vector features)
ClassificationModel
transform()
and output rawPredictionCol
.
predictRaw
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Model<LogisticRegressionModel>
extra
- (undocumented)defaultCopy()
protected double raw2prediction(Vector rawPrediction)
ClassificationModel
raw2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected double probability2prediction(Vector probability)
ProbabilisticClassificationModel
probability2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
probability
- (undocumented)public void checkThresholdConsistency()
threshold
and thresholds
are both set, ensures they are consistent.java.lang.IllegalArgumentException
- if threshold
and thresholds
are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.