Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# 11.2.0
- [fixed] Resolved a decoding error for citations without a `uri` and added
support for decoding `title` fields, which were previously ignored. (#13518)
- [changed] **Breaking Change**: The methods for starting streaming requests
(`generateContentStream` and `sendMessageStream`) and creating a chat instance
(`startChat`) are now asynchronous and must be called with `await`. (#13545)

# 10.29.0
- [feature] Added community support for watchOS. (#13215)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ class ConversationViewModel: ObservableObject {
}

private var model: GenerativeModel
private var chat: Chat
private var chat: Chat? = nil
private var stopGenerating = false

private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash")
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
error = nil
if chat == nil {
chat = await model.startChat()
}
if streaming {
await internalSendMessageStreaming(text)
} else {
Expand All @@ -52,7 +54,7 @@ class ConversationViewModel: ObservableObject {
func startNewChat() {
stop()
error = nil
chat = model.startChat()
chat = nil
messages.removeAll()
}

Expand All @@ -79,7 +81,10 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
let responseStream = chat.sendMessageStream(text)
guard let chat else {
throw ChatError.notInitialized
}
let responseStream = await chat.sendMessageStream(text)
for try await chunk in responseStream {
messages[messages.count - 1].pending = false
if let text = chunk.text {
Expand Down Expand Up @@ -112,10 +117,12 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
var response: GenerateContentResponse?
response = try await chat.sendMessage(text)
guard let chat = chat else {
throw ChatError.notInitialized
}
let response = try await chat.sendMessage(text)

if let responseText = response?.text {
if let responseText = response.text {
// replace pending message with backend response
messages[messages.count - 1].message = responseText
messages[messages.count - 1].pending = false
Expand All @@ -127,4 +134,8 @@ class ConversationViewModel: ObservableObject {
}
}
}

enum ChatError: Error {
case notInitialized
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject {
private var functionCalls = [FunctionCall]()

private var model: GenerativeModel
private var chat: Chat
private var chat: Chat? = nil

private var chatTask: Task<Void, Never>?

Expand Down Expand Up @@ -62,7 +62,6 @@ class FunctionCallingViewModel: ObservableObject {
),
])]
)
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
Expand All @@ -75,6 +74,10 @@ class FunctionCallingViewModel: ObservableObject {
busy = false
}

if chat == nil {
chat = await model.startChat()
}

// first, add the user's message to the chat
let userMessage = ChatMessage(message: text, participant: .user)
messages.append(userMessage)
Expand Down Expand Up @@ -103,7 +106,7 @@ class FunctionCallingViewModel: ObservableObject {
func startNewChat() {
stop()
error = nil
chat = model.startChat()
chat = nil
messages.removeAll()
}

Expand All @@ -114,14 +117,17 @@ class FunctionCallingViewModel: ObservableObject {

private func internalSendMessageStreaming(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
guard let chat else {
throw ChatError.notInitialized
}
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = chat.sendMessageStream(text)
responseStream = await chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = chat.sendMessageStream(functionResponses.modelContent())
responseStream = await chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
Expand All @@ -130,6 +136,9 @@ class FunctionCallingViewModel: ObservableObject {

private func internalSendMessage(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
guard let chat else {
throw ChatError.notInitialized
}
let response: GenerateContentResponse
if functionResponses.isEmpty {
response = try await chat.sendMessage(text)
Expand Down Expand Up @@ -181,6 +190,10 @@ class FunctionCallingViewModel: ObservableObject {
return functionResponses
}

enum ChatError: Error {
case notInitialized
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
}
}

let outputContentStream = model.generateContentStream(prompt, images)
let outputContentStream = await model.generateContentStream(prompt, images)

// stream response
for try await outputContent in outputContentStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {

let prompt = "Summarize the following text for me: \(inputText)"

let outputContentStream = model.generateContentStream(prompt)
let outputContentStream = await model.generateContentStream(prompt)

// stream response
for try await outputContent in outputContentStream {
Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Foundation
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
/// the context in memory between each message sent.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public class Chat {
public actor Chat {
private let model: GenerativeModel

/// Initializes a new chat representing a 1:1 conversation between model and user.
Expand Down Expand Up @@ -121,7 +121,7 @@ public class Chat {

// Send the history alongside the new message as context.
let request = history + newContent
let stream = model.generateContentStream(request)
let stream = await model.generateContentStream(request)
do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
Expand Down
44 changes: 21 additions & 23 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import Foundation
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
/// content based on various input types.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public final class GenerativeModel {
public final actor GenerativeModel {
/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

Expand Down Expand Up @@ -217,33 +217,31 @@ public final class GenerativeModel {
isStreaming: true,
options: requestOptions)

var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
.makeAsyncIterator()
let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)

return AsyncThrowingStream {
let response: GenerateContentResponse?
do {
response = try await responseIterator.next()
} catch {
throw GenerativeModel.generateContentError(from: error)
}
for try await response in responseStream {
// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// The responseIterator will return `nil` when it's done.
guard let response = response else {
// If the stream ended early unexpectedly, throw an error.
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
throw GenerateContentError.responseStoppedEarly(
reason: finishReason,
response: response
)
} else {
// Response was valid content, pass it along and continue.
return response
}
}
// This is the end of the stream! Signal it by sending `nil`.
return nil
}

// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// If the stream ended early unexpectedly, throw an error.
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
} else {
// Response was valid content, pass it along and continue.
return response
} catch {
throw GenerativeModel.generateContentError(from: error)
}
}
}
Expand Down
11 changes: 6 additions & 5 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,20 @@ final class ChatTests: XCTestCase {
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = chat.sendMessageStream(input)
let stream = await chat.sendMessageStream(input)

// Ensure the values are parsed correctly
for try await value in stream {
XCTAssertNotNil(value.text)
}

XCTAssertEqual(chat.history.count, 2)
XCTAssertEqual(chat.history[0].parts[0].text, input)
let history = await chat.history
XCTAssertEqual(history.count, 2)
XCTAssertEqual(history[0].parts[0].text, input)

let finalText = "1 2 3 4 5 6 7 8"
let assembledExpectation = ModelContent(role: "model", parts: finalText)
XCTAssertEqual(chat.history[0].parts[0].text, input)
XCTAssertEqual(chat.history[1], assembledExpectation)
XCTAssertEqual(history[0].parts[0].text, input)
XCTAssertEqual(history[1], assembledExpectation)
}
}
Loading