public class GBTClassificationModel extends ProbabilisticClassificationModel<Vector,GBTClassificationModel> implements GBTClassifierParams, TreeEnsembleModel<DecisionTreeRegressionModel>, MLWritable, scala.Serializable
param: _trees Decision trees in the ensemble. param: _treeWeights Weights for the decision trees in the ensemble.
Constructor and Description |
---|
GBTClassificationModel(String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
Construct a GBTClassificationModel
|
Modifier and Type | Method and Description |
---|---|
BooleanParam |
cacheNodeIds()
If false, the algorithm will pass trees to executors to match instances with nodes.
|
IntParam |
checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
GBTClassificationModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
double[] |
evaluateEachIteration(Dataset<?> dataset)
Method to compute error or loss for every iteration of gradient boosting.
|
Vector |
featureImportances() |
Param<String> |
featureSubsetStrategy()
The number of features to consider for splits at each tree node.
|
int |
getNumTrees()
Number of trees in ensemble
|
Param<String> |
impurity()
Criterion used for information gain calculation (case-insensitive).
|
Param<String> |
leafCol()
Leaf indices column name.
|
static GBTClassificationModel |
load(String path) |
Param<String> |
lossType()
Loss function which GBT tries to minimize.
|
IntParam |
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split
on features at each node.
|
IntParam |
maxDepth()
Maximum depth of the tree (nonnegative).
|
IntParam |
maxIter()
Param for maximum number of iterations (>= 0).
|
IntParam |
maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation.
|
DoubleParam |
minInfoGain()
Minimum information gain for a split to be considered at a tree node.
|
IntParam |
minInstancesPerNode()
Minimum number of instances each child must have after split.
|
DoubleParam |
minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split.
|
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given features.
|
Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
static MLReader<GBTClassificationModel> |
read() |
LongParam |
seed()
Param for random seed.
|
DoubleParam |
stepSize()
Param for Step size (a.k.a.
|
DoubleParam |
subsamplingRate()
Fraction of the training data used for learning each decision tree, in range (0, 1].
|
String |
toString()
Summary of the model
|
int |
totalNumNodes()
Total number of nodes, summed over all trees in the ensemble.
|
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol , and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector
- probability of each class as probabilityCol of type Vector . |
StructType |
transformSchema(StructType schema)
Check transform validity and derive the output schema from the input schema.
|
DecisionTreeRegressionModel[] |
trees()
Trees in this ensemble.
|
double[] |
treeWeights()
Weights for each tree, zippable with
trees |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
validationIndicatorCol()
Param for name of the column that indicates whether each row is for training or for validation.
|
DoubleParam |
validationTol()
Threshold for stopping early when fit with validation is used.
|
Param<String> |
weightCol()
Param for weight column name.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, probabilityCol, setProbabilityCol, setThresholds, thresholds
rawPredictionCol, setRawPredictionCol, transformImpl
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionCol
transform, transform, transform
params
getLossType, getOldLossType
getOldBoostingStrategy, getValidationTol
getMaxIter
getStepSize
getValidationIndicatorCol
validateAndTransformSchema
getFeatureSubsetStrategy, getOldStrategy, getSubsamplingRate
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafCol
getCheckpointInterval
getWeightCol
extractInstances
extractInstances, extractInstances
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
getRawPredictionCol, rawPredictionCol
getProbabilityCol, probabilityCol
getThresholds, thresholds
getImpurity, getOldImpurity
getLeafField, javaTreeWeights, predictLeaf, toDebugString
save
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public GBTClassificationModel(String uid, DecisionTreeRegressionModel[] _trees, double[] _treeWeights)
_trees
- Decision trees in the ensemble._treeWeights
- Weights for the decision trees in the ensemble.uid
- (undocumented)public static MLReader<GBTClassificationModel> read()
public static GBTClassificationModel load(String path)
public int totalNumNodes()
TreeEnsembleModel
totalNumNodes
in interface TreeEnsembleModel<DecisionTreeRegressionModel>
public Param<String> lossType()
GBTClassifierParams
lossType
in interface GBTClassifierParams
public final Param<String> impurity()
HasVarianceImpurity
impurity
in interface HasVarianceImpurity
public final DoubleParam validationTol()
GBTParams
validationTol
in interface GBTParams
validationIndicatorCol
public final DoubleParam stepSize()
GBTParams
stepSize
in interface HasStepSize
stepSize
in interface GBTParams
public final Param<String> validationIndicatorCol()
HasValidationIndicatorCol
validationIndicatorCol
in interface HasValidationIndicatorCol
public final IntParam maxIter()
HasMaxIter
maxIter
in interface HasMaxIter
public final DoubleParam subsamplingRate()
TreeEnsembleParams
subsamplingRate
in interface TreeEnsembleParams
public final Param<String> featureSubsetStrategy()
TreeEnsembleParams
These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
featureSubsetStrategy
in interface TreeEnsembleParams
public final Param<String> leafCol()
DecisionTreeParams
leafCol
in interface DecisionTreeParams
public final IntParam maxDepth()
DecisionTreeParams
maxDepth
in interface DecisionTreeParams
public final IntParam maxBins()
DecisionTreeParams
maxBins
in interface DecisionTreeParams
public final IntParam minInstancesPerNode()
DecisionTreeParams
minInstancesPerNode
in interface DecisionTreeParams
public final DoubleParam minWeightFractionPerNode()
DecisionTreeParams
minWeightFractionPerNode
in interface DecisionTreeParams
public final DoubleParam minInfoGain()
DecisionTreeParams
minInfoGain
in interface DecisionTreeParams
public final IntParam maxMemoryInMB()
DecisionTreeParams
maxMemoryInMB
in interface DecisionTreeParams
public final BooleanParam cacheNodeIds()
DecisionTreeParams
cacheNodeIds
in interface DecisionTreeParams
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final LongParam seed()
HasSeed
public final IntParam checkpointInterval()
HasCheckpointInterval
checkpointInterval
in interface HasCheckpointInterval
public String uid()
Identifiable
uid
in interface Identifiable
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,GBTClassificationModel>
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,GBTClassificationModel>
public DecisionTreeRegressionModel[] trees()
TreeEnsembleModel
trees
in interface TreeEnsembleModel<DecisionTreeRegressionModel>
public int getNumTrees()
public double[] treeWeights()
TreeEnsembleModel
trees
treeWeights
in interface TreeEnsembleModel<DecisionTreeRegressionModel>
public StructType transformSchema(StructType schema)
PipelineStage
We check validity for interactions between parameters during transformSchema
and
raise an exception if any parameter value is invalid. Parameter value checks which
do not depend on other parameters are handled by Param.validate()
.
Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
transformSchema
in class ProbabilisticClassificationModel<Vector,GBTClassificationModel>
schema
- (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
ProbabilisticClassificationModel
featuresCol
, and appending new columns as specified by
parameters:
- predicted labels as predictionCol
of type Double
- raw predictions (confidences) as rawPredictionCol
of type Vector
- probability of each class as probabilityCol
of type Vector
.
transform
in class ProbabilisticClassificationModel<Vector,GBTClassificationModel>
dataset
- input datasetpublic double predict(Vector features)
ClassificationModel
transform()
and output predictionCol
.
This default implementation for classification predicts the index of the maximum value
from predictRaw()
.
predict
in class ClassificationModel<Vector,GBTClassificationModel>
features
- (undocumented)public Vector predictRaw(Vector features)
ClassificationModel
transform()
and output rawPredictionCol
.
predictRaw
in class ClassificationModel<Vector,GBTClassificationModel>
features
- (undocumented)public GBTClassificationModel copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Model<GBTClassificationModel>
extra
- (undocumented)public String toString()
TreeEnsembleModel
toString
in interface TreeEnsembleModel<DecisionTreeRegressionModel>
toString
in interface Identifiable
toString
in class Object
public Vector featureImportances()
public double[] evaluateEachIteration(Dataset<?> dataset)
dataset
- Dataset for validation.public MLWriter write()
MLWritable
MLWriter
instance for this ML instance.write
in interface MLWritable