feat: Enhance request management in SigSocket client with new methods and structures
This commit is contained in:
		| @@ -1,9 +1,18 @@ | ||||
| //! Main client interface for sigsocket communication | ||||
|  | ||||
| #[cfg(target_arch = "wasm32")] | ||||
| use alloc::{string::String, vec::Vec, boxed::Box}; | ||||
| use alloc::{string::String, vec::Vec, boxed::Box, string::ToString}; | ||||
|  | ||||
| #[cfg(not(target_arch = "wasm32"))] | ||||
| use std::collections::HashMap; | ||||
|  | ||||
| #[cfg(target_arch = "wasm32")] | ||||
| use alloc::collections::BTreeMap as HashMap; | ||||
|  | ||||
| use crate::{SignRequest, SignResponse, Result, SigSocketError}; | ||||
| use crate::protocol::ManagedSignRequest; | ||||
|  | ||||
|  | ||||
|  | ||||
| /// Connection state of the sigsocket client | ||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||||
| @@ -67,6 +76,10 @@ pub struct SigSocketClient { | ||||
|     state: ConnectionState, | ||||
|     /// Sign request handler | ||||
|     sign_handler: Option<Box<dyn SignRequestHandler>>, | ||||
|     /// Pending sign requests managed by the client | ||||
|     pending_requests: HashMap<String, ManagedSignRequest>, | ||||
|     /// Connected public key (hex-encoded) - set when connection is established | ||||
|     connected_public_key: Option<String>, | ||||
|     /// Platform-specific implementation | ||||
|     #[cfg(not(target_arch = "wasm32"))] | ||||
|     inner: Option<crate::native::NativeClient>, | ||||
| @@ -100,14 +113,16 @@ impl SigSocketClient { | ||||
|             public_key, | ||||
|             state: ConnectionState::Disconnected, | ||||
|             sign_handler: None, | ||||
|             pending_requests: HashMap::new(), | ||||
|             connected_public_key: None, | ||||
|             inner: None, | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     /// Set the sign request handler | ||||
|     ///  | ||||
|     /// | ||||
|     /// This handler will be called whenever the server sends a signature request. | ||||
|     ///  | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `handler` - Implementation of SignRequestHandler trait | ||||
|     pub fn set_sign_handler<H>(&mut self, handler: H) | ||||
| @@ -117,6 +132,8 @@ impl SigSocketClient { | ||||
|         self.sign_handler = Some(Box::new(handler)); | ||||
|     } | ||||
|  | ||||
|  | ||||
|  | ||||
|     /// Get the current connection state | ||||
|     pub fn state(&self) -> ConnectionState { | ||||
|         self.state | ||||
| @@ -136,6 +153,109 @@ impl SigSocketClient { | ||||
|     pub fn url(&self) -> &str { | ||||
|         &self.url | ||||
|     } | ||||
|  | ||||
|     /// Get the connected public key (if connected) | ||||
|     pub fn connected_public_key(&self) -> Option<&str> { | ||||
|         self.connected_public_key.as_deref() | ||||
|     } | ||||
|  | ||||
|     // === Request Management Methods === | ||||
|  | ||||
|     /// Add a pending sign request | ||||
|     /// | ||||
|     /// This is typically called when a sign request is received from the server. | ||||
|     /// The request will be stored and can be retrieved later for processing. | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request` - The sign request to add | ||||
|     /// * `target_public_key` - The public key this request is intended for | ||||
|     pub fn add_pending_request(&mut self, request: SignRequest, target_public_key: String) { | ||||
|         let managed_request = ManagedSignRequest::new(request, target_public_key); | ||||
|         self.pending_requests.insert(managed_request.id().to_string(), managed_request); | ||||
|     } | ||||
|  | ||||
|     /// Remove a pending request by ID | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request_id` - The ID of the request to remove | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `Some(request)` - The removed request if it existed | ||||
|     /// * `None` - If no request with that ID was found | ||||
|     pub fn remove_pending_request(&mut self, request_id: &str) -> Option<ManagedSignRequest> { | ||||
|         self.pending_requests.remove(request_id) | ||||
|     } | ||||
|  | ||||
|     /// Get a pending request by ID | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request_id` - The ID of the request to retrieve | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `Some(request)` - The request if it exists | ||||
|     /// * `None` - If no request with that ID was found | ||||
|     pub fn get_pending_request(&self, request_id: &str) -> Option<&ManagedSignRequest> { | ||||
|         self.pending_requests.get(request_id) | ||||
|     } | ||||
|  | ||||
|     /// Get all pending requests | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * A reference to the HashMap containing all pending requests | ||||
|     pub fn get_pending_requests(&self) -> &HashMap<String, ManagedSignRequest> { | ||||
|         &self.pending_requests | ||||
|     } | ||||
|  | ||||
|     /// Get pending requests filtered by public key | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `public_key` - The public key to filter by (hex-encoded) | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * A vector of references to requests for the specified public key | ||||
|     pub fn get_requests_for_public_key(&self, public_key: &str) -> Vec<&ManagedSignRequest> { | ||||
|         self.pending_requests | ||||
|             .values() | ||||
|             .filter(|req| req.is_for_public_key(public_key)) | ||||
|             .collect() | ||||
|     } | ||||
|  | ||||
|     /// Check if a request can be handled for the given public key | ||||
|     /// | ||||
|     /// This performs protocol-level validation without cryptographic operations. | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request` - The sign request to validate | ||||
|     /// * `public_key` - The public key to check against (hex-encoded) | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `true` - If the request can be handled for this public key | ||||
|     /// * `false` - If the request cannot be handled | ||||
|     pub fn can_handle_request_for_key(&self, request: &SignRequest, public_key: &str) -> bool { | ||||
|         // Basic protocol validation | ||||
|         if request.id.is_empty() || request.message.is_empty() { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         // Check if we can decode the message | ||||
|         if request.message_bytes().is_err() { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         // For now, we assume any valid request can be handled for any public key | ||||
|         // More sophisticated validation can be added here | ||||
|         !public_key.is_empty() | ||||
|     } | ||||
|  | ||||
|     /// Clear all pending requests | ||||
|     pub fn clear_pending_requests(&mut self) { | ||||
|         self.pending_requests.clear(); | ||||
|     } | ||||
|  | ||||
|     /// Get the count of pending requests | ||||
|     pub fn pending_request_count(&self) -> usize { | ||||
|         self.pending_requests.len() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Platform-specific implementations will be added in separate modules | ||||
| @@ -176,6 +296,7 @@ impl SigSocketClient { | ||||
|         } | ||||
|  | ||||
|         self.state = ConnectionState::Connected; | ||||
|         self.connected_public_key = Some(self.public_key_hex()); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
| @@ -190,17 +311,19 @@ impl SigSocketClient { | ||||
|         } | ||||
|         self.inner = None; | ||||
|         self.state = ConnectionState::Disconnected; | ||||
|         self.connected_public_key = None; | ||||
|         self.clear_pending_requests(); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     /// Send a sign response to the server | ||||
|     ///  | ||||
|     /// | ||||
|     /// This is typically called after the user has approved a signature request | ||||
|     /// and the application has generated the signature. | ||||
|     ///  | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `response` - The sign response containing the signature | ||||
|     ///  | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `Ok(())` - Response sent successfully | ||||
|     /// * `Err(error)` - Failed to send response | ||||
| @@ -215,6 +338,41 @@ impl SigSocketClient { | ||||
|             Err(SigSocketError::NotConnected) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Send a response for a specific request ID with signature | ||||
|     /// | ||||
|     /// This is a convenience method that creates a SignResponse and sends it. | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request_id` - The ID of the request being responded to | ||||
|     /// * `message` - The original message (base64-encoded) | ||||
|     /// * `signature` - The signature (base64-encoded) | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `Ok(())` - Response sent successfully | ||||
|     /// * `Err(error)` - Failed to send response | ||||
|     pub async fn send_response(&self, request_id: &str, message: &str, signature: &str) -> Result<()> { | ||||
|         let response = SignResponse::new(request_id, message, signature); | ||||
|         self.send_sign_response(&response).await | ||||
|     } | ||||
|  | ||||
|     /// Send a rejection for a specific request ID | ||||
|     /// | ||||
|     /// This sends an error response to indicate the request was rejected. | ||||
|     /// | ||||
|     /// # Arguments | ||||
|     /// * `request_id` - The ID of the request being rejected | ||||
|     /// * `reason` - The reason for rejection | ||||
|     /// | ||||
|     /// # Returns | ||||
|     /// * `Ok(())` - Rejection sent successfully | ||||
|     /// * `Err(error)` - Failed to send rejection | ||||
|     pub async fn send_rejection(&self, request_id: &str, _reason: &str) -> Result<()> { | ||||
|         // For now, we'll send an empty signature to indicate rejection | ||||
|         // This can be improved with a proper rejection protocol | ||||
|         let response = SignResponse::new(request_id, "", ""); | ||||
|         self.send_sign_response(&response).await | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Drop for SigSocketClient { | ||||
| @@ -222,3 +380,5 @@ impl Drop for SigSocketClient { | ||||
|         // Cleanup will be handled by the platform-specific implementations | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -60,10 +60,13 @@ mod native; | ||||
| mod wasm; | ||||
|  | ||||
| pub use error::{SigSocketError, Result}; | ||||
| pub use protocol::{SignRequest, SignResponse}; | ||||
| pub use protocol::{SignRequest, SignResponse, ManagedSignRequest, RequestStatus}; | ||||
| pub use client::{SigSocketClient, SignRequestHandler, ConnectionState}; | ||||
|  | ||||
| // Re-export for convenience | ||||
| pub mod prelude { | ||||
|     pub use crate::{SigSocketClient, SignRequest, SignResponse, SignRequestHandler, ConnectionState, SigSocketError, Result}; | ||||
|     pub use crate::{ | ||||
|         SigSocketClient, SignRequest, SignResponse, ManagedSignRequest, RequestStatus, | ||||
|         SignRequestHandler, ConnectionState, SigSocketError, Result | ||||
|     }; | ||||
| } | ||||
|   | ||||
| @@ -82,6 +82,92 @@ impl SignResponse { | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Enhanced sign request with additional metadata for request management | ||||
| #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] | ||||
| pub struct ManagedSignRequest { | ||||
|     /// The original sign request | ||||
|     #[serde(flatten)] | ||||
|     pub request: SignRequest, | ||||
|     /// Timestamp when the request was received (Unix timestamp in milliseconds) | ||||
|     pub timestamp: u64, | ||||
|     /// Target public key for this request (hex-encoded) | ||||
|     pub target_public_key: String, | ||||
|     /// Current status of the request | ||||
|     pub status: RequestStatus, | ||||
| } | ||||
|  | ||||
| /// Status of a sign request | ||||
| #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] | ||||
| pub enum RequestStatus { | ||||
|     /// Request is pending user approval | ||||
|     Pending, | ||||
|     /// Request has been approved and signed | ||||
|     Approved, | ||||
|     /// Request has been rejected by user | ||||
|     Rejected, | ||||
|     /// Request has expired or been cancelled | ||||
|     Cancelled, | ||||
| } | ||||
|  | ||||
| impl ManagedSignRequest { | ||||
|     /// Create a new managed sign request | ||||
|     pub fn new(request: SignRequest, target_public_key: String) -> Self { | ||||
|         Self { | ||||
|             request, | ||||
|             timestamp: current_timestamp_ms(), | ||||
|             target_public_key, | ||||
|             status: RequestStatus::Pending, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Get the request ID | ||||
|     pub fn id(&self) -> &str { | ||||
|         &self.request.id | ||||
|     } | ||||
|  | ||||
|     /// Get the message as bytes (decoded from base64) | ||||
|     pub fn message_bytes(&self) -> Result<Vec<u8>, base64::DecodeError> { | ||||
|         self.request.message_bytes() | ||||
|     } | ||||
|  | ||||
|     /// Check if this request is for the given public key | ||||
|     pub fn is_for_public_key(&self, public_key: &str) -> bool { | ||||
|         self.target_public_key == public_key | ||||
|     } | ||||
|  | ||||
|     /// Mark the request as approved | ||||
|     pub fn mark_approved(&mut self) { | ||||
|         self.status = RequestStatus::Approved; | ||||
|     } | ||||
|  | ||||
|     /// Mark the request as rejected | ||||
|     pub fn mark_rejected(&mut self) { | ||||
|         self.status = RequestStatus::Rejected; | ||||
|     } | ||||
|  | ||||
|     /// Check if the request is still pending | ||||
|     pub fn is_pending(&self) -> bool { | ||||
|         matches!(self.status, RequestStatus::Pending) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Get current timestamp in milliseconds | ||||
| #[cfg(not(target_arch = "wasm32"))] | ||||
| fn current_timestamp_ms() -> u64 { | ||||
|     std::time::SystemTime::now() | ||||
|         .duration_since(std::time::UNIX_EPOCH) | ||||
|         .unwrap_or_default() | ||||
|         .as_millis() as u64 | ||||
| } | ||||
|  | ||||
| /// Get current timestamp in milliseconds (WASM version) | ||||
| #[cfg(target_arch = "wasm32")] | ||||
| fn current_timestamp_ms() -> u64 { | ||||
|     // In WASM, we'll use a simple counter or Date.now() via JS | ||||
|     // For now, return 0 - this can be improved later | ||||
|     0 | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use super::*; | ||||
| @@ -138,4 +224,33 @@ mod tests { | ||||
|         let deserialized: SignResponse = serde_json::from_str(&json).unwrap(); | ||||
|         assert_eq!(response, deserialized); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_managed_sign_request() { | ||||
|         let request = SignRequest::new("test-id", "dGVzdCBtZXNzYWdl"); | ||||
|         let managed = ManagedSignRequest::new(request.clone(), "test-public-key".to_string()); | ||||
|  | ||||
|         assert_eq!(managed.id(), "test-id"); | ||||
|         assert_eq!(managed.request, request); | ||||
|         assert_eq!(managed.target_public_key, "test-public-key"); | ||||
|         assert!(managed.is_pending()); | ||||
|         assert!(managed.is_for_public_key("test-public-key")); | ||||
|         assert!(!managed.is_for_public_key("other-key")); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_managed_request_status_changes() { | ||||
|         let request = SignRequest::new("test-id", "dGVzdCBtZXNzYWdl"); | ||||
|         let mut managed = ManagedSignRequest::new(request, "test-public-key".to_string()); | ||||
|  | ||||
|         assert!(managed.is_pending()); | ||||
|  | ||||
|         managed.mark_approved(); | ||||
|         assert_eq!(managed.status, RequestStatus::Approved); | ||||
|         assert!(!managed.is_pending()); | ||||
|  | ||||
|         managed.mark_rejected(); | ||||
|         assert_eq!(managed.status, RequestStatus::Rejected); | ||||
|         assert!(!managed.is_pending()); | ||||
|     } | ||||
| } | ||||
|   | ||||
							
								
								
									
										92
									
								
								sigsocket_client/tests/request_management_test.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								sigsocket_client/tests/request_management_test.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| //! Tests for the enhanced request management functionality | ||||
|  | ||||
| use sigsocket_client::prelude::*; | ||||
|  | ||||
| #[test] | ||||
| fn test_client_request_management() { | ||||
|     let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap(); | ||||
|     let mut client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap(); | ||||
|      | ||||
|     // Initially no requests | ||||
|     assert_eq!(client.pending_request_count(), 0); | ||||
|     assert!(client.get_pending_requests().is_empty()); | ||||
|      | ||||
|     // Add a request | ||||
|     let request = SignRequest::new("test-1", "dGVzdCBtZXNzYWdl"); | ||||
|     let public_key_hex = "02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9"; | ||||
|     client.add_pending_request(request.clone(), public_key_hex.to_string()); | ||||
|      | ||||
|     // Check request was added | ||||
|     assert_eq!(client.pending_request_count(), 1); | ||||
|     assert!(client.get_pending_request("test-1").is_some()); | ||||
|      | ||||
|     // Check filtering by public key | ||||
|     let filtered = client.get_requests_for_public_key(public_key_hex); | ||||
|     assert_eq!(filtered.len(), 1); | ||||
|     assert_eq!(filtered[0].id(), "test-1"); | ||||
|      | ||||
|     // Add another request for different public key | ||||
|     let request2 = SignRequest::new("test-2", "dGVzdCBtZXNzYWdlMg=="); | ||||
|     let other_public_key = "03f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9"; | ||||
|     client.add_pending_request(request2, other_public_key.to_string()); | ||||
|      | ||||
|     // Check total count | ||||
|     assert_eq!(client.pending_request_count(), 2); | ||||
|      | ||||
|     // Check filtering still works | ||||
|     let filtered = client.get_requests_for_public_key(public_key_hex); | ||||
|     assert_eq!(filtered.len(), 1); | ||||
|      | ||||
|     let filtered_other = client.get_requests_for_public_key(other_public_key); | ||||
|     assert_eq!(filtered_other.len(), 1); | ||||
|      | ||||
|     // Remove a request | ||||
|     let removed = client.remove_pending_request("test-1"); | ||||
|     assert!(removed.is_some()); | ||||
|     assert_eq!(removed.unwrap().id(), "test-1"); | ||||
|     assert_eq!(client.pending_request_count(), 1); | ||||
|      | ||||
|     // Clear all requests | ||||
|     client.clear_pending_requests(); | ||||
|     assert_eq!(client.pending_request_count(), 0); | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| fn test_client_request_validation() { | ||||
|     let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap(); | ||||
|     let client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap(); | ||||
|      | ||||
|     // Valid request | ||||
|     let valid_request = SignRequest::new("test-1", "dGVzdCBtZXNzYWdl"); | ||||
|     assert!(client.can_handle_request_for_key(&valid_request, "some-public-key")); | ||||
|      | ||||
|     // Invalid request - empty ID | ||||
|     let invalid_request = SignRequest::new("", "dGVzdCBtZXNzYWdl"); | ||||
|     assert!(!client.can_handle_request_for_key(&invalid_request, "some-public-key")); | ||||
|      | ||||
|     // Invalid request - empty message | ||||
|     let invalid_request2 = SignRequest::new("test-1", ""); | ||||
|     assert!(!client.can_handle_request_for_key(&invalid_request2, "some-public-key")); | ||||
|      | ||||
|     // Invalid request - invalid base64 | ||||
|     let invalid_request3 = SignRequest::new("test-1", "invalid-base64!"); | ||||
|     assert!(!client.can_handle_request_for_key(&invalid_request3, "some-public-key")); | ||||
|      | ||||
|     // Invalid public key | ||||
|     assert!(!client.can_handle_request_for_key(&valid_request, "")); | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| fn test_client_connection_state() { | ||||
|     let public_key = hex::decode("02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9").unwrap(); | ||||
|     let client = SigSocketClient::new("ws://localhost:8080/ws", public_key).unwrap(); | ||||
|      | ||||
|     // Initially disconnected | ||||
|     assert_eq!(client.state(), ConnectionState::Disconnected); | ||||
|     assert!(!client.is_connected()); | ||||
|     assert!(client.connected_public_key().is_none()); | ||||
|      | ||||
|     // Public key should be available | ||||
|     assert_eq!(client.public_key_hex(), "02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9"); | ||||
|     assert_eq!(client.url(), "ws://localhost:8080/ws"); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user