WIP2: implementing lancedb: created embedding abstraction, server-side per-dataset embedding config + updates RPC endpoints
This commit is contained in:
186
src/cmd.rs
186
src/cmd.rs
@@ -1,4 +1,4 @@
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server};
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server, embedding::{EmbeddingConfig, EmbeddingProvider}};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use futures::future::select_all;
|
||||
|
||||
@@ -127,20 +127,20 @@ pub enum Cmd {
|
||||
reducers: Vec<String>,
|
||||
},
|
||||
|
||||
// LanceDB vector search commands
|
||||
// LanceDB text-first commands (no user-provided vectors)
|
||||
LanceCreate {
|
||||
name: String,
|
||||
dim: usize,
|
||||
},
|
||||
LanceStore {
|
||||
LanceStoreText {
|
||||
name: String,
|
||||
id: String,
|
||||
vector: Vec<f32>,
|
||||
text: String,
|
||||
meta: Vec<(String, String)>,
|
||||
},
|
||||
LanceSearch {
|
||||
LanceSearchText {
|
||||
name: String,
|
||||
vector: Vec<f32>,
|
||||
text: String,
|
||||
k: usize,
|
||||
filter: Option<String>,
|
||||
return_fields: Option<Vec<String>>,
|
||||
@@ -150,6 +150,16 @@ pub enum Cmd {
|
||||
index_type: String,
|
||||
params: Vec<(String, String)>,
|
||||
},
|
||||
// Embedding configuration per dataset
|
||||
LanceEmbeddingConfigSet {
|
||||
name: String,
|
||||
provider: String,
|
||||
model: String,
|
||||
params: Vec<(String, String)>,
|
||||
},
|
||||
LanceEmbeddingConfigGet {
|
||||
name: String,
|
||||
},
|
||||
LanceList,
|
||||
LanceInfo {
|
||||
name: String,
|
||||
@@ -862,9 +872,9 @@ impl Cmd {
|
||||
Cmd::LanceCreate { name, dim }
|
||||
}
|
||||
"lance.store" => {
|
||||
// LANCE.STORE name ID id VECTOR v1 v2 ... [META k v ...]
|
||||
// LANCE.STORE name ID <id> TEXT <text> [META k v ...]
|
||||
if cmd.len() < 6 {
|
||||
return Err(DBError("ERR LANCE.STORE requires: name ID <id> VECTOR v1 v2 ... [META k v ...]".to_string()));
|
||||
return Err(DBError("ERR LANCE.STORE requires: name ID <id> TEXT <text> [META k v ...]".to_string()));
|
||||
}
|
||||
let name = cmd[1].clone();
|
||||
let mut i = 2;
|
||||
@@ -873,16 +883,16 @@ impl Cmd {
|
||||
}
|
||||
let id = cmd[i + 1].clone();
|
||||
i += 2;
|
||||
if i >= cmd.len() || cmd[i].to_uppercase() != "VECTOR" {
|
||||
return Err(DBError("ERR LANCE.STORE requires VECTOR <f32...>".to_string()));
|
||||
if i >= cmd.len() || cmd[i].to_uppercase() != "TEXT" {
|
||||
return Err(DBError("ERR LANCE.STORE requires TEXT <text>".to_string()));
|
||||
}
|
||||
i += 1;
|
||||
let mut vector: Vec<f32> = Vec::new();
|
||||
while i < cmd.len() && cmd[i].to_uppercase() != "META" {
|
||||
let v: f32 = cmd[i].parse().map_err(|_| DBError("ERR vector element must be a float32".to_string()))?;
|
||||
vector.push(v);
|
||||
i += 1;
|
||||
if i >= cmd.len() {
|
||||
return Err(DBError("ERR LANCE.STORE requires TEXT <text>".to_string()));
|
||||
}
|
||||
let text = cmd[i].clone();
|
||||
i += 1;
|
||||
|
||||
let mut meta: Vec<(String, String)> = Vec::new();
|
||||
if i < cmd.len() && cmd[i].to_uppercase() == "META" {
|
||||
i += 1;
|
||||
@@ -891,28 +901,28 @@ impl Cmd {
|
||||
i += 2;
|
||||
}
|
||||
}
|
||||
Cmd::LanceStore { name, id, vector, meta }
|
||||
Cmd::LanceStoreText { name, id, text, meta }
|
||||
}
|
||||
"lance.search" => {
|
||||
// LANCE.SEARCH name K k VECTOR v1 v2 ... [FILTER expr] [RETURN n fields...]
|
||||
// LANCE.SEARCH name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]
|
||||
if cmd.len() < 6 {
|
||||
return Err(DBError("ERR LANCE.SEARCH requires: name K <k> VECTOR v1 v2 ... [FILTER expr] [RETURN n fields...]".to_string()));
|
||||
return Err(DBError("ERR LANCE.SEARCH requires: name K <k> QUERY <text> [FILTER expr] [RETURN n fields...]".to_string()));
|
||||
}
|
||||
let name = cmd[1].clone();
|
||||
if cmd[2].to_uppercase() != "K" {
|
||||
return Err(DBError("ERR LANCE.SEARCH requires K <k>".to_string()));
|
||||
}
|
||||
let k: usize = cmd[3].parse().map_err(|_| DBError("ERR K must be an integer".to_string()))?;
|
||||
if cmd[4].to_uppercase() != "VECTOR" {
|
||||
return Err(DBError("ERR LANCE.SEARCH requires VECTOR <f32...>".to_string()));
|
||||
if cmd[4].to_uppercase() != "QUERY" {
|
||||
return Err(DBError("ERR LANCE.SEARCH requires QUERY <text>".to_string()));
|
||||
}
|
||||
let mut i = 5;
|
||||
let mut vector: Vec<f32> = Vec::new();
|
||||
while i < cmd.len() && !["FILTER","RETURN"].contains(&cmd[i].to_uppercase().as_str()) {
|
||||
let v: f32 = cmd[i].parse().map_err(|_| DBError("ERR vector element must be a float32".to_string()))?;
|
||||
vector.push(v);
|
||||
i += 1;
|
||||
if i >= cmd.len() {
|
||||
return Err(DBError("ERR LANCE.SEARCH requires QUERY <text>".to_string()));
|
||||
}
|
||||
let text = cmd[i].clone();
|
||||
i += 1;
|
||||
|
||||
let mut filter: Option<String> = None;
|
||||
let mut return_fields: Option<Vec<String>> = None;
|
||||
while i < cmd.len() {
|
||||
@@ -942,7 +952,7 @@ impl Cmd {
|
||||
_ => { i += 1; }
|
||||
}
|
||||
}
|
||||
Cmd::LanceSearch { name, vector, k, filter, return_fields }
|
||||
Cmd::LanceSearchText { name, text, k, filter, return_fields }
|
||||
}
|
||||
"lance.createindex" => {
|
||||
// LANCE.CREATEINDEX name TYPE t [PARAM k v ...]
|
||||
@@ -962,6 +972,60 @@ impl Cmd {
|
||||
}
|
||||
Cmd::LanceCreateIndex { name, index_type, params }
|
||||
}
|
||||
"lance.embedding" => {
|
||||
// LANCE.EMBEDDING CONFIG SET name PROVIDER p MODEL m [PARAM k v ...]
|
||||
// LANCE.EMBEDDING CONFIG GET name
|
||||
if cmd.len() < 3 || cmd[1].to_uppercase() != "CONFIG" {
|
||||
return Err(DBError("ERR LANCE.EMBEDDING requires CONFIG subcommand".to_string()));
|
||||
}
|
||||
if cmd.len() >= 4 && cmd[2].to_uppercase() == "SET" {
|
||||
if cmd.len() < 8 {
|
||||
return Err(DBError("ERR LANCE.EMBEDDING CONFIG SET requires: SET name PROVIDER p MODEL m [PARAM k v ...]".to_string()));
|
||||
}
|
||||
let name = cmd[3].clone();
|
||||
let mut i = 4;
|
||||
let mut provider: Option<String> = None;
|
||||
let mut model: Option<String> = None;
|
||||
let mut params: Vec<(String, String)> = Vec::new();
|
||||
while i < cmd.len() {
|
||||
match cmd[i].to_uppercase().as_str() {
|
||||
"PROVIDER" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR PROVIDER requires a value".to_string()));
|
||||
}
|
||||
provider = Some(cmd[i + 1].clone());
|
||||
i += 2;
|
||||
}
|
||||
"MODEL" => {
|
||||
if i + 1 >= cmd.len() {
|
||||
return Err(DBError("ERR MODEL requires a value".to_string()));
|
||||
}
|
||||
model = Some(cmd[i + 1].clone());
|
||||
i += 2;
|
||||
}
|
||||
"PARAM" => {
|
||||
i += 1;
|
||||
while i + 1 < cmd.len() {
|
||||
params.push((cmd[i].clone(), cmd[i + 1].clone()));
|
||||
i += 2;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown token; break to avoid infinite loop
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
let provider = provider.ok_or_else(|| DBError("ERR missing PROVIDER".to_string()))?;
|
||||
let model = model.ok_or_else(|| DBError("ERR missing MODEL".to_string()))?;
|
||||
Cmd::LanceEmbeddingConfigSet { name, provider, model, params }
|
||||
} else if cmd.len() == 4 && cmd[2].to_uppercase() == "GET" {
|
||||
let name = cmd[3].clone();
|
||||
Cmd::LanceEmbeddingConfigGet { name }
|
||||
} else {
|
||||
return Err(DBError("ERR LANCE.EMBEDDING CONFIG supports: SET ... | GET name".to_string()));
|
||||
}
|
||||
}
|
||||
"lance.list" => {
|
||||
if cmd.len() != 1 {
|
||||
return Err(DBError("ERR LANCE.LIST takes no arguments".to_string()));
|
||||
@@ -1070,8 +1134,10 @@ impl Cmd {
|
||||
| Cmd::Command(..)
|
||||
| Cmd::Info(..)
|
||||
| Cmd::LanceCreate { .. }
|
||||
| Cmd::LanceStore { .. }
|
||||
| Cmd::LanceSearch { .. }
|
||||
| Cmd::LanceStoreText { .. }
|
||||
| Cmd::LanceSearchText { .. }
|
||||
| Cmd::LanceEmbeddingConfigSet { .. }
|
||||
| Cmd::LanceEmbeddingConfigGet { .. }
|
||||
| Cmd::LanceCreateIndex { .. }
|
||||
| Cmd::LanceList
|
||||
| Cmd::LanceInfo { .. }
|
||||
@@ -1104,8 +1170,10 @@ impl Cmd {
|
||||
if !is_lance_backend {
|
||||
match &self {
|
||||
Cmd::LanceCreate { .. }
|
||||
| Cmd::LanceStore { .. }
|
||||
| Cmd::LanceSearch { .. }
|
||||
| Cmd::LanceStoreText { .. }
|
||||
| Cmd::LanceSearchText { .. }
|
||||
| Cmd::LanceEmbeddingConfigSet { .. }
|
||||
| Cmd::LanceEmbeddingConfigGet { .. }
|
||||
| Cmd::LanceCreateIndex { .. }
|
||||
| Cmd::LanceList
|
||||
| Cmd::LanceInfo { .. }
|
||||
@@ -1249,18 +1317,66 @@ impl Cmd {
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
Cmd::LanceStore { name, id, vector, meta } => {
|
||||
Cmd::LanceEmbeddingConfigSet { name, provider, model, params } => {
|
||||
if !server.has_write_permission() {
|
||||
return Ok(Protocol::err("ERR write permission denied"));
|
||||
}
|
||||
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
|
||||
match server.lance_store()?.store_vector(&name, &id, vector, meta_map).await {
|
||||
// Map provider string to enum
|
||||
let p_lc = provider.to_lowercase();
|
||||
let prov = match p_lc.as_str() {
|
||||
"test-hash" | "testhash" => EmbeddingProvider::TestHash,
|
||||
"fastembed" | "lancefastembed" => EmbeddingProvider::LanceFastEmbed,
|
||||
"openai" | "lanceopenai" => EmbeddingProvider::LanceOpenAI,
|
||||
other => EmbeddingProvider::LanceOther(other.to_string()),
|
||||
};
|
||||
let cfg = EmbeddingConfig {
|
||||
provider: prov,
|
||||
model,
|
||||
params: params.into_iter().collect(),
|
||||
};
|
||||
match server.set_dataset_embedding_config(&name, &cfg) {
|
||||
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
Cmd::LanceSearch { name, vector, k, filter, return_fields } => {
|
||||
match server.lance_store()?.search_vectors(&name, vector, k, filter, return_fields).await {
|
||||
Cmd::LanceEmbeddingConfigGet { name } => {
|
||||
match server.get_dataset_embedding_config(&name) {
|
||||
Ok(cfg) => {
|
||||
let mut arr = Vec::new();
|
||||
arr.push(Protocol::BulkString("provider".to_string()));
|
||||
arr.push(Protocol::BulkString(match cfg.provider {
|
||||
EmbeddingProvider::TestHash => "test-hash".to_string(),
|
||||
EmbeddingProvider::LanceFastEmbed => "lancefastembed".to_string(),
|
||||
EmbeddingProvider::LanceOpenAI => "lanceopenai".to_string(),
|
||||
EmbeddingProvider::LanceOther(ref s) => s.clone(),
|
||||
}));
|
||||
arr.push(Protocol::BulkString("model".to_string()));
|
||||
arr.push(Protocol::BulkString(cfg.model.clone()));
|
||||
arr.push(Protocol::BulkString("params".to_string()));
|
||||
arr.push(Protocol::BulkString(serde_json::to_string(&cfg.params).unwrap_or_else(|_| "{}".to_string())));
|
||||
Ok(Protocol::Array(arr))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
Cmd::LanceStoreText { name, id, text, meta } => {
|
||||
if !server.has_write_permission() {
|
||||
return Ok(Protocol::err("ERR write permission denied"));
|
||||
}
|
||||
// Resolve embedder and embed text
|
||||
let embedder = server.get_embedder_for(&name)?;
|
||||
let vector = embedder.embed(&text)?;
|
||||
let meta_map: std::collections::HashMap<String, String> = meta.into_iter().collect();
|
||||
match server.lance_store()?.store_vector(&name, &id, vector, meta_map, Some(text)).await {
|
||||
Ok(()) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&e.0)),
|
||||
}
|
||||
}
|
||||
Cmd::LanceSearchText { name, text, k, filter, return_fields } => {
|
||||
// Resolve embedder and embed query text
|
||||
let embedder = server.get_embedder_for(&name)?;
|
||||
let qv = embedder.embed(&text)?;
|
||||
match server.lance_store()?.search_vectors(&name, qv, k, filter, return_fields).await {
|
||||
Ok(results) => {
|
||||
// Encode as array of [id, score, [k1, v1, k2, v2, ...]]
|
||||
let mut arr = Vec::new();
|
||||
|
Reference in New Issue
Block a user