Skip to content

Commit

Permalink
Add confusion matrix (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifjiang authored Dec 10, 2020
1 parent 13ad9cd commit 91724f1
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ import com.twitter.algebird.Operators._
import com.twitter.algebird.Tuple2Semigroup
import com.salesforce.op.utils.spark.RichEvaluator._
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{DoubleArrayParam, IntArrayParam}
import org.apache.spark.ml.linalg.{Vector, DenseVector}
import org.apache.spark.ml.param.{DoubleArrayParam, IntArrayParam, IntParam, ParamValidators}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{Dataset, Row}
import org.slf4j.LoggerFactory
import scala.collection.Searching._


/**
* Instance to evaluate Multi Classification metrics
Expand Down Expand Up @@ -94,6 +96,35 @@ private[op] class OpMultiClassificationEvaluator

def setThresholds(v: Array[Double]): this.type = set(thresholds, v)

final val confMatrixNumClasses = new IntParam(
parent = this,
name = "confMatrixNumClasses",
doc = "# of the top most frequent classes used for confusion matrix metrics",
isValid = ParamValidators.inRange(1, 30, lowerInclusive = true, upperInclusive = true)
)
setDefault(confMatrixNumClasses, 15)

def setConfMatrixNumClasses(v: Int): this.type = set(confMatrixNumClasses, v)

final val confMatrixMinSupport = new IntParam(
parent = this,
name = "confMatrixMinSupport",
doc = "# of the top most frequent misclassified classes in each label/prediction category",
isValid = ParamValidators.inRange(1, 10, lowerInclusive = false, upperInclusive = true)
)
setDefault(confMatrixMinSupport, 5)

def setConfMatrixMinSupport(v: Int): this.type = set(confMatrixMinSupport, v)

final val confMatrixThresholds = new DoubleArrayParam(
parent = this,
name = "confMatrixThresholds",
doc = "sequence of threshold values used for confusion matrix metrics",
isValid = _.forall(x => x >= 0.0 && x < 1.0)
)
setDefault(confMatrixThresholds, Array(0.0, 0.2, 0.4, 0.6, 0.8))
def setConfMatrixThresholds(v: Array[Double]): this.type = set(confMatrixThresholds, v)

override def evaluateAll(data: Dataset[_]): MultiClassificationMetrics = {
val labelColName = getLabelCol
val dataUse = makeDataToUse(data, labelColName)
Expand All @@ -112,7 +143,9 @@ private[op] class OpMultiClassificationEvaluator
log.warn("The dataset is empty. Returning empty metrics.")
MultiClassificationMetrics(0.0, 0.0, 0.0, 0.0,
MulticlassThresholdMetrics(Seq.empty, Seq.empty, Map.empty, Map.empty, Map.empty),
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty))
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
MulticlassConfMatrixMetricsByThreshold($(confMatrixNumClasses), Seq.empty, $(confMatrixThresholds), Seq.empty),
MisClassificationMetrics($(confMatrixMinSupport), Seq.empty, Seq.empty))
} else {
val multiclassMetrics = new MulticlassMetrics(rdd)
val error = 1.0 - multiclassMetrics.accuracy
Expand All @@ -133,20 +166,183 @@ private[op] class OpMultiClassificationEvaluator
topKs = $(topKs)
)

val rddCm = dataUse.select(col(labelColName), col(predictionColName), col(probabilityColName)).rdd.map{
case Row(label: Double, pred: Double, prob: DenseVector) => (label, pred, prob.toArray)
}
val confusionMatrixByThreshold = calculateConfMatrixMetricsByThreshold(rddCm)
val misClassifications = calculateMisClassificationMetrics( rddCm.map{ case (label, pred, _) => (label, pred)} )

val metrics = MultiClassificationMetrics(
Precision = precision,
Recall = recall,
F1 = f1,
Error = error,
ThresholdMetrics = thresholdMetrics,
TopKMetrics = topKMetrics
TopKMetrics = topKMetrics,
ConfusionMatrixMetrics = confusionMatrixByThreshold,
MisClassificationMetrics = misClassifications
)

log.info("Evaluated metrics: {}", metrics.toString)
metrics
}
}

/**
* function to construct the confusion matrix for the top n most occurring labels
* @param labelPredictionCtRDD RDD of ((label, prediction, confidence), count)
* @param cmClasses the top n most occurring classes, sorted by counts in descending order
* @return an array of counts
*/
def constructConfusionMatrix(
labelPredictionCtRDD: RDD[((Double, Double, Double), Long)],
cmClasses: Seq[Double]): Seq[Long] = {

val confusionMatrixMap = labelPredictionCtRDD.map {
case ((label, prediction, _), count) => ((label, prediction), count)
}.reduceByKey(_ + _).collectAsMap()

for {
label <- cmClasses
prediction <- cmClasses
} yield {
confusionMatrixMap.getOrElse((label, prediction), 0L)
}
}

private[evaluators] object SearchHelper extends Serializable{

/**
* function to search the confidence threshold corresponding to a probability score
*
* @param arr a sorted array of confidence thresholds
* @param element the probability score to be searched
* @return the confidence threshold corresponding of the element. It equals to the element if there is an exact
* match. Otherwise it's the element right before the insertion point.
*/
def findThreshold(arr: IndexedSeq[Double], element: Double): Double = {
require(!arr.isEmpty, "Array of confidence thresholds can't be empty!")
if (element > arr.last) arr.last
else if (element < arr.head) 0.0
else {
val insertionPoint = new SearchImpl(arr).search(element).insertionPoint
val insertionPointValue = arr(insertionPoint)
if (element == insertionPointValue) insertionPointValue
else arr(insertionPoint-1)
}
}
}

/**
* function to calculate confusion matrix for TopK most occurring labels by confidence threshold
*
* @param data RDD of (label, prediction, prediction probability vector)
* @return a MulticlassConfMatrixMetricsByThreshold instance
*/
def calculateConfMatrixMetricsByThreshold(
data: RDD[(Double, Double, Array[Double])]): MulticlassConfMatrixMetricsByThreshold = {

val labelCountsRDD = data.map { case (label, _, _) => (label, 1L) }.reduceByKey(_ + _)
val cmClasses = labelCountsRDD.sortBy(-_._2).map(_._1).take($(confMatrixNumClasses)).toSeq
val cmClassesSet = cmClasses.toSet

val dataTopNLabels = data.filter { case (label, prediction, _) =>
cmClassesSet.contains(label) && cmClassesSet.contains(prediction)
}

val sortedThresholds = $(confMatrixThresholds).sorted.toIndexedSeq

// reduce data to a coarser RDD (with size N * N * thresholds at most) for further aggregation
val labelPredictionConfidenceCountRDD = dataTopNLabels.map{
case (label, prediction, proba) => {
( (label, prediction, SearchHelper.findThreshold(sortedThresholds, proba.max)), 1L )
}
}.reduceByKey(_ + _)

labelPredictionConfidenceCountRDD.persist()

val cmByThreshold = sortedThresholds.map( threshold => {
val filteredRDD = labelPredictionConfidenceCountRDD.filter {
case ((_, _, confidence), _) => confidence >= threshold
}
constructConfusionMatrix(filteredRDD, cmClasses)
})

labelPredictionConfidenceCountRDD.unpersist()

MulticlassConfMatrixMetricsByThreshold(
ConfMatrixNumClasses = $(confMatrixNumClasses),
ConfMatrixClassIndices = cmClasses,
ConfMatrixThresholds = $(confMatrixThresholds),
ConfMatrices = cmByThreshold
)
}

/**
* function to calculate the mostly frequently mis-classified classes for each label/prediction category
*
* @param data RDD of (label, prediction)
* @return a MisClassificationMetrics instance
*/
def calculateMisClassificationMetrics(data: RDD[(Double, Double)]): MisClassificationMetrics = {

val labelPredictionCountRDD = data.map {
case (label, prediction) => ((label, prediction), 1L) }
.reduceByKey(_ + _)

val misClassificationsByLabel = labelPredictionCountRDD.map {
case ((label, prediction), count) => (label, Seq((prediction, count)))
}.reduceByKey(_ ++ _)
.map { case (label, predictionCountsIter) => {
val misClassificationCtMap = predictionCountsIter
.filter { case (pred, _) => pred != label }
.sortBy(-_._2)
.take($(confMatrixMinSupport)).toMap

val labelCount = predictionCountsIter.map(_._2).reduce(_ + _)
val correctCount = predictionCountsIter
.collect { case (pred, count) if pred == label => count }
.reduceOption(_ + _).getOrElse(0L)

MisClassificationsPerCategory(
Category = label,
TotalCount = labelCount,
CorrectCount = correctCount,
MisClassifications = misClassificationCtMap
)
}
}.sortBy(-_.TotalCount).collect()

val misClassificationsByPrediction = labelPredictionCountRDD.map {
case ((label, prediction), count) => (prediction, Seq((label, count)))
}.reduceByKey(_ ++ _)
.map { case (prediction, labelCountsIter) => {
val sortedMisclassificationCt = labelCountsIter
.filter { case (label, _) => label != prediction }
.sortBy(-_._2)
.take($(confMatrixMinSupport)).toMap

val predictionCount = labelCountsIter.map(_._2).reduce(_ + _)
val correctCount = labelCountsIter
.collect { case (label, count) if label == prediction => count }
.reduceOption(_ + _).getOrElse(0L)

MisClassificationsPerCategory(
Category = prediction,
TotalCount = predictionCount,
CorrectCount = correctCount,
MisClassifications = sortedMisclassificationCt
)
}
}.sortBy(-_.TotalCount).collect()

MisClassificationMetrics(
ConfMatrixMinSupport = $(confMatrixMinSupport),
MisClassificationsByLabel = misClassificationsByLabel,
MisClassificationsByPrediction = misClassificationsByPrediction
)
}

/**
* Function that calculates Multi Classification Metrics for different topK most occuring labels given an RDD
* of scores & labels, and a list of topK values to consider.
Expand Down Expand Up @@ -187,7 +383,6 @@ private[op] class OpMultiClassificationEvaluator
)
}


/**
* Function that calculates a set of threshold metrics for different topN values given an RDD of scores & labels,
* a list of topN values to consider, and a list of thresholds to use.
Expand Down Expand Up @@ -308,7 +503,6 @@ private[op] class OpMultiClassificationEvaluator
.setMetricName(metricName.sparkEntryName)
.evaluateOrDefault(dataUse, default = default)
}

}


Expand All @@ -328,7 +522,9 @@ case class MultiClassificationMetrics
F1: Double,
Error: Double,
ThresholdMetrics: MulticlassThresholdMetrics,
TopKMetrics: MultiClassificationMetricsTopK
TopKMetrics: MultiClassificationMetricsTopK,
ConfusionMatrixMetrics: MulticlassConfMatrixMetricsByThreshold,
MisClassificationMetrics: MisClassificationMetrics
) extends EvaluationMetrics

/**
Expand All @@ -352,6 +548,56 @@ case class MultiClassificationMetricsTopK
Error: Seq[Double]
) extends EvaluationMetrics

/**
* Metrics for multi-class confusion matrix. It captures confusion matrix of records of which
* 1) the labels belong to the top n most occurring classes (n = confMatrixNumClasses)
* 2) the top predicted probability exceeds a certain threshold in confMatrixThresholds
*
* @param confMatrixNumClasses value of the top n most occurring classes in the dataset
* @param confMatrixClassIndices label index of the top n most occuring classes
* @param confMatrixThresholds a sequence of thresholds
* @param confMatrices a sequence of counts that stores the confusion matrix for each threshold in confMatrixThresholds
*/
case class MulticlassConfMatrixMetricsByThreshold
(
ConfMatrixNumClasses: Int,
@JsonDeserialize(contentAs = classOf[java.lang.Double])
ConfMatrixClassIndices: Seq[Double],
@JsonDeserialize(contentAs = classOf[java.lang.Double])
ConfMatrixThresholds: Seq[Double],
ConfMatrices: Seq[Seq[Long]]
) extends EvaluationMetrics

/**
* Multiclass mis-classification metrics, including the top n (n = confMatrixMinSupport) most frequently
* mis-classified classes for each label or prediction category.
*
*/
case class MisClassificationMetrics
(
ConfMatrixMinSupport: Int,
MisClassificationsByLabel: Seq[MisClassificationsPerCategory],
MisClassificationsByPrediction: Seq[MisClassificationsPerCategory]
)

/**
* container to store the most frequently mis-classified classes for each label/prediction category
*
* @param category a category which a record's label or prediction equals to
* @param totalCount total # of records in that category
* @param correctCount # of correctly predicted records in that category
* @param misClassifications the top n most frequently misclassified classes (n = confMatrixMinSupport) and
* their respective counts in that category. Ordered by counts in descending order.
*/
case class MisClassificationsPerCategory
(
Category: Double,
TotalCount: Long,
CorrectCount: Long,
@JsonDeserialize(keyAs = classOf[java.lang.Double])
MisClassifications: Map[Double, Long]
)

/**
* Threshold-based metrics for multiclass classification
*
Expand Down
Loading

0 comments on commit 91724f1

Please sign in to comment.