public class MultilayerPerceptronClassifier extends Predictor<Vector,MultilayerPerceptronClassifier,MultilayerPerceptronClassificationModel>
Constructor and Description |
---|
MultilayerPerceptronClassifier() |
MultilayerPerceptronClassifier(java.lang.String uid) |
Modifier and Type | Method and Description |
---|---|
IntParam |
blockSize()
Block size for stacking input data in matrices to speed up the computation.
|
MultilayerPerceptronClassifier |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
int |
getBlockSize() |
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
int[] |
getLayers() |
java.lang.String |
getPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
IntArrayParam |
layers()
Layer sizes including input size and output size.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
MultilayerPerceptronClassifier |
setBlockSize(int value) |
MultilayerPerceptronClassifier |
setLayers(int[] value) |
MultilayerPerceptronClassifier |
setMaxIter(int value)
Set the maximum number of iterations.
|
MultilayerPerceptronClassifier |
setSeed(long value)
Set the seed for weights initialization.
|
MultilayerPerceptronClassifier |
setTol(double value)
Set the convergence tolerance of iterations.
|
protected MultilayerPerceptronClassificationModel |
train(DataFrame dataset)
Train a model using the given dataset and parameters.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
extractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public MultilayerPerceptronClassifier(java.lang.String uid)
public MultilayerPerceptronClassifier()
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public MultilayerPerceptronClassifier setLayers(int[] value)
public MultilayerPerceptronClassifier setBlockSize(int value)
public MultilayerPerceptronClassifier setMaxIter(int value)
value
- (undocumented)public MultilayerPerceptronClassifier setTol(double value)
value
- (undocumented)public MultilayerPerceptronClassifier setSeed(long value)
value
- (undocumented)public MultilayerPerceptronClassifier copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Predictor<Vector,MultilayerPerceptronClassifier,MultilayerPerceptronClassificationModel>
extra
- (undocumented)defaultCopy()
protected MultilayerPerceptronClassificationModel train(DataFrame dataset)
fit()
to avoid dealing with schema validation
and copying parameters into the model.
train
in class Predictor<Vector,MultilayerPerceptronClassifier,MultilayerPerceptronClassificationModel>
dataset
- Training datasetpublic IntArrayParam layers()
public int[] getLayers()
public IntParam blockSize()
public int getBlockSize()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()