public class FMRegressor extends Regressor<Vector,FMRegressor,FMRegressionModel> implements FactorizationMachines, FMRegressorParams, DefaultParamsWritable, org.apache.spark.internal.Logging
The implementation is based upon: S. Rendle. "Factorization machines" 2010.
FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:
$$ \begin{align} y = w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \end{align} $$First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.
FM regression model uses MSE loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.
Constructor and Description |
---|
FMRegressor() |
FMRegressor(String uid) |
Modifier and Type | Method and Description |
---|---|
FMRegressor |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
IntParam |
factorSize()
Param for dimensionality of the factors (>= 0)
|
BooleanParam |
fitIntercept()
Param for whether to fit an intercept term.
|
BooleanParam |
fitLinear()
Param for whether to fit linear term (aka 1-way term)
|
DoubleParam |
initStd()
Param for standard deviation of initial coefficients
|
static FMRegressor |
load(String path) |
IntParam |
maxIter()
Param for maximum number of iterations (>= 0).
|
DoubleParam |
miniBatchFraction()
Param for mini-batch fraction, must be in range (0, 1]
|
static MLReader<T> |
read() |
DoubleParam |
regParam()
Param for regularization parameter (>= 0).
|
LongParam |
seed()
Param for random seed.
|
FMRegressor |
setFactorSize(int value)
Set the dimensionality of the factors.
|
FMRegressor |
setFitIntercept(boolean value)
Set whether to fit intercept term.
|
FMRegressor |
setFitLinear(boolean value)
Set whether to fit linear term.
|
FMRegressor |
setInitStd(double value)
Set the standard deviation of initial coefficients.
|
FMRegressor |
setMaxIter(int value)
Set the maximum number of iterations.
|
FMRegressor |
setMiniBatchFraction(double value)
Set the mini-batch fraction parameter.
|
FMRegressor |
setRegParam(double value)
Set the L2 regularization parameter.
|
FMRegressor |
setSeed(long value)
Set the random seed for weight initialization.
|
FMRegressor |
setSolver(String value)
Set the solver algorithm used for optimization.
|
FMRegressor |
setStepSize(double value)
Set the initial step size for the first step (like learning rate).
|
FMRegressor |
setTol(double value)
Set the convergence tolerance of iterations.
|
Param<String> |
solver()
The solver algorithm for optimization.
|
DoubleParam |
stepSize()
Param for Step size to be used for each iteration of optimization (> 0).
|
DoubleParam |
tol()
Param for the convergence tolerance for iterative algorithms (>= 0).
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
weightCol()
Param for weight column name.
|
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
params
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initCoefficients, trainImpl
getFactorSize, getFitLinear, getInitStd, getMiniBatchFraction
extractInstances, extractInstances, validateAndTransformSchema
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
getMaxIter
getStepSize
getFitIntercept
getRegParam
getWeightCol
write
save
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public static FMRegressor load(String path)
public static MLReader<T> read()
public final IntParam factorSize()
FactorizationMachinesParams
factorSize
in interface FactorizationMachinesParams
public final BooleanParam fitLinear()
FactorizationMachinesParams
fitLinear
in interface FactorizationMachinesParams
public final DoubleParam miniBatchFraction()
FactorizationMachinesParams
miniBatchFraction
in interface FactorizationMachinesParams
public final DoubleParam initStd()
FactorizationMachinesParams
initStd
in interface FactorizationMachinesParams
public final Param<String> solver()
FactorizationMachinesParams
solver
in interface HasSolver
solver
in interface FactorizationMachinesParams
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final DoubleParam regParam()
HasRegParam
regParam
in interface HasRegParam
public final BooleanParam fitIntercept()
HasFitIntercept
fitIntercept
in interface HasFitIntercept
public final LongParam seed()
HasSeed
public final DoubleParam tol()
HasTol
public DoubleParam stepSize()
HasStepSize
stepSize
in interface HasStepSize
public final IntParam maxIter()
HasMaxIter
maxIter
in interface HasMaxIter
public String uid()
Identifiable
uid
in interface Identifiable
public FMRegressor setFactorSize(int value)
value
- (undocumented)public FMRegressor setFitIntercept(boolean value)
value
- (undocumented)public FMRegressor setFitLinear(boolean value)
value
- (undocumented)public FMRegressor setRegParam(double value)
value
- (undocumented)public FMRegressor setMiniBatchFraction(double value)
value
- (undocumented)public FMRegressor setInitStd(double value)
value
- (undocumented)public FMRegressor setMaxIter(int value)
value
- (undocumented)public FMRegressor setStepSize(double value)
value
- (undocumented)public FMRegressor setTol(double value)
value
- (undocumented)public FMRegressor setSolver(String value)
value
- (undocumented)public FMRegressor setSeed(long value)
value
- (undocumented)public FMRegressor copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,FMRegressor,FMRegressionModel>
extra
- (undocumented)