Skip to content

Commit 1878792

Browse files
manjanacryanwilson
andauthored
Custom model (#6594)
* Initial commit for CustomModel class. * Initial commit for CustomModel class. * API scaffolding for CustomModel and ModelDownloader classes * API scaffolding for CustomModel and ModelDownloader classes * Delete ModelDownloadConditions.swift * Changed Swift structs to NSObject subclasses for objc visibility * Design for a Swift-first SDK * Design for a Swift-first SDK * Design for a Swift-first SDK * Updated design w/ errors and download progress handler * Style compliance * Refactor Test folder structure * Refactor Test folder structure * Fix pod lint errors * Fix pod lint errors * Add xcscheme for SwiftPM builds. * Disable catalyst temporarily, re-enable issue #6790 * Better classification of download errors * Documentation comments for custom model * Refactor errors for model downloading Co-authored-by: Ryan Wilson <wilsonryan@google.com>
1 parent 4ba3d42 commit 1878792

File tree

6 files changed

+232
-15
lines changed

6 files changed

+232
-15
lines changed

.github/workflows/mlmodeldownloader.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ jobs:
4646
runs-on: macOS-latest
4747
strategy:
4848
matrix:
49-
target: [tvOS, macOS, catalyst]
49+
# TODO: manjanac@ Add catalyst back here.
50+
target: [tvOS, macOS]
5051
steps:
5152
- uses: actions/checkout@v2
5253
- name: Xcode 12
@@ -58,8 +59,9 @@ jobs:
5859

5960
catalyst:
6061
# Don't run on private repo unless it is a PR.
61-
if: github.repository != 'FirebasePrivate/firebase-ios-sdk' || github.event_name == 'pull_request'
62-
62+
# TODO: manjanac@ Uncomment line below to re-enable catalyst.
63+
# if: github.repository != 'FirebasePrivate/firebase-ios-sdk' || github.event_name == 'pull_request'
64+
if: false
6365
runs-on: macOS-latest
6466
steps:
6567
- uses: actions/checkout@v2
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<Scheme
3+
LastUpgradeVersion = "1200"
4+
version = "1.3">
5+
<BuildAction
6+
parallelizeBuildables = "YES"
7+
buildImplicitDependencies = "YES">
8+
</BuildAction>
9+
<TestAction
10+
buildConfiguration = "Debug"
11+
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
12+
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
13+
shouldUseLaunchSchemeArgsEnv = "YES">
14+
<Testables>
15+
<TestableReference
16+
skipped = "NO">
17+
<BuildableReference
18+
BuildableIdentifier = "primary"
19+
BlueprintIdentifier = "FirebaseMLModelDownloaderUnit"
20+
BuildableName = "FirebaseMLModelDownloaderUnit"
21+
BlueprintName = "FirebaseMLModelDownloaderUnit"
22+
ReferencedContainer = "container:">
23+
</BuildableReference>
24+
</TestableReference>
25+
</Testables>
26+
</TestAction>
27+
<LaunchAction
28+
buildConfiguration = "Debug"
29+
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
30+
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
31+
launchStyle = "0"
32+
useCustomWorkingDirectory = "NO"
33+
ignoresPersistentStateOnLaunch = "NO"
34+
debugDocumentVersioning = "YES"
35+
debugServiceExtension = "internal"
36+
allowLocationSimulation = "YES">
37+
</LaunchAction>
38+
<ProfileAction
39+
buildConfiguration = "Release"
40+
shouldUseLaunchSchemeArgsEnv = "YES"
41+
savedToolIdentifier = ""
42+
useCustomWorkingDirectory = "NO"
43+
debugDocumentVersioning = "YES">
44+
</ProfileAction>
45+
<AnalyzeAction
46+
buildConfiguration = "Debug">
47+
</AnalyzeAction>
48+
<ArchiveAction
49+
buildConfiguration = "Release"
50+
revealArchiveInOrganizer = "YES">
51+
</ArchiveAction>
52+
</Scheme>

FirebaseMLModelDownloader/Tests/Unit/MLDownloaderTests.swift renamed to FirebaseMLModelDownloader/Sources/CustomModel.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import XCTest
16-
@testable import FirebaseMLModelDownloader
15+
import Foundation
1716

18-
final class MLModelDownloaderTests: XCTestCase {
19-
func testExample() {
20-
// This is an example of a functional test case.
21-
// Use XCTAssert and related functions to verify your tests produce the correct
22-
// results.
23-
XCTAssertEqual(MLModelDownloader().text, "Hello, World!")
24-
}
17+
/// A custom model that is stored remotely on the server and downloaded to the device.
18+
public struct CustomModel: Hashable {
19+
/// Name of the model.
20+
public let name: String
21+
/// Size of the custom model, provided by the server.
22+
public let size: Int
23+
/// Path where the model is stored on device.
24+
public let path: String
25+
/// Hash for the model, used for model verification.
26+
public let hash: String
2527
}

FirebaseMLModelDownloader/Sources/MLModelDownloader.swift renamed to FirebaseMLModelDownloader/Sources/ModelDownloadConditions.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
public struct MLModelDownloader {
16-
public var text = "Hello, World!"
17-
}
15+
import Foundation
16+
17+
/// Model download conditions.
18+
public struct ModelDownloadConditions {}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright 2020 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Possible errors with model downloading.
18+
public enum DownloadError: Error {
19+
/// No model with this name found on server.
20+
case notFound
21+
/// Caller does not have necessary permissions for this operation.
22+
case permissionDenied
23+
/// Conditions not met to perform download.
24+
case failedPrecondition
25+
/// Not enough space for model on device.
26+
case notEnoughSpace
27+
/// Malformed model name.
28+
case invalidArgument
29+
/// Other errors with description.
30+
case internalError(description: String)
31+
}
32+
33+
/// Possible errors with locating model on device.
34+
public enum DownloadedModelError: Error {
35+
/// File system error.
36+
case fileIOError
37+
/// Model not found on device.
38+
case notFound
39+
}
40+
41+
/// Possible ways to get a custom model.
42+
public enum ModelDownloadType {
43+
/// Get local model stored on device.
44+
case localModel
45+
/// Get local model on device and update to latest model from server in the background.
46+
case localModelUpdateInBackground
47+
/// Get latest model from server.
48+
case latestModel
49+
}
50+
51+
/// Downloader to manage custom model downloads.
52+
public struct ModelDownloader {
53+
/// Downloads a custom model to device or gets a custom model already on device, w/ optional handler for progress.
54+
public func getModel(name modelName: String, downloadType: ModelDownloadType,
55+
conditions: ModelDownloadConditions,
56+
progressHandler: ((Float) -> Void)? = nil,
57+
completion: @escaping (Result<CustomModel, DownloadError>) -> Void) {
58+
// TODO: Model download
59+
let modelSize = Int()
60+
let modelPath = String()
61+
let modelHash = String()
62+
63+
let customModel = CustomModel(
64+
name: modelName,
65+
size: modelSize,
66+
path: modelPath,
67+
hash: modelHash
68+
)
69+
completion(.success(customModel))
70+
completion(.failure(.notFound))
71+
}
72+
73+
/// Gets all downloaded models.
74+
public func listDownloadedModels(completion: @escaping (Result<Set<CustomModel>,
75+
DownloadedModelError>) -> Void) {
76+
let customModels = Set<CustomModel>()
77+
// TODO: List downloaded models
78+
completion(.success(customModels))
79+
completion(.failure(.notFound))
80+
}
81+
82+
/// Deletes a custom model from device.
83+
public func deleteDownloadedModel(name modelName: String,
84+
completion: @escaping (Result<Void, DownloadedModelError>)
85+
-> Void) {
86+
// TODO: Delete previously downloaded model
87+
completion(.success(()))
88+
completion(.failure(.notFound))
89+
}
90+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright 2020 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import FirebaseMLModelDownloader
17+
18+
final class ModelDownloaderTests: XCTestCase {
19+
func testExample() {
20+
// This is an example of a functional test case.
21+
// Use XCTAssert and related functions to verify your tests produce the correct
22+
// results.
23+
let modelDownloader = ModelDownloader()
24+
let conditions = ModelDownloadConditions()
25+
26+
// Download model w/ progress handler
27+
modelDownloader.getModel(
28+
name: "your_model_name",
29+
downloadType: .latestModel,
30+
conditions: conditions,
31+
progressHandler: { progress in
32+
// Handle progress
33+
}
34+
) { result in
35+
switch result {
36+
case .success:
37+
// Use model with your inference API
38+
// let interpreter = Interpreter(modelPath: customModel.modelPath)
39+
break
40+
case .failure:
41+
// Handle download error
42+
break
43+
}
44+
}
45+
46+
// Access array of downloaded models
47+
modelDownloader.listDownloadedModels { result in
48+
switch result {
49+
case .success:
50+
// Pick model(s) for further use
51+
break
52+
case .failure:
53+
// Handle failure
54+
break
55+
}
56+
}
57+
58+
// Delete downloaded model
59+
modelDownloader.deleteDownloadedModel(name: "your_model_name") { result in
60+
switch result {
61+
case .success():
62+
// Apply any other clean up
63+
break
64+
case .failure:
65+
// Handle failure
66+
break
67+
}
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)