...
This commit is contained in:
165
src/cmd.rs
165
src/cmd.rs
@@ -1,6 +1,8 @@
|
||||
use crate::{error::DBError, protocol::Protocol, server::Server};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use futures::future::select_all;
|
||||
use std::sync::Arc;
|
||||
use base64::Engine;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Cmd {
|
||||
@@ -1006,11 +1008,11 @@ impl Cmd {
|
||||
Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await),
|
||||
|
||||
// Lance vector database commands
|
||||
Cmd::LanceCreate { dataset, dim, schema } => lance_create_cmd(server, &dataset, *dim, &schema).await,
|
||||
Cmd::LanceStore { dataset, text, image_base64, metadata } => lance_store_cmd(server, &dataset, text.as_deref(), image_base64.as_deref(), metadata).await,
|
||||
Cmd::LanceSearch { dataset, vector, k, nprobes, refine_factor } => lance_search_cmd(server, &dataset, vector, *k, nprobes, refine_factor).await,
|
||||
Cmd::LanceSearchText { dataset, query_text, k, nprobes, refine_factor } => lance_search_text_cmd(server, &dataset, &query_text, *k, nprobes, refine_factor).await,
|
||||
Cmd::LanceEmbedText { texts } => lance_embed_text_cmd(server, texts).await,
|
||||
Cmd::LanceCreate { dataset, dim, schema } => lance_create_cmd(server, &dataset, dim, &schema).await,
|
||||
Cmd::LanceStore { dataset, text, image_base64, metadata } => lance_store_cmd(server, &dataset, text.as_deref(), image_base64.as_deref(), &metadata).await,
|
||||
Cmd::LanceSearch { dataset, vector, k, nprobes, refine_factor } => lance_search_cmd(server, &dataset, &vector, k, nprobes, refine_factor).await,
|
||||
Cmd::LanceSearchText { dataset, query_text, k, nprobes, refine_factor } => lance_search_text_cmd(server, &dataset, &query_text, k, nprobes, refine_factor).await,
|
||||
Cmd::LanceEmbedText { texts } => lance_embed_text_cmd(server, &texts).await,
|
||||
Cmd::LanceCreateIndex { dataset, index_type, num_partitions, num_sub_vectors } => lance_create_index_cmd(server, &dataset, &index_type, num_partitions, num_sub_vectors).await,
|
||||
Cmd::LanceList => lance_list_cmd(server).await,
|
||||
Cmd::LanceDrop { dataset } => lance_drop_cmd(server, &dataset).await,
|
||||
@@ -1800,6 +1802,36 @@ fn command_cmd(args: &[String]) -> Result<Protocol, DBError> {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create Arrow schema from field specifications
|
||||
fn create_schema_from_fields(dim: usize, fields: &[(String, String)]) -> arrow::datatypes::Schema {
|
||||
let mut schema_fields = Vec::new();
|
||||
|
||||
// Always add the vector field first
|
||||
let vector_field = arrow::datatypes::Field::new(
|
||||
"vector",
|
||||
arrow::datatypes::DataType::FixedSizeList(
|
||||
Arc::new(arrow::datatypes::Field::new("item", arrow::datatypes::DataType::Float32, true)),
|
||||
dim as i32
|
||||
),
|
||||
false
|
||||
);
|
||||
schema_fields.push(vector_field);
|
||||
|
||||
// Add custom fields
|
||||
for (name, field_type) in fields {
|
||||
let data_type = match field_type.to_lowercase().as_str() {
|
||||
"string" | "text" => arrow::datatypes::DataType::Utf8,
|
||||
"int" | "integer" => arrow::datatypes::DataType::Int64,
|
||||
"float" => arrow::datatypes::DataType::Float64,
|
||||
"bool" | "boolean" => arrow::datatypes::DataType::Boolean,
|
||||
_ => arrow::datatypes::DataType::Utf8, // Default to string
|
||||
};
|
||||
schema_fields.push(arrow::datatypes::Field::new(name, data_type, true));
|
||||
}
|
||||
|
||||
arrow::datatypes::Schema::new(schema_fields)
|
||||
}
|
||||
|
||||
// Lance vector database command implementations
|
||||
async fn lance_create_cmd(
|
||||
server: &Server,
|
||||
@@ -1809,12 +1841,12 @@ async fn lance_create_cmd(
|
||||
) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.create_dataset(dataset, dim, schema.to_vec()).await {
|
||||
match lance_store.create_dataset(dataset, create_schema_from_fields(dim, schema)).await {
|
||||
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1827,12 +1859,14 @@ async fn lance_store_cmd(
|
||||
) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.store_data(dataset, text, image_base64, metadata.clone()).await {
|
||||
match lance_store.store_multimodal(server, dataset, text.map(|s| s.to_string()),
|
||||
image_base64.and_then(|s| base64::engine::general_purpose::STANDARD.decode(s).ok()),
|
||||
metadata.clone()).await {
|
||||
Ok(id) => Ok(Protocol::BulkString(id)),
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1846,24 +1880,14 @@ async fn lance_search_cmd(
|
||||
) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.search_vector(dataset, vector, k, nprobes, refine_factor).await {
|
||||
match lance_store.search_vectors(dataset, vector.to_vec(), k, nprobes, refine_factor).await {
|
||||
Ok(results) => {
|
||||
let mut response = Vec::new();
|
||||
for result in results {
|
||||
for (distance, metadata) in results {
|
||||
let mut item = Vec::new();
|
||||
item.push(Protocol::BulkString("id".to_string()));
|
||||
item.push(Protocol::BulkString(result.id));
|
||||
item.push(Protocol::BulkString("score".to_string()));
|
||||
item.push(Protocol::BulkString(result.score.to_string()));
|
||||
if let Some(text) = result.text {
|
||||
item.push(Protocol::BulkString("text".to_string()));
|
||||
item.push(Protocol::BulkString(text));
|
||||
}
|
||||
if let Some(image) = result.image_base64 {
|
||||
item.push(Protocol::BulkString("image".to_string()));
|
||||
item.push(Protocol::BulkString(image));
|
||||
}
|
||||
for (key, value) in result.metadata {
|
||||
item.push(Protocol::BulkString("distance".to_string()));
|
||||
item.push(Protocol::BulkString(distance.to_string()));
|
||||
for (key, value) in metadata {
|
||||
item.push(Protocol::BulkString(key));
|
||||
item.push(Protocol::BulkString(value));
|
||||
}
|
||||
@@ -1871,10 +1895,10 @@ async fn lance_search_cmd(
|
||||
}
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1888,24 +1912,14 @@ async fn lance_search_text_cmd(
|
||||
) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.search_text(dataset, query_text, k, nprobes, refine_factor).await {
|
||||
match lance_store.search_with_text(server, dataset, query_text.to_string(), k, nprobes, refine_factor).await {
|
||||
Ok(results) => {
|
||||
let mut response = Vec::new();
|
||||
for result in results {
|
||||
for (distance, metadata) in results {
|
||||
let mut item = Vec::new();
|
||||
item.push(Protocol::BulkString("id".to_string()));
|
||||
item.push(Protocol::BulkString(result.id));
|
||||
item.push(Protocol::BulkString("score".to_string()));
|
||||
item.push(Protocol::BulkString(result.score.to_string()));
|
||||
if let Some(text) = result.text {
|
||||
item.push(Protocol::BulkString("text".to_string()));
|
||||
item.push(Protocol::BulkString(text));
|
||||
}
|
||||
if let Some(image) = result.image_base64 {
|
||||
item.push(Protocol::BulkString("image".to_string()));
|
||||
item.push(Protocol::BulkString(image));
|
||||
}
|
||||
for (key, value) in result.metadata {
|
||||
item.push(Protocol::BulkString("distance".to_string()));
|
||||
item.push(Protocol::BulkString(distance.to_string()));
|
||||
for (key, value) in metadata {
|
||||
item.push(Protocol::BulkString(key));
|
||||
item.push(Protocol::BulkString(value));
|
||||
}
|
||||
@@ -1913,10 +1927,26 @@ async fn lance_search_text_cmd(
|
||||
}
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to sanitize error messages for Redis protocol
|
||||
fn sanitize_error_message(msg: &str) -> String {
|
||||
// Remove newlines, carriage returns, and limit length
|
||||
let sanitized = msg
|
||||
.replace('\n', " ")
|
||||
.replace('\r', " ")
|
||||
.replace('\t', " ");
|
||||
|
||||
// Limit to 200 characters to avoid overly long error messages
|
||||
if sanitized.len() > 200 {
|
||||
format!("{}...", &sanitized[..197])
|
||||
} else {
|
||||
sanitized
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1926,7 +1956,7 @@ async fn lance_embed_text_cmd(
|
||||
) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.embed_texts(texts).await {
|
||||
match lance_store.embed_text(server, texts.to_vec()).await {
|
||||
Ok(embeddings) => {
|
||||
let mut response = Vec::new();
|
||||
for embedding in embeddings {
|
||||
@@ -1938,10 +1968,10 @@ async fn lance_embed_text_cmd(
|
||||
}
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1956,10 +1986,10 @@ async fn lance_create_index_cmd(
|
||||
Ok(lance_store) => {
|
||||
match lance_store.create_index(dataset, index_type, num_partitions, num_sub_vectors).await {
|
||||
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1974,10 +2004,10 @@ async fn lance_list_cmd(server: &Server) -> Result<Protocol, DBError> {
|
||||
.collect();
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1986,41 +2016,28 @@ async fn lance_drop_cmd(server: &Server, dataset: &str) -> Result<Protocol, DBEr
|
||||
Ok(lance_store) => {
|
||||
match lance_store.drop_dataset(dataset).await {
|
||||
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn lance_info_cmd(server: &Server, dataset: &str) -> Result<Protocol, DBError> {
|
||||
match server.lance_store() {
|
||||
Ok(lance_store) => {
|
||||
match lance_store.dataset_info(dataset).await {
|
||||
match lance_store.get_dataset_info(dataset).await {
|
||||
Ok(info) => {
|
||||
let mut response = Vec::new();
|
||||
response.push(Protocol::BulkString("name".to_string()));
|
||||
response.push(Protocol::BulkString(info.name));
|
||||
response.push(Protocol::BulkString("dimension".to_string()));
|
||||
response.push(Protocol::BulkString(info.dimension.to_string()));
|
||||
response.push(Protocol::BulkString("num_rows".to_string()));
|
||||
response.push(Protocol::BulkString(info.num_rows.to_string()));
|
||||
response.push(Protocol::BulkString("schema".to_string()));
|
||||
let schema_items: Vec<Protocol> = info.schema
|
||||
.into_iter()
|
||||
.map(|(field, field_type)| {
|
||||
Protocol::Array(vec![
|
||||
Protocol::BulkString(field),
|
||||
Protocol::BulkString(field_type),
|
||||
])
|
||||
})
|
||||
.collect();
|
||||
response.push(Protocol::Array(schema_items));
|
||||
for (key, value) in info {
|
||||
response.push(Protocol::BulkString(key));
|
||||
response.push(Protocol::BulkString(value));
|
||||
}
|
||||
Ok(Protocol::Array(response))
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR {}", e)))),
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(Protocol::err(&format!("ERR Lance store not available: {}", e))),
|
||||
Err(e) => Ok(Protocol::err(&sanitize_error_message(&format!("ERR Lance store not available: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
@@ -3,9 +3,10 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use arrow::array::{Float32Array, StringArray, ArrayRef, FixedSizeListArray};
|
||||
use arrow::datatypes::{DataType, Field, Schema, FieldRef};
|
||||
use arrow::record_batch::RecordBatch;
|
||||
use arrow::array::{Float32Array, StringArray, ArrayRef, FixedSizeListArray, Array};
|
||||
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
|
||||
use arrow::record_batch::{RecordBatch, RecordBatchReader};
|
||||
use arrow::error::ArrowError;
|
||||
use lance::dataset::{Dataset, WriteParams, WriteMode};
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance_index::vector::pq::PQBuildParams;
|
||||
@@ -13,10 +14,39 @@ use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_index::DatasetIndexExt;
|
||||
use lance_linalg::distance::MetricType;
|
||||
use futures::TryStreamExt;
|
||||
use base64::Engine;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::error::DBError;
|
||||
use crate::protocol::Protocol;
|
||||
|
||||
// Simple RecordBatchReader implementation for Vec<RecordBatch>
|
||||
struct VecRecordBatchReader {
|
||||
batches: std::vec::IntoIter<Result<RecordBatch, ArrowError>>,
|
||||
}
|
||||
|
||||
impl VecRecordBatchReader {
|
||||
fn new(batches: Vec<RecordBatch>) -> Self {
|
||||
let result_batches = batches.into_iter().map(Ok).collect::<Vec<_>>();
|
||||
Self {
|
||||
batches: result_batches.into_iter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for VecRecordBatchReader {
|
||||
type Item = Result<RecordBatch, ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.batches.next()
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchReader for VecRecordBatchReader {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
// This is a simplified implementation - in practice you'd want to store the schema
|
||||
Arc::new(Schema::empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct EmbeddingRequest {
|
||||
@@ -32,6 +62,18 @@ struct EmbeddingResponse {
|
||||
usage: Option<HashMap<String, u32>>,
|
||||
}
|
||||
|
||||
// Ollama-specific request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaEmbeddingRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaEmbeddingResponse {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
pub struct LanceStore {
|
||||
datasets: Arc<RwLock<HashMap<String, Arc<Dataset>>>>,
|
||||
data_dir: PathBuf,
|
||||
@@ -56,64 +98,104 @@ impl LanceStore {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get embedding service URL from Redis config
|
||||
/// Get embedding service URL from Redis config, default to local Ollama
|
||||
async fn get_embedding_url(&self, server: &crate::server::Server) -> Result<String, DBError> {
|
||||
// Get the embedding URL from Redis config
|
||||
let key = "config:core:aiembed:url";
|
||||
|
||||
// Use HGET to retrieve the URL from Redis hash
|
||||
let cmd = crate::cmd::Cmd::HGet(key.to_string(), "url".to_string());
|
||||
|
||||
// Execute command to get the config
|
||||
let result = cmd.run(&mut server.clone()).await?;
|
||||
|
||||
match result {
|
||||
Protocol::BulkString(url) => Ok(url),
|
||||
Protocol::SimpleString(url) => Ok(url),
|
||||
Protocol::Nil => Err(DBError(
|
||||
"Embedding service URL not configured. Set it with: HSET config:core:aiembed:url url <YOUR_EMBEDDING_SERVICE_URL>".to_string()
|
||||
)),
|
||||
_ => Err(DBError("Invalid embedding URL configuration".to_string())),
|
||||
// Get the embedding URL from Redis config directly from storage
|
||||
let storage = server.current_storage()?;
|
||||
match storage.hget("config:core:aiembed", "url")? {
|
||||
Some(url) => Ok(url),
|
||||
None => Ok("http://localhost:11434".to_string()), // Default to local Ollama
|
||||
}
|
||||
}
|
||||
|
||||
/// Call external embedding service
|
||||
/// Check if we're using Ollama (default) or custom embedding service
|
||||
async fn is_ollama_service(&self, server: &crate::server::Server) -> Result<bool, DBError> {
|
||||
let url = self.get_embedding_url(server).await?;
|
||||
Ok(url.contains("localhost:11434") || url.contains("127.0.0.1:11434"))
|
||||
}
|
||||
|
||||
/// Call external embedding service (Ollama or custom)
|
||||
async fn call_embedding_service(
|
||||
&self,
|
||||
server: &crate::server::Server,
|
||||
texts: Option<Vec<String>>,
|
||||
images: Option<Vec<String>>,
|
||||
) -> Result<Vec<Vec<f32>>, DBError> {
|
||||
let url = self.get_embedding_url(server).await?;
|
||||
let base_url = self.get_embedding_url(server).await?;
|
||||
let is_ollama = self.is_ollama_service(server).await?;
|
||||
|
||||
let request = EmbeddingRequest {
|
||||
texts,
|
||||
images,
|
||||
model: None, // Let the service use its default
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to call embedding service: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(DBError(format!(
|
||||
"Embedding service returned error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
if is_ollama {
|
||||
// Use Ollama API format
|
||||
if let Some(texts) = texts {
|
||||
let mut embeddings = Vec::new();
|
||||
for text in texts {
|
||||
let url = format!("{}/api/embeddings", base_url);
|
||||
let request = OllamaEmbeddingRequest {
|
||||
model: "nomic-embed-text".to_string(),
|
||||
prompt: text,
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to call Ollama embedding service: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(DBError(format!(
|
||||
"Ollama embedding service returned error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let ollama_response: OllamaEmbeddingResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to parse Ollama embedding response: {}", e)))?;
|
||||
|
||||
embeddings.push(ollama_response.embedding);
|
||||
}
|
||||
Ok(embeddings)
|
||||
} else if let Some(_images) = images {
|
||||
// Ollama doesn't support image embeddings with this API yet
|
||||
Err(DBError("Image embeddings not supported with Ollama. Please configure a custom embedding service.".to_string()))
|
||||
} else {
|
||||
Err(DBError("No text or images provided for embedding".to_string()))
|
||||
}
|
||||
} else {
|
||||
// Use custom embedding service API format
|
||||
let request = EmbeddingRequest {
|
||||
texts,
|
||||
images,
|
||||
model: None, // Let the service use its default
|
||||
};
|
||||
|
||||
let response = self.http_client
|
||||
.post(&base_url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to call embedding service: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(DBError(format!(
|
||||
"Embedding service returned error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let embedding_response: EmbeddingResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to parse embedding response: {}", e)))?;
|
||||
|
||||
Ok(embedding_response.embeddings)
|
||||
}
|
||||
|
||||
let embedding_response: EmbeddingResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to parse embedding response: {}", e)))?;
|
||||
|
||||
Ok(embedding_response.embeddings)
|
||||
}
|
||||
|
||||
pub async fn embed_text(
|
||||
@@ -162,10 +244,11 @@ impl LanceStore {
|
||||
|
||||
// Create an empty RecordBatch with the schema
|
||||
let empty_batch = RecordBatch::new_empty(Arc::new(schema));
|
||||
let batches = vec![empty_batch];
|
||||
|
||||
// Use RecordBatchReader for Lance 0.33
|
||||
let reader = VecRecordBatchReader::new(vec![empty_batch]);
|
||||
let dataset = Dataset::write(
|
||||
batches,
|
||||
reader,
|
||||
dataset_path.to_str().unwrap(),
|
||||
Some(write_params)
|
||||
).await
|
||||
@@ -186,7 +269,7 @@ impl LanceStore {
|
||||
let dataset_path = self.data_dir.join(format!("{}.lance", dataset_name));
|
||||
|
||||
// Open or get cached dataset
|
||||
let dataset = self.get_or_open_dataset(dataset_name).await?;
|
||||
let _dataset = self.get_or_open_dataset(dataset_name).await?;
|
||||
|
||||
// Build RecordBatch
|
||||
let num_vectors = vectors.len();
|
||||
@@ -200,10 +283,13 @@ impl LanceStore {
|
||||
|
||||
// Flatten vectors
|
||||
let flat_vectors: Vec<f32> = vectors.into_iter().flatten().collect();
|
||||
let vector_array = Float32Array::from(flat_vectors);
|
||||
let vector_array = arrow::array::FixedSizeListArray::try_new_from_values(
|
||||
vector_array,
|
||||
dim as i32
|
||||
let values_array = Float32Array::from(flat_vectors);
|
||||
let field = Arc::new(Field::new("item", DataType::Float32, true));
|
||||
let vector_array = FixedSizeListArray::try_new(
|
||||
field,
|
||||
dim as i32,
|
||||
Arc::new(values_array),
|
||||
None
|
||||
).map_err(|e| DBError(format!("Failed to create vector array: {}", e)))?;
|
||||
|
||||
let mut arrays: Vec<ArrayRef> = vec![Arc::new(vector_array)];
|
||||
@@ -241,8 +327,9 @@ impl LanceStore {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let reader = VecRecordBatchReader::new(vec![batch]);
|
||||
Dataset::write(
|
||||
vec![batch],
|
||||
reader,
|
||||
dataset_path.to_str().unwrap(),
|
||||
Some(write_params)
|
||||
).await
|
||||
@@ -261,25 +348,27 @@ impl LanceStore {
|
||||
query_vector: Vec<f32>,
|
||||
k: usize,
|
||||
nprobes: Option<usize>,
|
||||
refine_factor: Option<usize>,
|
||||
_refine_factor: Option<usize>,
|
||||
) -> Result<Vec<(f32, HashMap<String, String>)>, DBError> {
|
||||
let dataset = self.get_or_open_dataset(dataset_name).await?;
|
||||
|
||||
// Build query
|
||||
let query_array = Float32Array::from(query_vector.clone());
|
||||
let mut query = dataset.scan();
|
||||
query = query.nearest(
|
||||
query.nearest(
|
||||
"vector",
|
||||
&query_vector,
|
||||
&query_array,
|
||||
k,
|
||||
).map_err(|e| DBError(format!("Failed to build search query: {}", e)))?;
|
||||
|
||||
if let Some(nprobes) = nprobes {
|
||||
query = query.nprobes(nprobes);
|
||||
query.nprobs(nprobes);
|
||||
}
|
||||
|
||||
if let Some(refine) = refine_factor {
|
||||
query = query.refine_factor(refine);
|
||||
}
|
||||
// Note: refine_factor might not be available in this Lance version
|
||||
// if let Some(refine) = refine_factor {
|
||||
// query.refine_factor(refine);
|
||||
// }
|
||||
|
||||
// Execute search
|
||||
let results = query
|
||||
@@ -399,33 +488,41 @@ impl LanceStore {
|
||||
num_partitions: Option<usize>,
|
||||
num_sub_vectors: Option<usize>,
|
||||
) -> Result<(), DBError> {
|
||||
let dataset = self.get_or_open_dataset(dataset_name).await?;
|
||||
|
||||
let mut params = VectorIndexParams::default();
|
||||
let _dataset = self.get_or_open_dataset(dataset_name).await?;
|
||||
|
||||
match index_type.to_uppercase().as_str() {
|
||||
"IVF_PQ" => {
|
||||
params.ivf = IvfBuildParams {
|
||||
let ivf_params = IvfBuildParams {
|
||||
num_partitions: num_partitions.unwrap_or(256),
|
||||
..Default::default()
|
||||
};
|
||||
params.pq = PQBuildParams {
|
||||
let pq_params = PQBuildParams {
|
||||
num_sub_vectors: num_sub_vectors.unwrap_or(16),
|
||||
..Default::default()
|
||||
};
|
||||
let params = VectorIndexParams::with_ivf_pq_params(
|
||||
MetricType::L2,
|
||||
ivf_params,
|
||||
pq_params,
|
||||
);
|
||||
|
||||
// Get a mutable reference to the dataset
|
||||
let mut dataset_mut = Dataset::open(self.data_dir.join(format!("{}.lance", dataset_name)).to_str().unwrap())
|
||||
.await
|
||||
.map_err(|e| DBError(format!("Failed to open dataset for indexing: {}", e)))?;
|
||||
|
||||
dataset_mut.create_index(
|
||||
&["vector"],
|
||||
lance_index::IndexType::Vector,
|
||||
None,
|
||||
¶ms,
|
||||
true
|
||||
).await
|
||||
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
|
||||
}
|
||||
_ => return Err(DBError(format!("Unsupported index type: {}", index_type))),
|
||||
}
|
||||
|
||||
dataset.create_index(
|
||||
&["vector"],
|
||||
lance::index::IndexType::Vector,
|
||||
None,
|
||||
¶ms,
|
||||
true
|
||||
).await
|
||||
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -496,14 +593,14 @@ impl LanceStore {
|
||||
|
||||
let mut info = HashMap::new();
|
||||
info.insert("name".to_string(), name.to_string());
|
||||
info.insert("version".to_string(), dataset.version().to_string());
|
||||
info.insert("num_rows".to_string(), dataset.count_rows().await?.to_string());
|
||||
info.insert("version".to_string(), dataset.version().version.to_string());
|
||||
info.insert("num_rows".to_string(), dataset.count_rows(None).await?.to_string());
|
||||
|
||||
// Get schema info
|
||||
let schema = dataset.schema();
|
||||
let fields: Vec<String> = schema.fields()
|
||||
let fields: Vec<String> = schema.fields
|
||||
.iter()
|
||||
.map(|f| format!("{}:{}", f.name(), f.data_type()))
|
||||
.map(|f| format!("{}:{}", f.name, f.data_type()))
|
||||
.collect();
|
||||
info.insert("schema".to_string(), fields.join(", "));
|
||||
|
||||
|
Reference in New Issue
Block a user