Skip to content

Commit 70ac3f6

Browse files
gerashegalovsanmitra
authored andcommitted
Fix writing of model in the local file system
1 parent 41c9a2b commit 70ac3f6

File tree

3 files changed

+46
-24
lines changed

3 files changed

+46
-24
lines changed

core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ private object WorkflowFileReader {
268268
val rawModel = "rawModel"
269269
val zipModel = "Model.zip"
270270
def modelStagingDir: String = s"modelStagingDir/model-${System.currentTimeMillis}"
271+
val confWithDefaultCodec = new Configuration(false)
272+
val codecFactory = new CompressionCodecFactory(confWithDefaultCodec)
271273

272274
def loadFile(pathString: String)(implicit conf: Configuration): String = {
273275
Try {
@@ -290,11 +292,10 @@ private object WorkflowFileReader {
290292
}
291293

292294
private def readAsString(path: Path)(implicit conf: Configuration): String = {
293-
val codecFactory = new CompressionCodecFactory(conf)
294-
val codec = Option(codecFactory.getCodec(path))
295295
val in = FileSystem.getLocal(conf).open(path)
296296
try {
297-
val read = codec.map(c => Source.fromInputStream(c.createInputStream(in)).mkString)
297+
val read = Option(codecFactory.getCodec(path))
298+
.map(c => Source.fromInputStream(c.createInputStream(in)).mkString)
298299
.getOrElse(IOUtils.toString(in, "UTF-8"))
299300
read
300301
} finally {

core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala

+35-20
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,14 @@
3131
package com.salesforce.op
3232

3333
import java.io.File
34+
import java.nio.charset.StandardCharsets
3435

3536
import com.salesforce.op.features.FeatureJsonHelper
3637
import com.salesforce.op.filters.RawFeatureFilterResults
3738
import com.salesforce.op.stages.{OPStage, OpPipelineStageWriter}
38-
import com.salesforce.op.utils.spark.{JobGroupUtil, OpStep}
3939
import enumeratum._
4040
import org.apache.hadoop.conf.Configuration
4141
import org.apache.hadoop.fs.{FileSystem, Path}
42-
import org.apache.hadoop.io.compress.GzipCodec
4342
import org.apache.spark.ml.util.MLWriter
4443
import org.json4s.JsonAST.{JArray, JObject, JString}
4544
import org.json4s.JsonDSL._
@@ -59,11 +58,39 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
5958

6059
implicit val jsonFormats: Formats = DefaultFormats
6160

61+
protected var modelStagingDir: String = WorkflowFileReader.modelStagingDir
62+
63+
/**
64+
* Set the local folder to copy and unpack stored model to for loading
65+
*/
66+
def setModelStagingDir(localDir: String): this.type = {
67+
modelStagingDir = localDir
68+
this
69+
}
70+
6271
override protected def saveImpl(path: String): Unit = {
63-
JobGroupUtil.withJobGroup(OpStep.ModelIO) {
64-
sc.parallelize(Seq(toJsonString(path)), 1)
65-
.saveAsTextFile(OpWorkflowModelReadWriteShared.jsonPath(path), classOf[GzipCodec])
66-
}(this.sparkSession)
72+
val conf = new Configuration()
73+
val localFileSystem = FileSystem.getLocal(conf)
74+
val localPath = localFileSystem.makeQualified(new Path(modelStagingDir))
75+
localFileSystem.delete(localPath, true)
76+
val raw = new Path(localPath, WorkflowFileReader.rawModel)
77+
78+
val rawPathStr = raw.toString
79+
val modelJson = toJsonString(rawPathStr)
80+
val jsonPath = OpWorkflowModelReadWriteShared.jsonPath(rawPathStr)
81+
val os = localFileSystem.create(new Path(jsonPath))
82+
try {
83+
os.write(modelJson.getBytes(StandardCharsets.UTF_8.toString))
84+
} finally {
85+
os.close()
86+
}
87+
88+
val compressed = new Path(localPath, WorkflowFileReader.zipModel)
89+
ZipUtil.pack(new File(raw.toUri.getPath), new File(compressed.toUri.getPath))
90+
91+
val finalPath = new Path(path, WorkflowFileReader.zipModel)
92+
val destinationFileSystem = finalPath.getFileSystem(conf)
93+
destinationFileSystem.moveFromLocalFile(compressed, finalPath)
6794
}
6895

6996
/**
@@ -207,21 +234,9 @@ object OpWorkflowModelWriter {
207234
overwrite: Boolean = true,
208235
modelStagingDir: String = WorkflowFileReader.modelStagingDir
209236
): Unit = {
210-
val localPath = new Path(modelStagingDir)
211-
val conf = new Configuration()
212-
val localFileSystem = FileSystem.getLocal(conf)
213-
if (overwrite) localFileSystem.delete(localPath, true)
214-
val raw = new Path(modelStagingDir, WorkflowFileReader.rawModel)
215-
216-
val w = new OpWorkflowModelWriter(model)
237+
val w = new OpWorkflowModelWriter(model).setModelStagingDir(modelStagingDir)
217238
val writer = if (overwrite) w.overwrite() else w
218-
writer.save(raw.toString)
219-
val compressed = new Path(modelStagingDir, WorkflowFileReader.zipModel)
220-
ZipUtil.pack(new File(raw.toString), new File(compressed.toString))
221-
222-
val finalPath = new Path(path, WorkflowFileReader.zipModel)
223-
val destinationFileSystem = finalPath.getFileSystem(conf)
224-
destinationFileSystem.moveFromLocalFile(compressed, finalPath)
239+
writer.save(path)
225240
}
226241

227242
/**

core/src/test/scala/com/salesforce/op/OpWorkflowRunnerTest.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,13 @@ class OpWorkflowRunnerTest extends AsyncFlatSpec with PassengerSparkFixtureTest
232232
dirFile.isDirectory shouldBe true
233233
// TODO: maybe do a thorough files inspection here
234234
val files = FileUtils.listFiles(dirFile, null, true)
235-
files.asScala.map(_.toString).exists(_.contains("_SUCCESS")) shouldBe true
235+
val fileNames = files.asScala.map(_.getName)
236+
if (outFile.getAbsolutePath.endsWith("/model")) {
237+
fileNames should contain ("op-model.json")
238+
}
239+
else {
240+
fileNames should contain ("_SUCCESS")
241+
}
236242
files.size > 1
237243
}
238244
res shouldBe a[R]

0 commit comments

Comments
 (0)