lancedb_impl #15

Open
maximevanhees wants to merge 7 commits from lancedb_impl into main
5 changed files with 132 additions and 211 deletions
Showing only changes of commit 4aa49e0d5c - Show all commits

200
Cargo.lock generated
View File

@@ -2358,15 +2358,6 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "equivalent"
version = "1.0.2"
@@ -2488,6 +2479,16 @@ dependencies = [
"rustc_version",
]
[[package]]
name = "flate2"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "fluent"
version = "0.16.1"
@@ -2544,21 +2545,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
@@ -2926,7 +2912,6 @@ dependencies = [
"rand 0.8.5",
"redb",
"redis",
"reqwest",
"secrecy",
"serde",
"serde_json",
@@ -2935,6 +2920,7 @@ dependencies = [
"tantivy 0.25.0",
"thiserror 1.0.69",
"tokio",
"ureq",
"uuid",
"x25519-dalek",
]
@@ -3131,23 +3117,7 @@ dependencies = [
"tokio",
"tokio-rustls 0.26.2",
"tower-service",
"webpki-roots",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper 1.7.0",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
"webpki-roots 1.0.2",
]
[[package]]
@@ -3169,11 +3139,9 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2 0.6.0",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -4512,12 +4480,6 @@ dependencies = [
"libc",
]
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@@ -4586,23 +4548,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b"
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -4853,50 +4798,12 @@ dependencies = [
"uuid",
]
[[package]]
name = "openssl"
version = "0.10.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8"
dependencies = [
"bitflags 2.9.3",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -5687,8 +5594,6 @@ checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb"
dependencies = [
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-channel",
"futures-core",
"futures-util",
"h2 0.4.12",
@@ -5697,12 +5602,9 @@ dependencies = [
"http-body-util",
"hyper 1.7.0",
"hyper-rustls 0.27.7",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
@@ -5714,7 +5616,6 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.2",
"tokio-util",
"tower",
@@ -5725,7 +5626,7 @@ dependencies = [
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots",
"webpki-roots 1.0.2",
]
[[package]]
@@ -6595,27 +6496,6 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.9.3",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tagptr"
version = "0.2.0"
@@ -7067,16 +6947,6 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
@@ -7344,6 +7214,24 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
dependencies = [
"base64 0.22.1",
"flate2",
"log",
"once_cell",
"rustls 0.23.31",
"rustls-pki-types",
"serde",
"serde_json",
"url",
"webpki-roots 0.26.11",
]
[[package]]
name = "url"
version = "2.5.6"
@@ -7397,12 +7285,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.5"
@@ -7581,6 +7463,15 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "webpki-roots"
version = "0.26.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
dependencies = [
"webpki-roots 1.0.2",
]
[[package]]
name = "webpki-roots"
version = "1.0.2"
@@ -7724,17 +7615,6 @@ dependencies = [
"windows-link 0.1.3",
]
[[package]]
name = "windows-registry"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
dependencies = [
"windows-link 0.1.3",
"windows-result 0.3.4",
"windows-strings 0.4.2",
]
[[package]]
name = "windows-result"
version = "0.3.4"

View File

@@ -34,7 +34,7 @@ lance-index = "0.37.0"
arrow = "55.2.0"
lancedb = "0.22.1"
uuid = "1.18.1"
reqwest = { version = "0.12", features = ["blocking", "json", "rustls-tls"] }
ureq = { version = "2.10.0", features = ["json", "tls"] }
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

View File

@@ -1363,9 +1363,20 @@ impl Cmd {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
// Resolve embedder and embed text
// Resolve embedder and embed text on a plain OS thread to avoid tokio runtime panics from reqwest::blocking
let embedder = server.get_embedder_for(&name)?;
let vector = embedder.embed(&text)?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let vector = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
};
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
match server.lance_store()?.store_vector(&name, &id, vector, meta_map, Some(text)).await {
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
@@ -1373,9 +1384,20 @@ impl Cmd {
}
}
Cmd::LanceSearchText { name, text, k, filter, return_fields } => {
// Resolve embedder and embed query text
// Resolve embedder and embed query text on a plain OS thread
let embedder = server.get_embedder_for(&name)?;
let qv = embedder.embed(&text)?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let qv = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
};
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
Ok(results) => {
// Encode as array of [id, score, [k1, v1, k2, v2, ...]]

View File

@@ -23,8 +23,7 @@ use crate::error::DBError;
// Networking for OpenAI/Azure
use std::time::Duration;
use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE, AUTHORIZATION};
use ureq::{Agent, AgentBuilder};
use serde_json::json;
/// Provider identifiers. Extend as needed to mirror LanceDB-supported providers.
@@ -132,10 +131,9 @@ impl Embedder for TestHashEmbedder {
struct OpenAIEmbedder {
model: String,
dim: usize,
client: Client,
agent: Agent,
endpoint: String,
auth_header_name: HeaderName,
auth_header_value: HeaderValue,
headers: Vec<(String, String)>,
use_azure: bool,
}
@@ -184,40 +182,33 @@ impl OpenAIEmbedder {
.unwrap_or_else(|| "https://api.openai.com/v1/embeddings".to_string())
};
// Determine expected dimension:
// - Prefer params["dim"] or params["dimensions"]
// - Else default to 1536 (common for text-embedding-3-small; callers should override if needed)
// Determine expected dimension (default 1536 for text-embedding-3-small; callers should override if needed)
let dim = cfg
.get_param_usize("dim")
.or_else(|| cfg.get_param_usize("dimensions"))
.unwrap_or(1536);
// Build default headers
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let (auth_name, auth_val) = if use_azure {
let name = HeaderName::from_static("api-key");
let val = HeaderValue::from_str(&api_key)
.map_err(|_| DBError("Invalid API key header value".into()))?;
(name, val)
} else {
let bearer = format!("Bearer {}", api_key);
(AUTHORIZATION, HeaderValue::from_str(&bearer).map_err(|_| DBError("Invalid Authorization header".into()))?)
};
// Build an HTTP agent with timeouts (blocking; no tokio runtime involved)
let agent = AgentBuilder::new()
.timeout_read(Duration::from_secs(30))
.timeout_write(Duration::from_secs(30))
.build();
let client = Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(headers)
.build()
.map_err(|e| DBError(format!("Failed to build HTTP client: {}", e)))?;
// Headers
let mut headers: Vec<(String, String)> = Vec::new();
headers.push(("Content-Type".to_string(), "application/json".to_string()));
if use_azure {
headers.push(("api-key".to_string(), api_key));
} else {
headers.push(("Authorization".to_string(), format!("Bearer {}", api_key)));
}
Ok(Self {
model: cfg.model.clone(),
dim,
client,
agent,
endpoint,
auth_header_name: auth_name,
auth_header_value: auth_val,
headers,
use_azure,
})
}
@@ -237,21 +228,26 @@ impl OpenAIEmbedder {
.insert("dimensions".to_string(), json!(self.dim));
}
let mut req = self.client.post(&self.endpoint);
// Add auth header dynamically
req = req.header(self.auth_header_name.clone(), self.auth_header_value.clone());
let resp = req
.json(&body)
.send()
.map_err(|e| DBError(format!("HTTP request failed: {}", e)))?;
if !resp.status().is_success() {
let code = resp.status();
let text = resp.text().unwrap_or_default();
return Err(DBError(format!("Embeddings API error {}: {}", code, text)));
// Build request
let mut req = self.agent.post(&self.endpoint);
for (k, v) in &self.headers {
req = req.set(k, v);
}
let val: serde_json::Value = resp
.json()
// Send and handle errors
let resp = req.send_json(body);
let text = match resp {
Ok(r) => r
.into_string()
.map_err(|e| DBError(format!("Failed to read embeddings response: {}", e)))?,
Err(ureq::Error::Status(code, r)) => {
let body = r.into_string().unwrap_or_default();
return Err(DBError(format!("Embeddings API error {}: {}", code, body)));
}
Err(e) => return Err(DBError(format!("HTTP request failed: {}", e))),
};
let val: serde_json::Value = serde_json::from_str(&text)
.map_err(|e| DBError(format!("Invalid JSON from embeddings API: {}", e)))?;
let data = val

View File

@@ -1057,10 +1057,22 @@ impl RpcServer for RpcServerImpl {
if !server.has_write_permission() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
}
// Resolve embedder and run blocking embedding off the async runtime
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
let embedder = server.get_embedder_for(&name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let vector = embedder.embed(&text)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let vector = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
};
server.lance_store()
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
.store_vector(&name, &id, vector, meta.unwrap_or_default(), Some(text)).await
@@ -1087,10 +1099,21 @@ impl RpcServer for RpcServerImpl {
if !server.has_read_permission() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>));
}
// Resolve embedder and run embedding on a plain OS thread (avoid dropping any runtime in async context)
let embedder = server.get_embedder_for(&name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let qv = embedder.embed(&text)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let (tx, rx) = tokio::sync::oneshot::channel();
let emb_arc = embedder.clone();
let text_cl = text.clone();
std::thread::spawn(move || {
let res = emb_arc.embed(&text_cl);
let _ = tx.send(res);
});
let qv = match rx.await {
Ok(Ok(v)) => v,
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
Err(recv_err) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("embedding thread error: {}", recv_err), None::<()>)),
};
let results = server.lance_store()
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
.search_vectors(&name, qv, k, filter, return_fields).await