WIP6: implementing image embedding as first step towards multi-model support
This commit is contained in:
138
docs/lancedb_text_and_images_example.md
Normal file
138
docs/lancedb_text_and_images_example.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# LanceDB Text and Images: End-to-End Example
|
||||||
|
|
||||||
|
This guide demonstrates creating a Lance backend database, ingesting two text documents and two images, performing searches over both, and cleaning up the datasets.
|
||||||
|
|
||||||
|
Prerequisites
|
||||||
|
- Build HeroDB and start the server with JSON-RPC enabled.
|
||||||
|
Commands:
|
||||||
|
```bash
|
||||||
|
cargo build --release
|
||||||
|
./target/release/herodb --dir /tmp/herodb --admin-secret mysecret --port 6379 --enable-rpc
|
||||||
|
```
|
||||||
|
|
||||||
|
We'll use:
|
||||||
|
- redis-cli for RESP commands against port 6379
|
||||||
|
- curl for JSON-RPC against 8080 if desired
|
||||||
|
- Deterministic local embedders to avoid external dependencies: testhash (text, dim 64) and testimagehash (image, dim 512)
|
||||||
|
|
||||||
|
0) Create a Lance-backed database (JSON-RPC)
|
||||||
|
Request:
|
||||||
|
```json
|
||||||
|
{ "jsonrpc": "2.0", "id": 1, "method": "herodb_createDatabase", "params": ["Lance", { "name": "media-db", "storage_path": null, "max_size": null, "redis_version": null }, null] }
|
||||||
|
```
|
||||||
|
Response returns db_id (assume 1). Select DB over RESP:
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 SELECT 1
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
|
||||||
|
1) Configure embedding providers
|
||||||
|
We'll create two datasets with independent embedding configs:
|
||||||
|
- textset → provider testhash, dim 64
|
||||||
|
- imageset → provider testimagehash, dim 512
|
||||||
|
|
||||||
|
Text config:
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER testhash MODEL any PARAM dim 64
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
Image config:
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET imageset PROVIDER testimagehash MODEL any PARAM dim 512
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
|
||||||
|
2) Create datasets
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.CREATE textset DIM 64
|
||||||
|
# → OK
|
||||||
|
redis-cli -p 6379 LANCE.CREATE imageset DIM 512
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
|
||||||
|
3) Ingest two text documents (server-side embedding)
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.STORE textset ID doc-1 TEXT "The quick brown fox jumps over the lazy dog" META title "Fox" category "animal"
|
||||||
|
# → OK
|
||||||
|
redis-cli -p 6379 LANCE.STORE textset ID doc-2 TEXT "A fast auburn fox vaulted a sleepy canine" META title "Paraphrase" category "animal"
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
|
||||||
|
4) Ingest two images
|
||||||
|
You can provide a URI or base64 bytes. Use URI for URIs, BYTES for base64 data.
|
||||||
|
Example using free placeholder images:
|
||||||
|
```bash
|
||||||
|
# Store via URI
|
||||||
|
redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-1 URI "https://picsum.photos/seed/1/256/256" META title "Seed1" group "demo"
|
||||||
|
# → OK
|
||||||
|
redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-2 URI "https://picsum.photos/seed/2/256/256" META title "Seed2" group "demo"
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
If your environment blocks outbound HTTP, you can embed image bytes:
|
||||||
|
```bash
|
||||||
|
# Example: read a local file and base64 it (replace path)
|
||||||
|
b64=$(base64 -w0 ./image1.png)
|
||||||
|
redis-cli -p 6379 LANCE.STOREIMAGE imageset ID img-b64-1 BYTES "$b64" META title "Local1" group "demo"
|
||||||
|
```
|
||||||
|
|
||||||
|
5) Search text
|
||||||
|
```bash
|
||||||
|
# Top-2 nearest neighbors for a query
|
||||||
|
redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "quick brown fox" RETURN 1 title
|
||||||
|
# → 1) [id, score, [k1,v1,...]]
|
||||||
|
```
|
||||||
|
With a filter (supports equality on schema or meta keys):
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.SEARCH textset K 2 QUERY "fox jumps" FILTER "category = 'animal'" RETURN 1 title
|
||||||
|
```
|
||||||
|
|
||||||
|
6) Search images
|
||||||
|
```bash
|
||||||
|
# Provide a URI as the query
|
||||||
|
redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYURI "https://picsum.photos/seed/1/256/256" RETURN 1 title
|
||||||
|
|
||||||
|
# Or provide base64 bytes as the query
|
||||||
|
qb64=$(curl -s https://picsum.photos/seed/3/256/256 | base64 -w0)
|
||||||
|
redis-cli -p 6379 LANCE.SEARCHIMAGE imageset K 2 QUERYBYTES "$qb64" RETURN 1 title
|
||||||
|
```
|
||||||
|
|
||||||
|
7) Inspect datasets
|
||||||
|
```bash
|
||||||
|
redis-cli -p 6379 LANCE.LIST
|
||||||
|
redis-cli -p 6379 LANCE.INFO textset
|
||||||
|
redis-cli -p 6379 LANCE.INFO imageset
|
||||||
|
```
|
||||||
|
|
||||||
|
8) Delete by id and drop datasets
|
||||||
|
```bash
|
||||||
|
# Delete one record
|
||||||
|
redis-cli -p 6379 LANCE.DEL textset doc-2
|
||||||
|
# → OK
|
||||||
|
|
||||||
|
# Drop entire datasets
|
||||||
|
redis-cli -p 6379 LANCE.DROP textset
|
||||||
|
redis-cli -p 6379 LANCE.DROP imageset
|
||||||
|
# → OK
|
||||||
|
```
|
||||||
|
|
||||||
|
Appendix: Using OpenAI embeddings instead of test providers
|
||||||
|
Text:
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY=sk-...
|
||||||
|
redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small PARAM dim 512
|
||||||
|
redis-cli -p 6379 LANCE.CREATE textset DIM 512
|
||||||
|
```
|
||||||
|
Azure OpenAI:
|
||||||
|
```bash
|
||||||
|
export AZURE_OPENAI_API_KEY=...
|
||||||
|
redis-cli -p 6379 LANCE.EMBEDDING CONFIG SET textset PROVIDER openai MODEL text-embedding-3-small \
|
||||||
|
PARAM use_azure true \
|
||||||
|
PARAM azure_endpoint https://myresource.openai.azure.com \
|
||||||
|
PARAM azure_deployment my-embed-deploy \
|
||||||
|
PARAM azure_api_version 2024-02-15 \
|
||||||
|
PARAM dim 512
|
||||||
|
```
|
||||||
|
Notes:
|
||||||
|
- Ensure dataset DIM matches the configured embedding dimension.
|
||||||
|
- Lance is only available for non-admin databases (db_id >= 1).
|
||||||
|
- On Lance DBs, only LANCE.* and basic control commands are allowed.
|
259
src/cmd.rs
259
src/cmd.rs
@@ -1,4 +1,5 @@
|
|||||||
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}};
|
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}};
|
||||||
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use tokio::time::{timeout, Duration};
|
use tokio::time::{timeout, Duration};
|
||||||
use futures::future::select_all;
|
use futures::future::select_all;
|
||||||
|
|
||||||
@@ -145,6 +146,22 @@ pub enum Cmd {
|
|||||||
filter: Option<String>,
|
filter: Option<String>,
|
||||||
return_fields: Option<Vec<String>>,
|
return_fields: Option<Vec<String>>,
|
||||||
},
|
},
|
||||||
|
// Image-first commands (no user-provided vectors)
|
||||||
|
LanceStoreImage {
|
||||||
|
name: String,
|
||||||
|
id: String,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
meta: Vec<(String, String)>,
|
||||||
|
},
|
||||||
|
LanceSearchImage {
|
||||||
|
name: String,
|
||||||
|
k: usize,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
filter: Option<String>,
|
||||||
|
return_fields: Option<Vec<String>>,
|
||||||
|
},
|
||||||
LanceCreateIndex {
|
LanceCreateIndex {
|
||||||
name: String,
|
name: String,
|
||||||
index_type: String,
|
index_type: String,
|
||||||
@@ -903,6 +920,46 @@ impl Cmd {
|
|||||||
}
|
}
|
||||||
Cmd::LanceStoreText { name, id, text, meta }
|
Cmd::LanceStoreText { name, id, text, meta }
|
||||||
}
|
}
|
||||||
|
"lance.storeimage" => {
|
||||||
|
// LANCE.STOREIMAGE name ID <id> (URI <uri> | BYTES <base64>) [META k v ...]
|
||||||
|
if cmd.len() < 6 {
|
||||||
|
return Err(DBError("ERR LANCE.STOREIMAGE requires: name ID <id> (URI <uri> | BYTES <base64>) [META k v ...]".to_string()));
|
||||||
|
}
|
||||||
|
let name = cmd[1].clone();
|
||||||
|
let mut i = 2;
|
||||||
|
if cmd[i].to_uppercase() != "ID" || i + 1 >= cmd.len() {
|
||||||
|
return Err(DBError("ERR LANCE.STOREIMAGE requires ID <id>".to_string()));
|
||||||
|
}
|
||||||
|
let id = cmd[i + 1].clone();
|
||||||
|
i += 2;
|
||||||
|
|
||||||
|
let mut uri_opt: Option<String> = None;
|
||||||
|
let mut bytes_b64_opt: Option<String> = None;
|
||||||
|
|
||||||
|
if i < cmd.len() && cmd[i].to_uppercase() == "URI" {
|
||||||
|
if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE URI requires a value".to_string())); }
|
||||||
|
uri_opt = Some(cmd[i + 1].clone());
|
||||||
|
i += 2;
|
||||||
|
} else if i < cmd.len() && cmd[i].to_uppercase() == "BYTES" {
|
||||||
|
if i + 1 >= cmd.len() { return Err(DBError("ERR LANCE.STOREIMAGE BYTES requires a value".to_string())); }
|
||||||
|
bytes_b64_opt = Some(cmd[i + 1].clone());
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
return Err(DBError("ERR LANCE.STOREIMAGE requires either URI <uri> or BYTES <base64>".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional META pairs
|
||||||
|
let mut meta: Vec<(String, String)> = Vec::new();
|
||||||
|
if i < cmd.len() && cmd[i].to_uppercase() == "META" {
|
||||||
|
i += 1;
|
||||||
|
while i + 1 < cmd.len() {
|
||||||
|
meta.push((cmd[i].clone(), cmd[i + 1].clone()));
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Cmd::LanceStoreImage { name, id, uri: uri_opt, bytes_b64: bytes_b64_opt, meta }
|
||||||
|
}
|
||||||
"lance.search" => {
|
"lance.search" => {
|
||||||
// LANCE.SEARCH name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]
|
// LANCE.SEARCH name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]
|
||||||
if cmd.len() < 6 {
|
if cmd.len() < 6 {
|
||||||
@@ -954,6 +1011,65 @@ impl Cmd {
|
|||||||
}
|
}
|
||||||
Cmd::LanceSearchText { name, text, k, filter, return_fields }
|
Cmd::LanceSearchText { name, text, k, filter, return_fields }
|
||||||
}
|
}
|
||||||
|
"lance.searchimage" => {
|
||||||
|
// LANCE.SEARCHIMAGE name K <k> (QUERYURI <uri> | QUERYBYTES <base64>) [FILTER expr] [RETURN n fields...]
|
||||||
|
if cmd.len() < 6 {
|
||||||
|
return Err(DBError("ERR LANCE.SEARCHIMAGE requires: name K <k> (QUERYURI <uri> | QUERYBYTES <base64>) [FILTER expr] [RETURN n fields...]".to_string()));
|
||||||
|
}
|
||||||
|
let name = cmd[1].clone();
|
||||||
|
if cmd[2].to_uppercase() != "K" {
|
||||||
|
return Err(DBError("ERR LANCE.SEARCHIMAGE requires K <k>".to_string()));
|
||||||
|
}
|
||||||
|
let k: usize = cmd[3].parse().map_err(|_| DBError("ERR K must be an integer".to_string()))?;
|
||||||
|
let mut i = 4;
|
||||||
|
|
||||||
|
let mut uri_opt: Option<String> = None;
|
||||||
|
let mut bytes_b64_opt: Option<String> = None;
|
||||||
|
|
||||||
|
if i < cmd.len() && cmd[i].to_uppercase() == "QUERYURI" {
|
||||||
|
if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYURI requires a value".to_string())); }
|
||||||
|
uri_opt = Some(cmd[i + 1].clone());
|
||||||
|
i += 2;
|
||||||
|
} else if i < cmd.len() && cmd[i].to_uppercase() == "QUERYBYTES" {
|
||||||
|
if i + 1 >= cmd.len() { return Err(DBError("ERR QUERYBYTES requires a value".to_string())); }
|
||||||
|
bytes_b64_opt = Some(cmd[i + 1].clone());
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
return Err(DBError("ERR LANCE.SEARCHIMAGE requires QUERYURI <uri> or QUERYBYTES <base64>".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut filter: Option<String> = None;
|
||||||
|
let mut return_fields: Option<Vec<String>> = None;
|
||||||
|
while i < cmd.len() {
|
||||||
|
match cmd[i].to_uppercase().as_str() {
|
||||||
|
"FILTER" => {
|
||||||
|
if i + 1 >= cmd.len() {
|
||||||
|
return Err(DBError("ERR FILTER requires an expression".to_string()));
|
||||||
|
}
|
||||||
|
filter = Some(cmd[i + 1].clone());
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
"RETURN" => {
|
||||||
|
if i + 1 >= cmd.len() {
|
||||||
|
return Err(DBError("ERR RETURN requires field count".to_string()));
|
||||||
|
}
|
||||||
|
let n: usize = cmd[i + 1].parse().map_err(|_| DBError("ERR RETURN count must be integer".to_string()))?;
|
||||||
|
i += 2;
|
||||||
|
let mut fields = Vec::new();
|
||||||
|
for _ in 0..n {
|
||||||
|
if i < cmd.len() {
|
||||||
|
fields.push(cmd[i].clone());
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return_fields = Some(fields);
|
||||||
|
}
|
||||||
|
_ => { i += 1; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Cmd::LanceSearchImage { name, k, uri: uri_opt, bytes_b64: bytes_b64_opt, filter, return_fields }
|
||||||
|
}
|
||||||
"lance.createindex" => {
|
"lance.createindex" => {
|
||||||
// LANCE.CREATEINDEX name TYPE t [PARAM k v ...]
|
// LANCE.CREATEINDEX name TYPE t [PARAM k v ...]
|
||||||
if cmd.len() < 4 || cmd[2].to_uppercase() != "TYPE" {
|
if cmd.len() < 4 || cmd[2].to_uppercase() != "TYPE" {
|
||||||
@@ -1136,6 +1252,8 @@ impl Cmd {
|
|||||||
| Cmd::LanceCreate { .. }
|
| Cmd::LanceCreate { .. }
|
||||||
| Cmd::LanceStoreText { .. }
|
| Cmd::LanceStoreText { .. }
|
||||||
| Cmd::LanceSearchText { .. }
|
| Cmd::LanceSearchText { .. }
|
||||||
|
| Cmd::LanceStoreImage { .. }
|
||||||
|
| Cmd::LanceSearchImage { .. }
|
||||||
| Cmd::LanceEmbeddingConfigSet { .. }
|
| Cmd::LanceEmbeddingConfigSet { .. }
|
||||||
| Cmd::LanceEmbeddingConfigGet { .. }
|
| Cmd::LanceEmbeddingConfigGet { .. }
|
||||||
| Cmd::LanceCreateIndex { .. }
|
| Cmd::LanceCreateIndex { .. }
|
||||||
@@ -1172,6 +1290,8 @@ impl Cmd {
|
|||||||
Cmd::LanceCreate { .. }
|
Cmd::LanceCreate { .. }
|
||||||
| Cmd::LanceStoreText { .. }
|
| Cmd::LanceStoreText { .. }
|
||||||
| Cmd::LanceSearchText { .. }
|
| Cmd::LanceSearchText { .. }
|
||||||
|
| Cmd::LanceStoreImage { .. }
|
||||||
|
| Cmd::LanceSearchImage { .. }
|
||||||
| Cmd::LanceEmbeddingConfigSet { .. }
|
| Cmd::LanceEmbeddingConfigSet { .. }
|
||||||
| Cmd::LanceEmbeddingConfigGet { .. }
|
| Cmd::LanceEmbeddingConfigGet { .. }
|
||||||
| Cmd::LanceCreateIndex { .. }
|
| Cmd::LanceCreateIndex { .. }
|
||||||
@@ -1421,6 +1541,145 @@ impl Cmd {
|
|||||||
Err(e) => Ok(Protocol::err(&e.0)),
|
Err(e) => Ok(Protocol::err(&e.0)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New: Image store
|
||||||
|
Cmd::LanceStoreImage { name, id, uri, bytes_b64, meta } => {
|
||||||
|
if !server.has_write_permission() {
|
||||||
|
return Ok(Protocol::err("ERR write permission denied"));
|
||||||
|
}
|
||||||
|
let use_uri = uri.is_some();
|
||||||
|
let use_b64 = bytes_b64.is_some();
|
||||||
|
if (use_uri && use_b64) || (!use_uri && !use_b64) {
|
||||||
|
return Ok(Protocol::err("ERR Provide exactly one of URI or BYTES for LANCE.STOREIMAGE"));
|
||||||
|
}
|
||||||
|
let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(10 * 1024 * 1024) as usize;
|
||||||
|
|
||||||
|
let media_uri_opt = if let Some(u) = uri.clone() {
|
||||||
|
match server.fetch_image_bytes_from_uri(&u) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => return Ok(Protocol::err(&e.0)),
|
||||||
|
}
|
||||||
|
Some(u)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes: Vec<u8> = if let Some(u) = uri {
|
||||||
|
match server.fetch_image_bytes_from_uri(&u) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return Ok(Protocol::err(&e.0)),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let b64 = bytes_b64.unwrap_or_default();
|
||||||
|
let data = match general_purpose::STANDARD.decode(b64.as_bytes()) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))),
|
||||||
|
};
|
||||||
|
if data.len() > max_bytes {
|
||||||
|
return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes)));
|
||||||
|
}
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
|
let img_embedder = match server.get_image_embedder_for(&name) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => return Ok(Protocol::err(&e.0)),
|
||||||
|
};
|
||||||
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
let emb_arc = img_embedder.clone();
|
||||||
|
let bytes_cl = bytes.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let res = emb_arc.embed_image(&bytes_cl);
|
||||||
|
let _ = tx.send(res);
|
||||||
|
});
|
||||||
|
let vector = match rx.await {
|
||||||
|
Ok(Ok(v)) => v,
|
||||||
|
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
|
||||||
|
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
|
||||||
|
};
|
||||||
|
|
||||||
|
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
|
||||||
|
match server.lance_store()?.store_vector_with_media(
|
||||||
|
&name,
|
||||||
|
&id,
|
||||||
|
vector,
|
||||||
|
meta_map,
|
||||||
|
None,
|
||||||
|
Some("image".to_string()),
|
||||||
|
media_uri_opt,
|
||||||
|
).await {
|
||||||
|
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||||
|
Err(e) => Ok(Protocol::err(&e.0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New: Image search
|
||||||
|
Cmd::LanceSearchImage { name, k, uri, bytes_b64, filter, return_fields } => {
|
||||||
|
let use_uri = uri.is_some();
|
||||||
|
let use_b64 = bytes_b64.is_some();
|
||||||
|
if (use_uri && use_b64) || (!use_uri && !use_b64) {
|
||||||
|
return Ok(Protocol::err("ERR Provide exactly one of QUERYURI or QUERYBYTES for LANCE.SEARCHIMAGE"));
|
||||||
|
}
|
||||||
|
let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(10 * 1024 * 1024) as usize;
|
||||||
|
|
||||||
|
let bytes: Vec<u8> = if let Some(u) = uri {
|
||||||
|
match server.fetch_image_bytes_from_uri(&u) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => return Ok(Protocol::err(&e.0)),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let b64 = bytes_b64.unwrap_or_default();
|
||||||
|
let data = match general_purpose::STANDARD.decode(b64.as_bytes()) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => return Ok(Protocol::err(&format!("ERR base64 decode error: {}", e))),
|
||||||
|
};
|
||||||
|
if data.len() > max_bytes {
|
||||||
|
return Ok(Protocol::err(&format!("ERR image exceeds max allowed bytes {}", max_bytes)));
|
||||||
|
}
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
|
let img_embedder = match server.get_image_embedder_for(&name) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => return Ok(Protocol::err(&e.0)),
|
||||||
|
};
|
||||||
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let res = img_embedder.embed_image(&bytes);
|
||||||
|
let _ = tx.send(res);
|
||||||
|
});
|
||||||
|
let qv = match rx.await {
|
||||||
|
Ok(Ok(v)) => v,
|
||||||
|
Ok(Err(e)) => return Ok(Protocol::err(&e.0)),
|
||||||
|
Err(recv_err) => return Ok(Protocol::err(&format!("ERR embedding thread error: {}", recv_err))),
|
||||||
|
};
|
||||||
|
|
||||||
|
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
|
||||||
|
Ok(results) => {
|
||||||
|
let mut arr = Vec::new();
|
||||||
|
for (id, score, meta) in results {
|
||||||
|
let mut meta_arr: Vec<Protocol> = Vec::new();
|
||||||
|
for (k, v) in meta {
|
||||||
|
meta_arr.push(Protocol::BulkString(k));
|
||||||
|
meta_arr.push(Protocol::BulkString(v));
|
||||||
|
}
|
||||||
|
arr.push(Protocol::Array(vec![
|
||||||
|
Protocol::BulkString(id),
|
||||||
|
Protocol::BulkString(score.to_string()),
|
||||||
|
Protocol::Array(meta_arr),
|
||||||
|
]));
|
||||||
|
}
|
||||||
|
Ok(Protocol::Array(arr))
|
||||||
|
}
|
||||||
|
Err(e) => Ok(Protocol::err(&e.0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
Cmd::LanceCreateIndex { name, index_type, params } => {
|
Cmd::LanceCreateIndex { name, index_type, params } => {
|
||||||
if !server.has_write_permission() {
|
if !server.has_write_permission() {
|
||||||
return Ok(Protocol::err("ERR write permission denied"));
|
return Ok(Protocol::err("ERR write permission denied"));
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
use chacha20poly1305::{
|
use chacha20poly1305::{
|
||||||
aead::{Aead, KeyInit, OsRng},
|
aead::{Aead, KeyInit},
|
||||||
XChaCha20Poly1305, XNonce,
|
XChaCha20Poly1305, XNonce,
|
||||||
};
|
};
|
||||||
use rand::RngCore;
|
use rand::{rngs::OsRng, RngCore};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
const VERSION: u8 = 1;
|
const VERSION: u8 = 1;
|
||||||
@@ -31,7 +31,7 @@ pub struct CryptoFactory {
|
|||||||
impl CryptoFactory {
|
impl CryptoFactory {
|
||||||
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
|
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
|
||||||
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
|
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
|
||||||
let mut h = Sha256::new();
|
let mut h = Sha256::default();
|
||||||
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
|
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
|
||||||
h.update(secret.as_ref());
|
h.update(secret.as_ref());
|
||||||
let digest = h.finalize(); // 32 bytes
|
let digest = h.finalize(); // 32 bytes
|
||||||
|
@@ -1,18 +1,4 @@
|
|||||||
// Embedding abstraction and minimal providers.
|
// Embedding abstraction and minimal providers.
|
||||||
//
|
|
||||||
// This module defines a provider-agnostic interface to produce vector embeddings
|
|
||||||
// from text, so callers never need to supply vectors manually. It includes:
|
|
||||||
// - Embedder trait
|
|
||||||
// - EmbeddingProvider and EmbeddingConfig (serde-serializable)
|
|
||||||
// - TestHashEmbedder: deterministic, CPU-only, no-network embedder suitable for CI
|
|
||||||
// - Factory create_embedder(..) to instantiate an embedder from config
|
|
||||||
//
|
|
||||||
// Integration plan:
|
|
||||||
// - Server will resolve per-dataset EmbeddingConfig from sidecar JSON and cache Arc<dyn Embedder>
|
|
||||||
// - LanceStore will call embedder.embed(text) then persist id, vector, text, meta
|
|
||||||
//
|
|
||||||
// Note: Real LanceDB-backed embedding providers can be added by implementing Embedder
|
|
||||||
// and extending create_embedder(..). This file keeps no direct dependency on LanceDB.
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@@ -112,6 +112,8 @@ impl LanceStore {
|
|||||||
Field::new("id", DataType::Utf8, false),
|
Field::new("id", DataType::Utf8, false),
|
||||||
Self::vector_field(dim),
|
Self::vector_field(dim),
|
||||||
Field::new("text", DataType::Utf8, true),
|
Field::new("text", DataType::Utf8, true),
|
||||||
|
Field::new("media_type", DataType::Utf8, true),
|
||||||
|
Field::new("media_uri", DataType::Utf8, true),
|
||||||
Field::new("meta", DataType::Utf8, true),
|
Field::new("meta", DataType::Utf8, true),
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
@@ -121,6 +123,8 @@ impl LanceStore {
|
|||||||
vector: &[f32],
|
vector: &[f32],
|
||||||
meta: &HashMap<String, String>,
|
meta: &HashMap<String, String>,
|
||||||
text: Option<&str>,
|
text: Option<&str>,
|
||||||
|
media_type: Option<&str>,
|
||||||
|
media_uri: Option<&str>,
|
||||||
dim: i32,
|
dim: i32,
|
||||||
) -> Result<(Arc<Schema>, RecordBatch), DBError> {
|
) -> Result<(Arc<Schema>, RecordBatch), DBError> {
|
||||||
if vector.len() as i32 != dim {
|
if vector.len() as i32 != dim {
|
||||||
@@ -156,6 +160,24 @@ impl LanceStore {
|
|||||||
}
|
}
|
||||||
let text_arr = Arc::new(text_builder.finish()) as Arc<dyn Array>;
|
let text_arr = Arc::new(text_builder.finish()) as Arc<dyn Array>;
|
||||||
|
|
||||||
|
// media_type column (optional)
|
||||||
|
let mut mt_builder = StringBuilder::new();
|
||||||
|
if let Some(mt) = media_type {
|
||||||
|
mt_builder.append_value(mt);
|
||||||
|
} else {
|
||||||
|
mt_builder.append_null();
|
||||||
|
}
|
||||||
|
let mt_arr = Arc::new(mt_builder.finish()) as Arc<dyn Array>;
|
||||||
|
|
||||||
|
// media_uri column (optional)
|
||||||
|
let mut mu_builder = StringBuilder::new();
|
||||||
|
if let Some(mu) = media_uri {
|
||||||
|
mu_builder.append_value(mu);
|
||||||
|
} else {
|
||||||
|
mu_builder.append_null();
|
||||||
|
}
|
||||||
|
let mu_arr = Arc::new(mu_builder.finish()) as Arc<dyn Array>;
|
||||||
|
|
||||||
// meta column (JSON string)
|
// meta column (JSON string)
|
||||||
let meta_json = if meta.is_empty() {
|
let meta_json = if meta.is_empty() {
|
||||||
None
|
None
|
||||||
@@ -171,7 +193,7 @@ impl LanceStore {
|
|||||||
let meta_arr = Arc::new(meta_builder.finish()) as Arc<dyn Array>;
|
let meta_arr = Arc::new(meta_builder.finish()) as Arc<dyn Array>;
|
||||||
|
|
||||||
let batch =
|
let batch =
|
||||||
RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, meta_arr]).map_err(|e| {
|
RecordBatch::try_new(schema.clone(), vec![id_arr, vec_arr, text_arr, mt_arr, mu_arr, meta_arr]).map_err(|e| {
|
||||||
DBError(format!("RecordBatch build failed: {e}"))
|
DBError(format!("RecordBatch build failed: {e}"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -207,10 +229,12 @@ impl LanceStore {
|
|||||||
let mut list_builder = FixedSizeListBuilder::new(v_builder, dim_i32);
|
let mut list_builder = FixedSizeListBuilder::new(v_builder, dim_i32);
|
||||||
let empty_vec = Arc::new(list_builder.finish()) as Arc<dyn Array>;
|
let empty_vec = Arc::new(list_builder.finish()) as Arc<dyn Array>;
|
||||||
let empty_text = Arc::new(StringArray::new_null(0));
|
let empty_text = Arc::new(StringArray::new_null(0));
|
||||||
|
let empty_media_type = Arc::new(StringArray::new_null(0));
|
||||||
|
let empty_media_uri = Arc::new(StringArray::new_null(0));
|
||||||
let empty_meta = Arc::new(StringArray::new_null(0));
|
let empty_meta = Arc::new(StringArray::new_null(0));
|
||||||
|
|
||||||
let empty_batch =
|
let empty_batch =
|
||||||
RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_meta])
|
RecordBatch::try_new(schema.clone(), vec![empty_id, empty_vec, empty_text, empty_media_type, empty_media_uri, empty_meta])
|
||||||
.map_err(|e| DBError(format!("Build empty batch failed: {e}")))?;
|
.map_err(|e| DBError(format!("Build empty batch failed: {e}")))?;
|
||||||
|
|
||||||
let write_params = WriteParams {
|
let write_params = WriteParams {
|
||||||
@@ -235,6 +259,21 @@ impl LanceStore {
|
|||||||
vector: Vec<f32>,
|
vector: Vec<f32>,
|
||||||
meta: HashMap<String, String>,
|
meta: HashMap<String, String>,
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
|
) -> Result<(), DBError> {
|
||||||
|
// Delegate to media-aware path with no media fields
|
||||||
|
self.store_vector_with_media(name, id, vector, meta, text, None, None).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store/Upsert a single vector with optional text and media fields (media_type/media_uri).
|
||||||
|
pub async fn store_vector_with_media(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
id: &str,
|
||||||
|
vector: Vec<f32>,
|
||||||
|
meta: HashMap<String, String>,
|
||||||
|
text: Option<String>,
|
||||||
|
media_type: Option<String>,
|
||||||
|
media_uri: Option<String>,
|
||||||
) -> Result<(), DBError> {
|
) -> Result<(), DBError> {
|
||||||
let path = self.dataset_path(name);
|
let path = self.dataset_path(name);
|
||||||
|
|
||||||
@@ -248,7 +287,15 @@ impl LanceStore {
|
|||||||
.map_err(|_| DBError("Vector length too large".into()))?
|
.map_err(|_| DBError("Vector length too large".into()))?
|
||||||
};
|
};
|
||||||
|
|
||||||
let (schema, batch) = Self::build_one_row_batch(id, &vector, &meta, text.as_deref(), dim_i32)?;
|
let (schema, batch) = Self::build_one_row_batch(
|
||||||
|
id,
|
||||||
|
&vector,
|
||||||
|
&meta,
|
||||||
|
text.as_deref(),
|
||||||
|
media_type.as_deref(),
|
||||||
|
media_uri.as_deref(),
|
||||||
|
dim_i32,
|
||||||
|
)?;
|
||||||
|
|
||||||
// If LanceDB table exists and provides delete, we can upsert by deleting same id
|
// If LanceDB table exists and provides delete, we can upsert by deleting same id
|
||||||
// Try best-effort delete; ignore errors to keep operation append-only on failure
|
// Try best-effort delete; ignore errors to keep operation append-only on failure
|
||||||
@@ -355,21 +402,36 @@ impl LanceStore {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| DBError(format!("Open dataset failed: {}", e)))?;
|
.map_err(|e| DBError(format!("Open dataset failed: {}", e)))?;
|
||||||
|
|
||||||
// Build scanner with projection; filter if provided
|
// Build scanner with projection; we project needed fields and filter client-side to support meta keys
|
||||||
let mut scan = ds.scan();
|
let mut scan = ds.scan();
|
||||||
if let Err(e) = scan.project(&["id", "vector", "meta"]) {
|
if let Err(e) = scan.project(&["id", "vector", "meta", "text", "media_type", "media_uri"]) {
|
||||||
return Err(DBError(format!("Project failed: {}", e)));
|
return Err(DBError(format!("Project failed: {}", e)));
|
||||||
}
|
}
|
||||||
if let Some(pred) = filter {
|
// Note: we no longer push down filter to Lance to allow filtering on meta fields client-side.
|
||||||
if let Err(e) = scan.filter(&pred) {
|
|
||||||
return Err(DBError(format!("Filter failed: {}", e)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut stream = scan
|
let mut stream = scan
|
||||||
.try_into_stream()
|
.try_into_stream()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| DBError(format!("Scan stream failed: {}", e)))?;
|
.map_err(|e| DBError(format!("Scan stream failed: {}", e)))?;
|
||||||
|
|
||||||
|
// Parse simple equality clause from filter for client-side filtering (supports one `key = 'value'`)
|
||||||
|
let clause = filter.as_ref().and_then(|s| {
|
||||||
|
fn parse_eq(s: &str) -> Option<(String, String)> {
|
||||||
|
let s = s.trim();
|
||||||
|
let pos = s.find('=').or_else(|| s.find(" = "))?;
|
||||||
|
let (k, vraw) = s.split_at(pos);
|
||||||
|
let mut v = vraw.trim_start_matches('=').trim();
|
||||||
|
if (v.starts_with('\'') && v.ends_with('\'')) || (v.starts_with('"') && v.ends_with('"')) {
|
||||||
|
if v.len() >= 2 {
|
||||||
|
v = &v[1..v.len()-1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let key = k.trim().trim_matches('"').trim_matches('\'').to_string();
|
||||||
|
if key.is_empty() { return None; }
|
||||||
|
Some((key, v.to_string()))
|
||||||
|
}
|
||||||
|
parse_eq(s)
|
||||||
|
});
|
||||||
|
|
||||||
// Maintain a max-heap with reverse ordering to keep top-k smallest distances
|
// Maintain a max-heap with reverse ordering to keep top-k smallest distances
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -412,20 +474,18 @@ impl LanceStore {
|
|||||||
let meta_arr = batch
|
let meta_arr = batch
|
||||||
.column_by_name("meta")
|
.column_by_name("meta")
|
||||||
.map(|a| a.as_string::<i32>().clone());
|
.map(|a| a.as_string::<i32>().clone());
|
||||||
|
let text_arr = batch
|
||||||
|
.column_by_name("text")
|
||||||
|
.map(|a| a.as_string::<i32>().clone());
|
||||||
|
let mt_arr = batch
|
||||||
|
.column_by_name("media_type")
|
||||||
|
.map(|a| a.as_string::<i32>().clone());
|
||||||
|
let mu_arr = batch
|
||||||
|
.column_by_name("media_uri")
|
||||||
|
.map(|a| a.as_string::<i32>().clone());
|
||||||
|
|
||||||
for i in 0..batch.num_rows() {
|
for i in 0..batch.num_rows() {
|
||||||
// Compute L2 distance
|
// Extract id
|
||||||
let val = vec_arr.value(i);
|
|
||||||
let prim = val.as_primitive::<Float32Type>();
|
|
||||||
let mut dist: f32 = 0.0;
|
|
||||||
let plen = prim.len();
|
|
||||||
for j in 0..plen {
|
|
||||||
let r = prim.value(j);
|
|
||||||
let d = query[j] - r;
|
|
||||||
dist += d * d;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse id
|
|
||||||
let id_val = id_arr.value(i).to_string();
|
let id_val = id_arr.value(i).to_string();
|
||||||
|
|
||||||
// Parse meta JSON if present
|
// Parse meta JSON if present
|
||||||
@@ -439,26 +499,54 @@ impl LanceStore {
|
|||||||
meta.insert(k, vs.to_string());
|
meta.insert(k, vs.to_string());
|
||||||
} else if v.is_number() || v.is_boolean() {
|
} else if v.is_number() || v.is_boolean() {
|
||||||
meta.insert(k, v.to_string());
|
meta.insert(k, v.to_string());
|
||||||
} else {
|
|
||||||
// skip complex entries
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Evaluate simple equality filter if provided (supports one clause)
|
||||||
|
let passes = if let Some((ref key, ref val)) = clause {
|
||||||
|
let candidate = match key.as_str() {
|
||||||
|
"id" => Some(id_val.clone()),
|
||||||
|
"text" => text_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }),
|
||||||
|
"media_type" => mt_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }),
|
||||||
|
"media_uri" => mu_arr.as_ref().and_then(|col| if col.is_null(i) { None } else { Some(col.value(i).to_string()) }),
|
||||||
|
_ => meta.get(key).cloned(),
|
||||||
|
};
|
||||||
|
match candidate {
|
||||||
|
Some(cv) => cv == *val,
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
} else { true };
|
||||||
|
if !passes {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute L2 distance
|
||||||
|
let val = vec_arr.value(i);
|
||||||
|
let prim = val.as_primitive::<Float32Type>();
|
||||||
|
let mut dist: f32 = 0.0;
|
||||||
|
let plen = prim.len();
|
||||||
|
for j in 0..plen {
|
||||||
|
let r = prim.value(j);
|
||||||
|
let d = query[j] - r;
|
||||||
|
dist += d * d;
|
||||||
|
}
|
||||||
|
|
||||||
// Apply return_fields on meta
|
// Apply return_fields on meta
|
||||||
|
let mut meta_out = meta;
|
||||||
if let Some(fields) = &return_fields {
|
if let Some(fields) = &return_fields {
|
||||||
let mut filtered = HashMap::new();
|
let mut filtered = HashMap::new();
|
||||||
for f in fields {
|
for f in fields {
|
||||||
if let Some(val) = meta.get(f) {
|
if let Some(val) = meta_out.get(f) {
|
||||||
filtered.insert(f.clone(), val.clone());
|
filtered.insert(f.clone(), val.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
meta = filtered;
|
meta_out = filtered;
|
||||||
}
|
}
|
||||||
|
|
||||||
let hit = Hit { dist, id: id_val, meta };
|
let hit = Hit { dist, id: id_val, meta: meta_out };
|
||||||
|
|
||||||
if heap.len() < k {
|
if heap.len() < k {
|
||||||
heap.push(hit);
|
heap.push(hit);
|
||||||
|
228
src/rpc.rs
228
src/rpc.rs
@@ -10,6 +10,7 @@ use crate::server::Server;
|
|||||||
use crate::options::DBOption;
|
use crate::options::DBOption;
|
||||||
use crate::admin_meta;
|
use crate::admin_meta;
|
||||||
use crate::embedding::{EmbeddingConfig, EmbeddingProvider};
|
use crate::embedding::{EmbeddingConfig, EmbeddingProvider};
|
||||||
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
|
|
||||||
/// Database backend types
|
/// Database backend types
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -282,6 +283,33 @@ pub trait Rpc {
|
|||||||
filter: Option<String>,
|
filter: Option<String>,
|
||||||
return_fields: Option<Vec<String>>,
|
return_fields: Option<Vec<String>>,
|
||||||
) -> RpcResult<serde_json::Value>;
|
) -> RpcResult<serde_json::Value>;
|
||||||
|
|
||||||
|
// ----- Image-first endpoints (no user-provided vectors) -----
|
||||||
|
|
||||||
|
/// Store an image; exactly one of uri or bytes_b64 must be provided.
|
||||||
|
#[method(name = "lanceStoreImage")]
|
||||||
|
async fn lance_store_image(
|
||||||
|
&self,
|
||||||
|
db_id: u64,
|
||||||
|
name: String,
|
||||||
|
id: String,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
meta: Option<HashMap<String, String>>,
|
||||||
|
) -> RpcResult<bool>;
|
||||||
|
|
||||||
|
/// Search using an image query; exactly one of uri or bytes_b64 must be provided.
|
||||||
|
#[method(name = "lanceSearchImage")]
|
||||||
|
async fn lance_search_image(
|
||||||
|
&self,
|
||||||
|
db_id: u64,
|
||||||
|
name: String,
|
||||||
|
k: usize,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
filter: Option<String>,
|
||||||
|
return_fields: Option<Vec<String>>,
|
||||||
|
) -> RpcResult<serde_json::Value>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// RPC Server implementation
|
/// RPC Server implementation
|
||||||
@@ -1131,4 +1159,204 @@ impl RpcServer for RpcServerImpl {
|
|||||||
|
|
||||||
Ok(serde_json::json!({ "results": json_results }))
|
Ok(serde_json::json!({ "results": json_results }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----- New image-first Lance RPC implementations -----
|
||||||
|
|
||||||
|
async fn lance_store_image(
|
||||||
|
&self,
|
||||||
|
db_id: u64,
|
||||||
|
name: String,
|
||||||
|
id: String,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
meta: Option<HashMap<String, String>>,
|
||||||
|
) -> RpcResult<bool> {
|
||||||
|
let server = self.get_or_create_server(db_id).await?;
|
||||||
|
if db_id == 0 {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>));
|
||||||
|
}
|
||||||
|
if !matches!(server.option.backend, crate::options::BackendType::Lance) {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>));
|
||||||
|
}
|
||||||
|
if !server.has_write_permission() {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "write permission denied", None::<()>));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate exactly one of uri or bytes_b64
|
||||||
|
let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some());
|
||||||
|
if (use_uri && use_b64) || (!use_uri && !use_b64) {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
"Provide exactly one of 'uri' or 'bytes_b64'",
|
||||||
|
None::<()>,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire image bytes (with caps)
|
||||||
|
let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(10 * 1024 * 1024) as usize;
|
||||||
|
|
||||||
|
let (bytes, media_uri_opt) = if let Some(u) = uri.clone() {
|
||||||
|
let data = server
|
||||||
|
.fetch_image_bytes_from_uri(&u)
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||||
|
(data, Some(u))
|
||||||
|
} else {
|
||||||
|
let b64 = bytes_b64.unwrap_or_default();
|
||||||
|
let data = general_purpose::STANDARD
|
||||||
|
.decode(b64.as_bytes())
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?;
|
||||||
|
if data.len() > max_bytes {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
format!("Image exceeds max allowed bytes {}", max_bytes),
|
||||||
|
None::<()>,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
(data, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Resolve image embedder and embed on a plain OS thread
|
||||||
|
let img_embedder = server
|
||||||
|
.get_image_embedder_for(&name)
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||||
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
let emb_arc = img_embedder.clone();
|
||||||
|
let bytes_cl = bytes.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let res = emb_arc.embed_image(&bytes_cl);
|
||||||
|
let _ = tx.send(res);
|
||||||
|
});
|
||||||
|
let vector = match rx.await {
|
||||||
|
Ok(Ok(v)) => v,
|
||||||
|
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
|
||||||
|
Err(recv_err) => {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
format!("embedding thread error: {}", recv_err),
|
||||||
|
None::<()>,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Store vector with media fields
|
||||||
|
server
|
||||||
|
.lance_store()
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
|
||||||
|
.store_vector_with_media(
|
||||||
|
&name,
|
||||||
|
&id,
|
||||||
|
vector,
|
||||||
|
meta.unwrap_or_default(),
|
||||||
|
None,
|
||||||
|
Some("image".to_string()),
|
||||||
|
media_uri_opt,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||||
|
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn lance_search_image(
|
||||||
|
&self,
|
||||||
|
db_id: u64,
|
||||||
|
name: String,
|
||||||
|
k: usize,
|
||||||
|
uri: Option<String>,
|
||||||
|
bytes_b64: Option<String>,
|
||||||
|
filter: Option<String>,
|
||||||
|
return_fields: Option<Vec<String>>,
|
||||||
|
) -> RpcResult<serde_json::Value> {
|
||||||
|
let server = self.get_or_create_server(db_id).await?;
|
||||||
|
if db_id == 0 {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "Lance not allowed on DB 0", None::<()>));
|
||||||
|
}
|
||||||
|
if !matches!(server.option.backend, crate::options::BackendType::Lance) {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "DB backend is not Lance", None::<()>));
|
||||||
|
}
|
||||||
|
if !server.has_read_permission() {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, "read permission denied", None::<()>));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate exactly one of uri or bytes_b64
|
||||||
|
let (use_uri, use_b64) = (uri.is_some(), bytes_b64.is_some());
|
||||||
|
if (use_uri && use_b64) || (!use_uri && !use_b64) {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
"Provide exactly one of 'uri' or 'bytes_b64'",
|
||||||
|
None::<()>,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire image bytes for query (with caps)
|
||||||
|
let max_bytes: usize = std::env::var("HERODB_IMAGE_MAX_BYTES")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(10 * 1024 * 1024) as usize;
|
||||||
|
|
||||||
|
let bytes = if let Some(u) = uri {
|
||||||
|
server
|
||||||
|
.fetch_image_bytes_from_uri(&u)
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
|
||||||
|
} else {
|
||||||
|
let b64 = bytes_b64.unwrap_or_default();
|
||||||
|
let data = general_purpose::STANDARD
|
||||||
|
.decode(b64.as_bytes())
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("base64 decode error: {}", e), None::<()>))?;
|
||||||
|
if data.len() > max_bytes {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
format!("Image exceeds max allowed bytes {}", max_bytes),
|
||||||
|
None::<()>,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
data
|
||||||
|
};
|
||||||
|
|
||||||
|
// Resolve image embedder and embed on OS thread
|
||||||
|
let img_embedder = server
|
||||||
|
.get_image_embedder_for(&name)
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||||
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
let emb_arc = img_embedder.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let res = emb_arc.embed_image(&bytes);
|
||||||
|
let _ = tx.send(res);
|
||||||
|
});
|
||||||
|
let qv = match rx.await {
|
||||||
|
Ok(Ok(v)) => v,
|
||||||
|
Ok(Err(e)) => return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>)),
|
||||||
|
Err(recv_err) => {
|
||||||
|
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
|
||||||
|
-32000,
|
||||||
|
format!("embedding thread error: {}", recv_err),
|
||||||
|
None::<()>,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// KNN search and return results
|
||||||
|
let results = server
|
||||||
|
.lance_store()
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?
|
||||||
|
.search_vectors(&name, qv, k, filter, return_fields)
|
||||||
|
.await
|
||||||
|
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
|
||||||
|
|
||||||
|
let json_results: Vec<serde_json::Value> = results
|
||||||
|
.into_iter()
|
||||||
|
.map(|(id, score, meta)| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": id,
|
||||||
|
"score": score,
|
||||||
|
"meta": meta,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(serde_json::json!({ "results": json_results }))
|
||||||
|
}
|
||||||
}
|
}
|
@@ -15,8 +15,11 @@ use crate::storage_trait::StorageBackend;
|
|||||||
use crate::admin_meta;
|
use crate::admin_meta;
|
||||||
|
|
||||||
// Embeddings: config and cache
|
// Embeddings: config and cache
|
||||||
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder};
|
use crate::embedding::{EmbeddingConfig, create_embedder, Embedder, create_image_embedder, ImageEmbedder};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
|
use ureq::{Agent, AgentBuilder};
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
@@ -33,9 +36,12 @@ pub struct Server {
|
|||||||
// Per-DB Lance stores (vector DB), keyed by db_id
|
// Per-DB Lance stores (vector DB), keyed by db_id
|
||||||
pub lance_stores: Arc<std::sync::RwLock<HashMap<u64, Arc<crate::lance_store::LanceStore>>>>,
|
pub lance_stores: Arc<std::sync::RwLock<HashMap<u64, Arc<crate::lance_store::LanceStore>>>>,
|
||||||
|
|
||||||
// Per-(db_id, dataset) embedder cache
|
// Per-(db_id, dataset) embedder cache (text)
|
||||||
pub embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn Embedder>>>>,
|
pub embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn Embedder>>>>,
|
||||||
|
|
||||||
|
// Per-(db_id, dataset) image embedder cache (image)
|
||||||
|
pub image_embedders: Arc<std::sync::RwLock<HashMap<(u64, String), Arc<dyn ImageEmbedder>>>>,
|
||||||
|
|
||||||
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
|
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
|
||||||
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
|
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
|
||||||
pub waiter_seq: Arc<AtomicU64>,
|
pub waiter_seq: Arc<AtomicU64>,
|
||||||
@@ -66,6 +72,7 @@ impl Server {
|
|||||||
search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
search_indexes: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
lance_stores: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
|
image_embedders: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
list_waiters: Arc::new(Mutex::new(HashMap::new())),
|
list_waiters: Arc::new(Mutex::new(HashMap::new())),
|
||||||
waiter_seq: Arc::new(AtomicU64::new(1)),
|
waiter_seq: Arc::new(AtomicU64::new(1)),
|
||||||
}
|
}
|
||||||
@@ -189,6 +196,10 @@ impl Server {
|
|||||||
let mut map = self.embedders.write().unwrap();
|
let mut map = self.embedders.write().unwrap();
|
||||||
map.remove(&(self.selected_db, dataset.to_string()));
|
map.remove(&(self.selected_db, dataset.to_string()));
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
let mut map_img = self.image_embedders.write().unwrap();
|
||||||
|
map_img.remove(&(self.selected_db, dataset.to_string()));
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,6 +244,88 @@ impl Server {
|
|||||||
Ok(emb)
|
Ok(emb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve or build an IMAGE embedder for (db_id, dataset). Caches instance.
|
||||||
|
pub fn get_image_embedder_for(&self, dataset: &str) -> Result<Arc<dyn ImageEmbedder>, DBError> {
|
||||||
|
if self.selected_db == 0 {
|
||||||
|
return Err(DBError("Lance not available on admin DB 0".to_string()));
|
||||||
|
}
|
||||||
|
// Fast path
|
||||||
|
{
|
||||||
|
let map = self.image_embedders.read().unwrap();
|
||||||
|
if let Some(e) = map.get(&(self.selected_db, dataset.to_string())) {
|
||||||
|
return Ok(e.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Load config and instantiate
|
||||||
|
let cfg = self.get_dataset_embedding_config(dataset)?;
|
||||||
|
let emb = create_image_embedder(&cfg)?;
|
||||||
|
{
|
||||||
|
let mut map = self.image_embedders.write().unwrap();
|
||||||
|
map.insert((self.selected_db, dataset.to_string()), emb.clone());
|
||||||
|
}
|
||||||
|
Ok(emb)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download image bytes from a URI with safety checks (size, timeout, content-type, optional host allowlist).
|
||||||
|
/// Env overrides:
|
||||||
|
/// - HERODB_IMAGE_MAX_BYTES (u64, default 10485760)
|
||||||
|
/// - HERODB_IMAGE_FETCH_TIMEOUT_SECS (u64, default 30)
|
||||||
|
/// - HERODB_IMAGE_ALLOWED_HOSTS (comma-separated, optional)
|
||||||
|
pub fn fetch_image_bytes_from_uri(&self, uri: &str) -> Result<Vec<u8>, DBError> {
|
||||||
|
// Basic scheme validation
|
||||||
|
if !(uri.starts_with("http://") || uri.starts_with("https://")) {
|
||||||
|
return Err(DBError("Only http(s) URIs are supported for image fetch".into()));
|
||||||
|
}
|
||||||
|
// Parse host (naive) for allowlist check
|
||||||
|
let host = {
|
||||||
|
let after_scheme = match uri.find("://") {
|
||||||
|
Some(i) => &uri[i + 3..],
|
||||||
|
None => uri,
|
||||||
|
};
|
||||||
|
let end = after_scheme.find('/').unwrap_or(after_scheme.len());
|
||||||
|
let host_port = &after_scheme[..end];
|
||||||
|
host_port.split('@').last().unwrap_or(host_port).split(':').next().unwrap_or(host_port).to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_bytes: u64 = std::env::var("HERODB_IMAGE_MAX_BYTES").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(10 * 1024 * 1024);
|
||||||
|
let timeout_secs: u64 = std::env::var("HERODB_IMAGE_FETCH_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()).unwrap_or(30);
|
||||||
|
let allowed_hosts_env = std::env::var("HERODB_IMAGE_ALLOWED_HOSTS").ok();
|
||||||
|
if let Some(allow) = allowed_hosts_env {
|
||||||
|
if !allow.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).any(|h| h.eq_ignore_ascii_case(&host)) {
|
||||||
|
return Err(DBError(format!("Host '{}' not allowed for image fetch (HERODB_IMAGE_ALLOWED_HOSTS)", host)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let agent: Agent = AgentBuilder::new()
|
||||||
|
.timeout_read(Duration::from_secs(timeout_secs))
|
||||||
|
.timeout_write(Duration::from_secs(timeout_secs))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let resp = agent.get(uri).call().map_err(|e| DBError(format!("HTTP GET failed: {}", e)))?;
|
||||||
|
// Validate content-type
|
||||||
|
let ctype = resp.header("Content-Type").unwrap_or("");
|
||||||
|
let ctype_main = ctype.split(';').next().unwrap_or("").trim().to_ascii_lowercase();
|
||||||
|
if !ctype_main.starts_with("image/") {
|
||||||
|
return Err(DBError(format!("Remote content-type '{}' is not image/*", ctype)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read with cap
|
||||||
|
let mut reader = resp.into_reader();
|
||||||
|
let mut buf: Vec<u8> = Vec::with_capacity(8192);
|
||||||
|
let mut tmp = [0u8; 8192];
|
||||||
|
let mut total: u64 = 0;
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut tmp).map_err(|e| DBError(format!("Read error: {}", e)))?;
|
||||||
|
if n == 0 { break; }
|
||||||
|
total += n as u64;
|
||||||
|
if total > max_bytes {
|
||||||
|
return Err(DBError(format!("Image exceeds max allowed bytes {}", max_bytes)));
|
||||||
|
}
|
||||||
|
buf.extend_from_slice(&tmp[..n]);
|
||||||
|
}
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if current permissions allow read operations
|
/// Check if current permissions allow read operations
|
||||||
pub fn has_read_permission(&self) -> bool {
|
pub fn has_read_permission(&self) -> bool {
|
||||||
// If an explicit permission is set for this connection, honor it.
|
// If an explicit permission is set for this connection, honor it.
|
||||||
|
Reference in New Issue
Block a user