Implement DAG for flow

Signed-off-by: Lee Smet <lee.smet@hotmail.com>
This commit is contained in:
Lee Smet
2025-08-21 16:45:10 +02:00
parent eb69a44039
commit ec91a15131
5 changed files with 277 additions and 1 deletions

207
src/dag.rs Normal file
View File

@@ -0,0 +1,207 @@
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use crate::{
models::{Flow, Job, ScriptType},
storage::RedisDriver,
};
pub type DagResult<T> = Result<T, DagError>;
#[derive(Debug)]
pub enum DagError {
Storage(Box<dyn std::error::Error + Send + Sync>),
MissingDependency { job: u32, depends_on: u32 },
CycleDetected { remaining: Vec<u32> },
}
impl fmt::Display for DagError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DagError::Storage(e) => write!(f, "Storage error: {}", e),
DagError::MissingDependency { job, depends_on } => write!(
f,
"Job {} depends on {}, which is not part of the flow.jobs list",
job, depends_on
),
DagError::CycleDetected { remaining } => {
write!(f, "Cycle detected; unresolved nodes: {:?}", remaining)
}
}
}
}
impl std::error::Error for DagError {}
impl From<Box<dyn std::error::Error + Send + Sync>> for DagError {
fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self {
DagError::Storage(e)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JobSummary {
pub id: u32,
pub depends: Vec<u32>,
pub prerequisites: Vec<String>,
pub script_type: ScriptType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlowDag {
pub flow_id: u32,
pub caller_id: u32,
pub context_id: u32,
pub nodes: HashMap<u32, JobSummary>,
pub edges: Vec<(u32, u32)>, // (from prerequisite, to job)
pub reverse_edges: Vec<(u32, u32)>, // (from job, to prerequisite)
pub roots: Vec<u32>, // in_degree == 0
pub leaves: Vec<u32>, // out_degree == 0
pub levels: Vec<Vec<u32>>, // topological layers for parallel execution
}
pub async fn build_flow_dag(
redis: &RedisDriver,
context_id: u32,
flow_id: u32,
) -> DagResult<FlowDag> {
// Load flow
let flow: Flow = redis
.load_flow(context_id, flow_id)
.await
.map_err(DagError::from)?;
let caller_id = flow.caller_id();
let flow_job_ids = flow.jobs().clone();
// Build a set for faster membership tests
let job_id_set: HashSet<u32> = flow_job_ids.iter().copied().collect();
// Load all jobs
let mut jobs: HashMap<u32, Job> = HashMap::with_capacity(flow_job_ids.len());
for jid in &flow_job_ids {
let job = redis
.load_job(context_id, caller_id, *jid)
.await
.map_err(DagError::from)?;
jobs.insert(*jid, job);
}
// Validate dependencies and construct adjacency
let mut edges: Vec<(u32, u32)> = Vec::new();
let mut reverse_edges: Vec<(u32, u32)> = Vec::new();
let mut adj: HashMap<u32, Vec<u32>> = HashMap::with_capacity(jobs.len());
let mut rev_adj: HashMap<u32, Vec<u32>> = HashMap::with_capacity(jobs.len());
let mut in_degree: HashMap<u32, usize> = HashMap::with_capacity(jobs.len());
for &jid in &flow_job_ids {
adj.entry(jid).or_default();
rev_adj.entry(jid).or_default();
in_degree.entry(jid).or_insert(0);
}
for (&jid, job) in &jobs {
for &dep in job.depends() {
if !job_id_set.contains(&dep) {
return Err(DagError::MissingDependency {
job: jid,
depends_on: dep,
});
}
// edge: dep -> jid
edges.push((dep, jid));
reverse_edges.push((jid, dep));
adj.get_mut(&dep).unwrap().push(jid);
rev_adj.get_mut(&jid).unwrap().push(dep);
*in_degree.get_mut(&jid).unwrap() += 1;
}
}
// Kahn's algorithm for topological sorting, with level construction
let mut zero_in: VecDeque<u32> = in_degree
.iter()
.filter_map(|(k, v)| if *v == 0 { Some(*k) } else { None })
.collect();
let mut processed_count = 0usize;
let mut levels: Vec<Vec<u32>> = Vec::new();
// To make deterministic, sort initial zero_in
{
let mut tmp: Vec<u32> = zero_in.iter().copied().collect();
tmp.sort_unstable();
zero_in = tmp.into_iter().collect();
}
while !zero_in.is_empty() {
let mut level: Vec<u32> = Vec::new();
// drain current frontier
let mut next_zero: Vec<u32> = Vec::new();
let mut current_frontier: Vec<u32> = zero_in.drain(..).collect();
current_frontier.sort_unstable();
for u in current_frontier {
level.push(u);
processed_count += 1;
if let Some(children) = adj.get(&u) {
let mut sorted_children = children.clone();
sorted_children.sort_unstable();
for &v in &sorted_children {
let d = in_degree.get_mut(&v).unwrap();
*d -= 1;
if *d == 0 {
next_zero.push(v);
}
}
}
}
next_zero.sort_unstable();
zero_in = next_zero.into_iter().collect();
levels.push(level);
}
if processed_count != jobs.len() {
let remaining: Vec<u32> = in_degree
.into_iter()
.filter_map(|(k, v)| if v > 0 { Some(k) } else { None })
.collect();
return Err(DagError::CycleDetected { remaining });
}
// Roots and leaves
let roots: Vec<u32> = levels.first().cloned().unwrap_or_default();
let leaves: Vec<u32> = adj
.iter()
.filter_map(|(k, v)| if v.is_empty() { Some(*k) } else { None })
.collect();
// Nodes map (JobSummary)
let mut nodes: HashMap<u32, JobSummary> = HashMap::with_capacity(jobs.len());
for (&jid, job) in &jobs {
let summary = JobSummary {
id: jid,
depends: job.depends().to_vec(),
prerequisites: job.prerequisites().to_vec(),
script_type: job.script_type(),
};
nodes.insert(jid, summary);
}
// Sort edges deterministically
edges.sort_unstable();
reverse_edges.sort_unstable();
let dag = FlowDag {
flow_id,
caller_id,
context_id,
nodes,
edges,
reverse_edges,
roots,
leaves,
levels,
};
Ok(dag)
}

View File

@@ -1,4 +1,5 @@
pub mod models;
pub mod storage;
mod time;
pub mod dag;
pub mod rpc;

View File

@@ -31,3 +31,18 @@ pub enum FlowStatus {
Error,
Finished,
}
impl Flow {
pub fn id(&self) -> u32 {
self.id
}
pub fn caller_id(&self) -> u32 {
self.caller_id
}
pub fn context_id(&self) -> u32 {
self.context_id
}
pub fn jobs(&self) -> &Vec<u32> {
&self.jobs
}
}

View File

@@ -36,3 +36,24 @@ pub enum JobStatus {
Error,
Finished,
}
impl Job {
pub fn id(&self) -> u32 {
self.id
}
pub fn caller_id(&self) -> u32 {
self.caller_id
}
pub fn context_id(&self) -> u32 {
self.context_id
}
pub fn depends(&self) -> &Vec<u32> {
&self.depends
}
pub fn prerequisites(&self) -> &Vec<String> {
&self.prerequisites
}
pub fn script_type(&self) -> ScriptType {
self.script_type.clone()
}
}

View File

@@ -9,10 +9,11 @@ use jsonrpsee::{
server::{ServerBuilder, ServerHandle},
types::error::ErrorObjectOwned,
};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use serde_json::{Value, json};
use crate::{
dag::{DagError, FlowDag, build_flow_dag},
models::{Actor, Context, Flow, Job, Message, MessageFormatType, Runner, ScriptType},
storage::RedisDriver,
time::current_timestamp,
@@ -45,6 +46,22 @@ fn storage_err(e: Box<dyn std::error::Error + Send + Sync>) -> ErrorObjectOwned
}
}
fn dag_err(e: DagError) -> ErrorObjectOwned {
match e {
DagError::Storage(inner) => storage_err(inner),
DagError::MissingDependency { .. } => ErrorObjectOwned::owned(
-32020,
"DAG Missing Dependency",
Some(Value::String(e.to_string())),
),
DagError::CycleDetected { .. } => ErrorObjectOwned::owned(
-32021,
"DAG Cycle Detected",
Some(Value::String(e.to_string())),
),
}
}
// -----------------------------
// Create DTOs and Param wrappers
// -----------------------------
@@ -447,6 +464,21 @@ pub fn build_module(state: Arc<AppState>) -> RpcModule<()> {
})
.expect("register flow.load");
}
{
let state = state.clone();
module
.register_async_method("flow.dag", move |params, _caller, _ctx| {
let state = state.clone();
async move {
let p: FlowLoadParams = params.parse().map_err(invalid_params_err)?;
let dag: FlowDag = build_flow_dag(&state.redis, p.context_id, p.id)
.await
.map_err(dag_err)?;
Ok::<_, ErrorObjectOwned>(dag)
}
})
.expect("register flow.dag");
}
// Job
{