Skip to content

Commit

Permalink
Add signed percentage error histogram to regression evaluator (#496)
Browse files Browse the repository at this point in the history
* add regression threshold metrics case classes and logical skeleton

* add tests for threshold metrics

* WIP regression threshold metrics

* improved implementation signed percentage error histogram; add test suite for it

* fix whitespace

* avoid hardcoded default evaluators but instead set them as defaults; backward compatible by prepending these defaults if necessary

* Revert "avoid hardcoded default evaluators but instead set them as defaults; backward compatible by prepending these defaults if necessary"

This reverts commit 218850b

* add enum

* nested SignedPercentageErrorHistogram case class; use Longs for counts

* rename test method

* drop smartCutoff parameter in favor of defined/undefined smartCutoffRatio

* optimization: don't get param inside .map

* docs

* make it explicit that we want a call by value

* make it explicit that we want a call by value

Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com>
  • Loading branch information
nicodv and TuanNguyen27 authored Jul 28, 2020
1 parent 964b58e commit d0b1038
Show file tree
Hide file tree
Showing 5 changed files with 409 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import com.salesforce.op.utils.spark.RichMetadata._
import enumeratum.{Enum, EnumEntry}
import org.apache.spark.sql.types.Metadata

import scala.util.Try


/**
* Trait for all different kinds of evaluation metrics
Expand Down Expand Up @@ -181,6 +179,8 @@ object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error", false)
case object R2 extends RegressionEvalMetric("r2", "r2", true)
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error", false)
case object SignedPercentageErrorHistogram extends RegressionEvalMetric("pctErrorHst",
"signed percentage error histogram", false)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,25 @@

package com.salesforce.op.evaluators

import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.salesforce.op.UID
import com.salesforce.op.utils.spark.RichEvaluator._
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.param.{DoubleArrayParam, DoubleParam}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.slf4j.LoggerFactory

/**
*
* Instance to evaluate Regression metrics
* The metrics are rmse, mse, r2 and mae
* Evaluator for regression metrics.
* The metrics are RMSE, MSE, R2, MASE, and a histogram of the signed percentage errors.
* For the percentage errors, it deals with the difficulties that occur with label
* values around 0, and exposes several parameters to control that behavior.
* Default evaluation returns Root Mean Squared Error
*
* @param name name of default metric
* @param isLargerBetter is metric better if larger
* @param uid uid for instance
*/

Expand All @@ -55,6 +60,40 @@ private[op] class OpRegressionEvaluator

@transient private lazy val log = LoggerFactory.getLogger(this.getClass)

final val signedPercentageErrorHistogramBins = new DoubleArrayParam(
parent = this,
name = "signedPercentageErrorHistogramBins",
doc = "the sequence of error percentage bins for the signed percentage error histogram",
isValid = l => l.nonEmpty && (l sameElements l.sorted)
)
setDefault(signedPercentageErrorHistogramBins,
Array(Double.NegativeInfinity) ++ (-100.0 to 100.0 by 10) ++ Array(Double.PositiveInfinity)
)

def setPercentageErrorHistogramBins(v: Array[Double]): this.type = set(signedPercentageErrorHistogramBins, v)

final val scaledErrorCutoff = new DoubleParam(
parent = this,
name = "scaledErrorCutoff",
doc = "the label value cutoff below which percentage error is implemented as a scaled error " +
"with a fixed denominator to avoid problems with label values around 0",
isValid = (d: Double) => d > 0.0
)
setDefault(scaledErrorCutoff, 1E-3)

def setScaledErrorCutoff(v: Double): this.type = set(scaledErrorCutoff, v)
def getScaledErrorCutoff: Option[Double] = get(scaledErrorCutoff)

final val smartCutoffRatio = new DoubleParam(
parent = this,
name = "smartCutoffRatio",
doc = "if set, scaledErrorCutoff is determined smartly by taking the average absolute magnitude " +
"of the data multiplied with this ratio",
isValid = (d: Double) => d > 0.0
)

def setSmartCutoffRatio(v: Double): this.type = set(smartCutoffRatio, v)

def getDefaultMetric: RegressionMetrics => Double = _.RootMeanSquaredError

override def evaluateAll(data: Dataset[_]): RegressionMetrics = {
Expand All @@ -64,8 +103,17 @@ private[op] class OpRegressionEvaluator
val r2 = getRegEvaluatorMetric(RegressionEvalMetrics.R2, dataUse, default = 0.0)
val mae = getRegEvaluatorMetric(RegressionEvalMetrics.MeanAbsoluteError, dataUse, default = 0.0)

val histogram = calculateSignedPercentageErrorHistogram(dataUse)

val metrics = RegressionMetrics(
RootMeanSquaredError = rmse, MeanSquaredError = mse, R2 = r2, MeanAbsoluteError = mae
RootMeanSquaredError = rmse,
MeanSquaredError = mse,
R2 = r2,
MeanAbsoluteError = mae,
SignedPercentageErrorHistogram = SignedPercentageErrorHistogram(
bins = $(signedPercentageErrorHistogramBins).toArray,
counts = histogram
)
)

log.info("Evaluated metrics: {}", metrics.toString)
Expand All @@ -86,6 +134,61 @@ private[op] class OpRegressionEvaluator
.setMetricName(metricName.sparkEntryName)
.evaluateOrDefault(dataUse, default = default)
}

/**
* Gets the histogram of the signed percentage errors
*
* @param data Data to use
* @return Sequence of bin counts of the histogram
*/
private def calculateSignedPercentageErrorHistogram(data: Dataset[_]): Array[Long] = {
// Prep data
val predictionsAndLabels = data
.select(col(getPredictionValueCol).cast(DoubleType), col(getLabelCol).cast(DoubleType))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }

// If we need to set the scaledErrorCutoff smartly, use the label data for that
if (isDefined(smartCutoffRatio)) {
val smartCutoff = calculateSmartCutoff(predictionsAndLabels)
log.info(s"Smart scaledErrorCutoff was determined to be: $smartCutoff")
setScaledErrorCutoff(smartCutoff)
}

val cutoff = $(scaledErrorCutoff)
val errors: RDD[Double] = predictionsAndLabels
.map(x => calculateSignedPercentageError(x._1, x._2, cutoff))
errors.histogram($(signedPercentageErrorHistogramBins))
}

/**
* Smartly calculates the scaledErrorCutoff
*
* @param predictionAndLabels Data set containing predictions and labels
* @return Suggested cutoff level for scaledErrorCutoff
*/
protected def calculateSmartCutoff(predictionAndLabels: RDD[(Double, Double)]): Double = {
val meanAbsoluteLabel = predictionAndLabels.map(_._2.abs).mean()
// Take the max with scaledErrorCutoff to avoid a cutoff of 0 if labels are all 0.
($(smartCutoffRatio) * meanAbsoluteLabel) max $(scaledErrorCutoff)
}

/**
* Calculates the signed percentage error, with cutoff logic to avoid large results
* due to division by small numbers.
*
* @param prediction Predicted value
* @param label Actual value
* @return Signed percentage error
*/
private def calculateSignedPercentageError(
prediction: Double,
label: Double,
scaledErrorCutoff: Double
): Double = {
100.0 * (prediction - label) / (label.abs max scaledErrorCutoff)
}

}


Expand All @@ -96,11 +199,28 @@ private[op] class OpRegressionEvaluator
* @param MeanSquaredError
* @param R2
* @param MeanAbsoluteError
* @param SignedPercentageErrorHistogram
*/
case class RegressionMetrics
(
RootMeanSquaredError: Double,
MeanSquaredError: Double,
R2: Double,
MeanAbsoluteError: Double
MeanAbsoluteError: Double,
SignedPercentageErrorHistogram: SignedPercentageErrorHistogram
) extends EvaluationMetrics


/**
* Histogram of signed percentage errors
*
* @param bins Histogram bins, where for example [-1, 0, 1] refer to bins [-1, 0), [0, 1]
* @param counts Histogram counts (length of bins parameter - 1)
*/
case class SignedPercentageErrorHistogram
(
@JsonDeserialize(contentAs = classOf[java.lang.Double])
bins: Seq[Double],
@JsonDeserialize(contentAs = classOf[java.lang.Long])
counts: Seq[Long]
) extends EvaluationMetrics
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.{Assertion, FlatSpec}
import org.scalatest.junit.JUnitRunner


Expand Down Expand Up @@ -158,6 +158,7 @@ class OpBinaryClassificationEvaluatorTest extends FlatSpec with TestSparkContext
recall shouldBe metrics.Recall
f1 shouldBe metrics.F1
1.0 - sparkMulticlassEvaluator.setMetricName(Error.sparkEntryName).evaluate(flattenedData2) shouldBe metrics.Error
assertThresholdsNotEmpty(metrics)
}

it should "evaluate the metrics with one prediction input" in {
Expand All @@ -177,6 +178,7 @@ class OpBinaryClassificationEvaluatorTest extends FlatSpec with TestSparkContext
metrics.Precision shouldBe precision
metrics.Recall shouldBe recall
metrics.F1 shouldBe f1
assertThresholdsNotEmpty(metrics)
}

it should "evaluate the metrics on dataset with only the label and prediction 0" in {
Expand All @@ -193,6 +195,7 @@ class OpBinaryClassificationEvaluatorTest extends FlatSpec with TestSparkContext
metricsZero.Precision shouldBe 0.0
metricsZero.Recall shouldBe 0.0
metricsZero.Error shouldBe 0.0
assertThresholdsNotEmpty(metricsZero)
}


Expand All @@ -208,6 +211,14 @@ class OpBinaryClassificationEvaluatorTest extends FlatSpec with TestSparkContext
metricsOne.Precision shouldBe 1.0
metricsOne.Recall shouldBe 1.0
metricsOne.Error shouldBe 0.0
assertThresholdsNotEmpty(metricsOne)
}

private def assertThresholdsNotEmpty(metrics: BinaryClassificationMetrics): Assertion = {
metrics.thresholds should not be empty
metrics.precisionByThreshold should not be empty
metrics.recallByThreshold should not be empty
metrics.falsePositiveRateByThreshold should not be empty
}

private def getPosNegValues(rdd: RDD[Row]): (Double, Double, Double, Double, Double, Double, Double) = {
Expand Down
Loading

0 comments on commit d0b1038

Please sign in to comment.