From ec91a15131ca9de468463b9fae28d83531a5a34d Mon Sep 17 00:00:00 2001 From: Lee Smet Date: Thu, 21 Aug 2025 16:45:10 +0200 Subject: [PATCH] Implement DAG for flow Signed-off-by: Lee Smet --- src/dag.rs | 207 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/models/flow.rs | 15 ++++ src/models/job.rs | 21 +++++ src/rpc.rs | 34 +++++++- 5 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 src/dag.rs diff --git a/src/dag.rs b/src/dag.rs new file mode 100644 index 0000000..845cff5 --- /dev/null +++ b/src/dag.rs @@ -0,0 +1,207 @@ +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fmt; + +use crate::{ + models::{Flow, Job, ScriptType}, + storage::RedisDriver, +}; + +pub type DagResult = Result; + +#[derive(Debug)] +pub enum DagError { + Storage(Box), + MissingDependency { job: u32, depends_on: u32 }, + CycleDetected { remaining: Vec }, +} + +impl fmt::Display for DagError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DagError::Storage(e) => write!(f, "Storage error: {}", e), + DagError::MissingDependency { job, depends_on } => write!( + f, + "Job {} depends on {}, which is not part of the flow.jobs list", + job, depends_on + ), + DagError::CycleDetected { remaining } => { + write!(f, "Cycle detected; unresolved nodes: {:?}", remaining) + } + } + } +} + +impl std::error::Error for DagError {} + +impl From> for DagError { + fn from(e: Box) -> Self { + DagError::Storage(e) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JobSummary { + pub id: u32, + pub depends: Vec, + pub prerequisites: Vec, + pub script_type: ScriptType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlowDag { + pub flow_id: u32, + pub caller_id: u32, + pub context_id: u32, + pub nodes: HashMap, + pub edges: Vec<(u32, u32)>, // (from prerequisite, to job) + pub reverse_edges: Vec<(u32, u32)>, // (from job, to prerequisite) + pub roots: Vec, // in_degree == 0 + pub leaves: Vec, // out_degree == 0 + pub levels: Vec>, // topological layers for parallel execution +} + +pub async fn build_flow_dag( + redis: &RedisDriver, + context_id: u32, + flow_id: u32, +) -> DagResult { + // Load flow + let flow: Flow = redis + .load_flow(context_id, flow_id) + .await + .map_err(DagError::from)?; + let caller_id = flow.caller_id(); + let flow_job_ids = flow.jobs().clone(); + + // Build a set for faster membership tests + let job_id_set: HashSet = flow_job_ids.iter().copied().collect(); + + // Load all jobs + let mut jobs: HashMap = HashMap::with_capacity(flow_job_ids.len()); + for jid in &flow_job_ids { + let job = redis + .load_job(context_id, caller_id, *jid) + .await + .map_err(DagError::from)?; + jobs.insert(*jid, job); + } + + // Validate dependencies and construct adjacency + let mut edges: Vec<(u32, u32)> = Vec::new(); + let mut reverse_edges: Vec<(u32, u32)> = Vec::new(); + let mut adj: HashMap> = HashMap::with_capacity(jobs.len()); + let mut rev_adj: HashMap> = HashMap::with_capacity(jobs.len()); + let mut in_degree: HashMap = HashMap::with_capacity(jobs.len()); + + for &jid in &flow_job_ids { + adj.entry(jid).or_default(); + rev_adj.entry(jid).or_default(); + in_degree.entry(jid).or_insert(0); + } + + for (&jid, job) in &jobs { + for &dep in job.depends() { + if !job_id_set.contains(&dep) { + return Err(DagError::MissingDependency { + job: jid, + depends_on: dep, + }); + } + // edge: dep -> jid + edges.push((dep, jid)); + reverse_edges.push((jid, dep)); + adj.get_mut(&dep).unwrap().push(jid); + rev_adj.get_mut(&jid).unwrap().push(dep); + *in_degree.get_mut(&jid).unwrap() += 1; + } + } + + // Kahn's algorithm for topological sorting, with level construction + let mut zero_in: VecDeque = in_degree + .iter() + .filter_map(|(k, v)| if *v == 0 { Some(*k) } else { None }) + .collect(); + + let mut processed_count = 0usize; + let mut levels: Vec> = Vec::new(); + + // To make deterministic, sort initial zero_in + { + let mut tmp: Vec = zero_in.iter().copied().collect(); + tmp.sort_unstable(); + zero_in = tmp.into_iter().collect(); + } + + while !zero_in.is_empty() { + let mut level: Vec = Vec::new(); + // drain current frontier + let mut next_zero: Vec = Vec::new(); + let mut current_frontier: Vec = zero_in.drain(..).collect(); + current_frontier.sort_unstable(); + for u in current_frontier { + level.push(u); + processed_count += 1; + if let Some(children) = adj.get(&u) { + let mut sorted_children = children.clone(); + sorted_children.sort_unstable(); + for &v in &sorted_children { + let d = in_degree.get_mut(&v).unwrap(); + *d -= 1; + if *d == 0 { + next_zero.push(v); + } + } + } + } + next_zero.sort_unstable(); + zero_in = next_zero.into_iter().collect(); + levels.push(level); + } + + if processed_count != jobs.len() { + let remaining: Vec = in_degree + .into_iter() + .filter_map(|(k, v)| if v > 0 { Some(k) } else { None }) + .collect(); + return Err(DagError::CycleDetected { remaining }); + } + + // Roots and leaves + let roots: Vec = levels.first().cloned().unwrap_or_default(); + let leaves: Vec = adj + .iter() + .filter_map(|(k, v)| if v.is_empty() { Some(*k) } else { None }) + .collect(); + + // Nodes map (JobSummary) + let mut nodes: HashMap = HashMap::with_capacity(jobs.len()); + for (&jid, job) in &jobs { + let summary = JobSummary { + id: jid, + depends: job.depends().to_vec(), + prerequisites: job.prerequisites().to_vec(), + script_type: job.script_type(), + }; + nodes.insert(jid, summary); + } + + // Sort edges deterministically + edges.sort_unstable(); + reverse_edges.sort_unstable(); + + let dag = FlowDag { + flow_id, + caller_id, + context_id, + nodes, + edges, + reverse_edges, + roots, + leaves, + levels, + }; + + Ok(dag) +} + diff --git a/src/lib.rs b/src/lib.rs index f9fb8a0..83161a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod models; pub mod storage; mod time; +pub mod dag; pub mod rpc; diff --git a/src/models/flow.rs b/src/models/flow.rs index 542a8c5..d726cbc 100644 --- a/src/models/flow.rs +++ b/src/models/flow.rs @@ -31,3 +31,18 @@ pub enum FlowStatus { Error, Finished, } + +impl Flow { + pub fn id(&self) -> u32 { + self.id + } + pub fn caller_id(&self) -> u32 { + self.caller_id + } + pub fn context_id(&self) -> u32 { + self.context_id + } + pub fn jobs(&self) -> &Vec { + &self.jobs + } +} diff --git a/src/models/job.rs b/src/models/job.rs index 4c62865..50205dd 100644 --- a/src/models/job.rs +++ b/src/models/job.rs @@ -36,3 +36,24 @@ pub enum JobStatus { Error, Finished, } + +impl Job { + pub fn id(&self) -> u32 { + self.id + } + pub fn caller_id(&self) -> u32 { + self.caller_id + } + pub fn context_id(&self) -> u32 { + self.context_id + } + pub fn depends(&self) -> &Vec { + &self.depends + } + pub fn prerequisites(&self) -> &Vec { + &self.prerequisites + } + pub fn script_type(&self) -> ScriptType { + self.script_type.clone() + } +} diff --git a/src/rpc.rs b/src/rpc.rs index a348bc7..851e0a1 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -9,10 +9,11 @@ use jsonrpsee::{ server::{ServerBuilder, ServerHandle}, types::error::ErrorObjectOwned, }; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::{Value, json}; use crate::{ + dag::{DagError, FlowDag, build_flow_dag}, models::{Actor, Context, Flow, Job, Message, MessageFormatType, Runner, ScriptType}, storage::RedisDriver, time::current_timestamp, @@ -45,6 +46,22 @@ fn storage_err(e: Box) -> ErrorObjectOwned } } +fn dag_err(e: DagError) -> ErrorObjectOwned { + match e { + DagError::Storage(inner) => storage_err(inner), + DagError::MissingDependency { .. } => ErrorObjectOwned::owned( + -32020, + "DAG Missing Dependency", + Some(Value::String(e.to_string())), + ), + DagError::CycleDetected { .. } => ErrorObjectOwned::owned( + -32021, + "DAG Cycle Detected", + Some(Value::String(e.to_string())), + ), + } +} + // ----------------------------- // Create DTOs and Param wrappers // ----------------------------- @@ -447,6 +464,21 @@ pub fn build_module(state: Arc) -> RpcModule<()> { }) .expect("register flow.load"); } + { + let state = state.clone(); + module + .register_async_method("flow.dag", move |params, _caller, _ctx| { + let state = state.clone(); + async move { + let p: FlowLoadParams = params.parse().map_err(invalid_params_err)?; + let dag: FlowDag = build_flow_dag(&state.redis, p.context_id, p.id) + .await + .map_err(dag_err)?; + Ok::<_, ErrorObjectOwned>(dag) + } + }) + .expect("register flow.dag"); + } // Job {