...
This commit is contained in:
227
packages/ai/codemonkey/src/lib.rs
Normal file
227
packages/ai/codemonkey/src/lib.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use async_trait::async_trait;
|
||||
use openrouter_rs::{OpenRouterClient, api::chat::*, types::Role, ChatCompletionResponse}; // Added ChatCompletionResponse here
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
|
||||
// Re-export Message and MessageRole for easier use in client code
|
||||
pub use openrouter_rs::api::chat::Message;
|
||||
pub use openrouter_rs::types::Role as MessageRole;
|
||||
// Removed the problematic import for ChatCompletionResponse
|
||||
// pub use openrouter_rs::api::chat::chat_completion::ChatCompletionResponse;
|
||||
|
||||
#[async_trait]
|
||||
pub trait AIProvider {
|
||||
async fn completion(
|
||||
&mut self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn Error>>;
|
||||
}
|
||||
|
||||
pub struct CompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_tokens: Option<i64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub stream: Option<bool>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
pub struct CompletionRequestBuilder<'a> {
|
||||
provider: &'a mut dyn AIProvider,
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f64>,
|
||||
max_tokens: Option<i64>,
|
||||
top_p: Option<f64>,
|
||||
stream: Option<bool>,
|
||||
stop: Option<Vec<String>>,
|
||||
provider_type: AIProviderType,
|
||||
}
|
||||
|
||||
impl<'a> CompletionRequestBuilder<'a> {
|
||||
pub fn new(provider: &'a mut dyn AIProvider, model: String, messages: Vec<Message>, provider_type: AIProviderType) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
model,
|
||||
messages,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
top_p: None,
|
||||
stream: None,
|
||||
stop: None,
|
||||
provider_type,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens(mut self, max_tokens: i64) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, top_p: f64) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop(mut self, stop: Vec<String>) -> Self {
|
||||
self.stop = Some(stop);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn completion(self) -> Result<ChatCompletionResponse, Box<dyn Error>> {
|
||||
let request = CompletionRequest {
|
||||
model: self.model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
max_tokens: self.max_tokens,
|
||||
top_p: self.top_p,
|
||||
stream: self.stream,
|
||||
stop: self.stop,
|
||||
};
|
||||
self.provider.completion(request).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GroqAIProvider {
|
||||
client: OpenRouterClient,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AIProvider for GroqAIProvider {
|
||||
async fn completion(
|
||||
&mut self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn Error>> {
|
||||
let chat_request = ChatCompletionRequest::builder()
|
||||
.model(request.model)
|
||||
.messages(request.messages)
|
||||
.temperature(request.temperature.unwrap_or(1.0))
|
||||
.max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
|
||||
.top_p(request.top_p.unwrap_or(1.0))
|
||||
.stream(request.stream.unwrap_or(false)) // Corrected to field assignment
|
||||
.stop(request.stop.unwrap_or_default())
|
||||
.build()?;
|
||||
|
||||
let result = self.client.send_chat_completion(&chat_request).await?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
client: OpenRouterClient,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AIProvider for OpenAIProvider {
|
||||
async fn completion(
|
||||
&mut self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn Error>> {
|
||||
let chat_request = ChatCompletionRequest::builder()
|
||||
.model(request.model)
|
||||
.messages(request.messages)
|
||||
.temperature(request.temperature.unwrap_or(1.0))
|
||||
.max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
|
||||
.top_p(request.top_p.unwrap_or(1.0))
|
||||
.stream(request.stream.unwrap_or(false)) // Corrected to field assignment
|
||||
.stop(request.stop.unwrap_or_default())
|
||||
.build()?;
|
||||
|
||||
let result = self.client.send_chat_completion(&chat_request).await?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenRouterAIProvider {
|
||||
client: OpenRouterClient,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AIProvider for OpenRouterAIProvider {
|
||||
async fn completion(
|
||||
&mut self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn Error>> {
|
||||
let chat_request = ChatCompletionRequest::builder()
|
||||
.model(request.model)
|
||||
.messages(request.messages)
|
||||
.temperature(request.temperature.unwrap_or(1.0))
|
||||
.max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
|
||||
.top_p(request.top_p.unwrap_or(1.0))
|
||||
.stream(request.stream.unwrap_or(false)) // Corrected to field assignment
|
||||
.stop(request.stop.unwrap_or_default())
|
||||
.build()?;
|
||||
|
||||
let result = self.client.send_chat_completion(&chat_request).await?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CerebrasAIProvider {
|
||||
client: OpenRouterClient,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AIProvider for CerebrasAIProvider {
|
||||
async fn completion(
|
||||
&mut self,
|
||||
request: CompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn Error>> {
|
||||
let chat_request = ChatCompletionRequest::builder()
|
||||
.model(request.model)
|
||||
.messages(request.messages)
|
||||
.temperature(request.temperature.unwrap_or(1.0))
|
||||
.max_tokens(request.max_tokens.map(|x| x as u32).unwrap_or(2048))
|
||||
.top_p(request.top_p.unwrap_or(1.0))
|
||||
.stream(request.stream.unwrap_or(false)) // Corrected to field assignment
|
||||
.stop(request.stop.unwrap_or_default())
|
||||
.build()?;
|
||||
|
||||
let result = self.client.send_chat_completion(&chat_request).await?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum AIProviderType {
|
||||
Groq,
|
||||
OpenAI,
|
||||
OpenRouter,
|
||||
Cerebras,
|
||||
}
|
||||
|
||||
pub fn create_ai_provider(provider_type: AIProviderType) -> Result<(Box<dyn AIProvider>, AIProviderType), Box<dyn Error>> {
|
||||
match provider_type {
|
||||
AIProviderType::Groq => {
|
||||
let api_key = env::var("GROQ_API_KEY")?;
|
||||
let client = OpenRouterClient::builder().api_key(api_key).build()?;
|
||||
Ok((Box::new(GroqAIProvider { client }), AIProviderType::Groq))
|
||||
}
|
||||
AIProviderType::OpenAI => {
|
||||
let api_key = env::var("OPENAI_API_KEY")?;
|
||||
let client = OpenRouterClient::builder().api_key(api_key).build()?;
|
||||
Ok((Box::new(OpenAIProvider { client }), AIProviderType::OpenAI))
|
||||
}
|
||||
AIProviderType::OpenRouter => {
|
||||
let api_key = env::var("OPENROUTER_API_KEY")?;
|
||||
let client = OpenRouterClient::builder().api_key(api_key).build()?;
|
||||
Ok((Box::new(OpenRouterAIProvider { client }), AIProviderType::OpenRouter))
|
||||
}
|
||||
AIProviderType::Cerebras => {
|
||||
let api_key = env::var("CEREBRAS_API_KEY")?;
|
||||
let client = OpenRouterClient::builder().api_key(api_key).build()?;
|
||||
Ok((Box::new(CerebrasAIProvider { client }), AIProviderType::Cerebras))
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user