31
31
package com .salesforce .op
32
32
33
33
import java .io .File
34
+ import java .nio .charset .StandardCharsets
34
35
35
36
import com .salesforce .op .features .FeatureJsonHelper
36
37
import com .salesforce .op .filters .RawFeatureFilterResults
37
38
import com .salesforce .op .stages .{OPStage , OpPipelineStageWriter }
38
- import com .salesforce .op .utils .spark .{JobGroupUtil , OpStep }
39
39
import enumeratum ._
40
40
import org .apache .hadoop .conf .Configuration
41
41
import org .apache .hadoop .fs .{FileSystem , Path }
42
- import org .apache .hadoop .io .compress .GzipCodec
43
42
import org .apache .spark .ml .util .MLWriter
44
43
import org .json4s .JsonAST .{JArray , JObject , JString }
45
44
import org .json4s .JsonDSL ._
@@ -59,11 +58,39 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
59
58
60
59
implicit val jsonFormats : Formats = DefaultFormats
61
60
61
+ protected var modelStagingDir : String = WorkflowFileReader .modelStagingDir
62
+
63
+ /**
64
+ * Set the local folder to copy and unpack stored model to for loading
65
+ */
66
+ def setModelStagingDir (localDir : String ): this .type = {
67
+ modelStagingDir = localDir
68
+ this
69
+ }
70
+
62
71
override protected def saveImpl (path : String ): Unit = {
63
- JobGroupUtil .withJobGroup(OpStep .ModelIO ) {
64
- sc.parallelize(Seq (toJsonString(path)), 1 )
65
- .saveAsTextFile(OpWorkflowModelReadWriteShared .jsonPath(path), classOf [GzipCodec ])
66
- }(this .sparkSession)
72
+ val conf = new Configuration ()
73
+ val localFileSystem = FileSystem .getLocal(conf)
74
+ val localPath = localFileSystem.makeQualified(new Path (modelStagingDir))
75
+ localFileSystem.delete(localPath, true )
76
+ val raw = new Path (localPath, WorkflowFileReader .rawModel)
77
+
78
+ val rawPathStr = raw.toString
79
+ val modelJson = toJsonString(rawPathStr)
80
+ val jsonPath = OpWorkflowModelReadWriteShared .jsonPath(rawPathStr)
81
+ val os = localFileSystem.create(new Path (jsonPath))
82
+ try {
83
+ os.write(modelJson.getBytes(StandardCharsets .UTF_8 .toString))
84
+ } finally {
85
+ os.close()
86
+ }
87
+
88
+ val compressed = new Path (localPath, WorkflowFileReader .zipModel)
89
+ ZipUtil .pack(new File (raw.toUri.getPath), new File (compressed.toUri.getPath))
90
+
91
+ val finalPath = new Path (path, WorkflowFileReader .zipModel)
92
+ val destinationFileSystem = finalPath.getFileSystem(conf)
93
+ destinationFileSystem.moveFromLocalFile(compressed, finalPath)
67
94
}
68
95
69
96
/**
@@ -207,21 +234,9 @@ object OpWorkflowModelWriter {
207
234
overwrite : Boolean = true ,
208
235
modelStagingDir : String = WorkflowFileReader .modelStagingDir
209
236
): Unit = {
210
- val localPath = new Path (modelStagingDir)
211
- val conf = new Configuration ()
212
- val localFileSystem = FileSystem .getLocal(conf)
213
- if (overwrite) localFileSystem.delete(localPath, true )
214
- val raw = new Path (modelStagingDir, WorkflowFileReader .rawModel)
215
-
216
- val w = new OpWorkflowModelWriter (model)
237
+ val w = new OpWorkflowModelWriter (model).setModelStagingDir(modelStagingDir)
217
238
val writer = if (overwrite) w.overwrite() else w
218
- writer.save(raw.toString)
219
- val compressed = new Path (modelStagingDir, WorkflowFileReader .zipModel)
220
- ZipUtil .pack(new File (raw.toString), new File (compressed.toString))
221
-
222
- val finalPath = new Path (path, WorkflowFileReader .zipModel)
223
- val destinationFileSystem = finalPath.getFileSystem(conf)
224
- destinationFileSystem.moveFromLocalFile(compressed, finalPath)
239
+ writer.save(path)
225
240
}
226
241
227
242
/**
0 commit comments