cryprot_net/
metrics.rs

1//! [`tracing_subscriber::Layer`] for structured communication metrics.
2//!
3//! The [`CommLayer`] is a [`tracing_subscriber::Layer`] which records numbers
4//! of bytes read and written. Metrics are collected by
5//! [`instrumenting`](`macro@tracing::instrument`) spans with the
6//! `cryprot_metrics` target and a phase. From within these spans, events with
7//! the same target can be emitted to track the number of bytes read/written.
8//!
9//! ```
10//! use tracing::{event, instrument, Level};
11//!
12//! #[instrument(target = "cryprot_metrics", fields(phase = "Online"))]
13//! async fn online() {
14//!     event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 5);
15//!     interleaved_setup().await
16//! }
17//!
18//! #[instrument(target = "cryprot_metrics", fields(phase = "Setup"))]
19//! async fn interleaved_setup() {
20//!     // Will be recorded in the sub phase "Setup" of the online phase
21//!     event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 10);
22//! }
23//! ```
24use std::{
25    collections::{BTreeMap, btree_map::Entry},
26    fmt::Debug,
27    mem,
28    ops::AddAssign,
29    sync::{Arc, Mutex},
30};
31
32use serde::{Deserialize, Serialize};
33use tracing::{
34    Level,
35    field::{Field, Visit},
36    span::{Attributes, Id},
37    warn,
38};
39use tracing_subscriber::{
40    filter::{Filtered, Targets},
41    layer::{Context, Layer},
42};
43
44#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45/// Communication metrics for a phase and its sub phases.
46pub struct CommData {
47    pub phase: String,
48    pub read: Counter,
49    pub write: Counter,
50    pub sub_comm_data: SubCommData,
51}
52
53#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
54pub struct Counter {
55    /// Number of written/read directly in this phase.
56    pub bytes: u64,
57    /// Total number of bytes written/read in this phase an all sub phases.
58    pub bytes_with_sub_comm: u64,
59}
60
61#[derive(Debug, Default, Clone, Serialize, Deserialize)]
62/// Sub communication data for different phases
63pub struct SubCommData(BTreeMap<String, CommData>);
64
65/// Convenience type alias for a filtered `CommLayerData` which only handles
66/// spans and events with `target = "cryprot_metrics"`.
67pub type CommLayer<S> = Filtered<CommLayerData, Targets, S>;
68
69#[derive(Clone, Debug, Default)]
70/// The `CommLayerData` has shared ownership of the root [`SubCommData`].
71pub struct CommLayerData {
72    // TOOD use Atomics in SubCommData to not need lock, maybe?
73    comm_data: Arc<Mutex<SubCommData>>,
74}
75
76/// Instantiate a new [`CommLayer`] and corresponding [`CommLayerData`].
77pub fn new_comm_layer<S>() -> (CommLayer<S>, CommLayerData)
78where
79    S: tracing::Subscriber,
80    S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
81{
82    let inner = CommLayerData::default();
83    let target_filter = Targets::new().with_target("cryprot_metrics", Level::TRACE);
84    (inner.clone().with_filter(target_filter), inner)
85}
86
87impl CommLayerData {
88    /// Returns a clone of the root `SubCommData` at this moment.
89    pub fn comm_data(&self) -> SubCommData {
90        self.comm_data.lock().expect("lock poisoned").clone()
91    }
92
93    /// Resets the root `SubCommData` and returns it.
94    ///
95    /// Do not use this method while an instrumented `target = cryprot_metrics`
96    /// span is active, as this will result in inconsistent data.
97    pub fn reset(&self) -> SubCommData {
98        let mut comm_data = self.comm_data.lock().expect("lock poisoned");
99        mem::take(&mut *comm_data)
100    }
101}
102
103impl<S> Layer<S> for CommLayerData
104where
105    S: tracing::Subscriber,
106    S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
107{
108    fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
109        let span = ctx.span(id).expect("Id is valid");
110        let mut visitor = PhaseVisitor(None);
111        attrs.record(&mut visitor);
112        if let Some(phase) = visitor.0 {
113            let data = CommData::new(phase);
114            span.extensions_mut().insert(data);
115        }
116    }
117
118    fn on_event(&self, event: &tracing::Event<'_>, ctx: Context<'_, S>) {
119        let Some(span) = ctx.event_span(event) else {
120            warn!(
121                "Received cryprot_metrics event outside of cryprot_metrics span. \
122                Communication is not tracked"
123            );
124            return;
125        };
126        // Check that we only have one field per event, otherwise the CommEventVisitor
127        // will only record on of them
128        let field_cnt = event
129            .fields()
130            .filter(|field| field.name() == "bytes_read" || field.name() == "bytes_written")
131            .count();
132        if field_cnt >= 2 {
133            warn!("Use individual events to record bytes_read and bytes_written");
134            return;
135        }
136        let mut vis = CommEventVisitor(None);
137        event.record(&mut vis);
138        if let Some(event) = vis.0 {
139            let mut extensions = span.extensions_mut();
140            let Some(comm_data) = extensions.get_mut::<CommData>() else {
141                warn!(
142                    "Received cryprot_metrics event inside cryprot_metrics span with no phase. \
143                    Communication is not tracked"
144                );
145                return;
146            };
147            match event {
148                CommEvent::Read(read) => {
149                    comm_data.read += read;
150                }
151                CommEvent::Write(written) => {
152                    comm_data.write += written;
153                }
154            }
155        }
156    }
157
158    fn on_close(&self, id: Id, ctx: Context<'_, S>) {
159        let span = ctx.span(&id).expect("Id is valid");
160        let mut extensions = span.extensions_mut();
161        let Some(comm_data) = extensions.get_mut::<CommData>().map(mem::take) else {
162            // nothing to do
163            return;
164        };
165
166        // TODO can merging of comm data be done in a background thread? Benchmark
167        // first!
168        if let Some(parent) = span.parent() {
169            if let Some(parent_comm_data) = parent.extensions_mut().get_mut::<CommData>() {
170                let entry = parent_comm_data
171                    .sub_comm_data
172                    .0
173                    .entry(comm_data.phase.clone())
174                    .or_insert_with(|| CommData::new(comm_data.phase.clone()));
175                parent_comm_data.read.bytes_with_sub_comm += comm_data.read.bytes_with_sub_comm;
176                parent_comm_data.write.bytes_with_sub_comm += comm_data.write.bytes_with_sub_comm;
177                merge(comm_data, entry)
178            }
179        } else {
180            let mut root_comm_data = self.comm_data.lock().expect("lock poisoned");
181            let phase_comm_data = root_comm_data
182                .0
183                .entry(comm_data.phase.clone())
184                .or_insert_with(|| CommData::new(comm_data.phase.clone()));
185            merge(comm_data, phase_comm_data);
186        }
187    }
188}
189
190fn merge(from: CommData, into: &mut CommData) {
191    into.read += from.read;
192    into.write += from.write;
193    for (phase, from_sub_comm) in from.sub_comm_data.0.into_iter() {
194        match into.sub_comm_data.0.entry(phase) {
195            Entry::Vacant(entry) => {
196                entry.insert(from_sub_comm);
197            }
198            Entry::Occupied(mut entry) => {
199                merge(from_sub_comm, entry.get_mut());
200            }
201        }
202    }
203}
204
205impl SubCommData {
206    /// Get the [`CommData`] for a phase.
207    pub fn get(&self, phase: &str) -> Option<&CommData> {
208        self.0.get(phase)
209    }
210
211    /// Iterate over all [`CommData`].
212    pub fn iter(&self) -> impl Iterator<Item = &CommData> {
213        self.0.values()
214    }
215}
216
217impl AddAssign for Counter {
218    fn add_assign(&mut self, rhs: Self) {
219        self.bytes += rhs.bytes;
220        self.bytes_with_sub_comm += rhs.bytes_with_sub_comm;
221    }
222}
223
224impl AddAssign<u64> for Counter {
225    fn add_assign(&mut self, rhs: u64) {
226        self.bytes += rhs;
227        self.bytes_with_sub_comm += rhs;
228    }
229}
230
231impl CommData {
232    fn new(phase: String) -> Self {
233        Self {
234            phase,
235            ..Default::default()
236        }
237    }
238}
239
240struct PhaseVisitor(Option<String>);
241
242impl Visit for PhaseVisitor {
243    fn record_str(&mut self, field: &Field, value: &str) {
244        if field.name() == "phase" {
245            self.0 = Some(value.to_owned());
246        }
247    }
248
249    fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
250        if field.name() == "phase" {
251            self.0 = Some(format!("{value:?}"));
252        }
253    }
254}
255
256enum CommEvent {
257    Read(u64),
258    Write(u64),
259}
260
261struct CommEventVisitor(Option<CommEvent>);
262
263impl CommEventVisitor {
264    fn record<T>(&mut self, field: &Field, value: T)
265    where
266        T: TryInto<u64>,
267        T::Error: Debug,
268    {
269        let name = field.name();
270        if name != "bytes_written" && name != "bytes_read" {
271            return;
272        }
273        let value = value
274            .try_into()
275            .expect("recorded bytes must be convertible to u64");
276        if name == "bytes_written" {
277            self.0 = Some(CommEvent::Write(value))
278        } else if name == "bytes_read" {
279            self.0 = Some(CommEvent::Read(value))
280        }
281    }
282}
283
284impl Visit for CommEventVisitor {
285    fn record_i64(&mut self, field: &Field, value: i64) {
286        self.record(field, value);
287    }
288    fn record_u64(&mut self, field: &Field, value: u64) {
289        self.record(field, value)
290    }
291    fn record_i128(&mut self, field: &Field, value: i128) {
292        self.record(field, value)
293    }
294    fn record_u128(&mut self, field: &Field, value: u128) {
295        self.record(field, value)
296    }
297    fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
298        warn!(
299            "cryprot_metrics event with field which is not an integer. {}: {:?}",
300            field.name(),
301            value
302        )
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use std::time::Duration;
309
310    use tokio::{self, join, time::sleep};
311    use tracing::{Instrument, Level, event, instrument};
312    use tracing_subscriber::{Registry, layer::SubscriberExt};
313
314    use crate::metrics::new_comm_layer;
315
316    #[tokio::test]
317    async fn test_communication_metrics() {
318        #[instrument(target = "cryprot_metrics", fields(phase = "TopLevel"))]
319        async fn top_level_operation() {
320            // Simulate some direct communication
321            event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 100);
322            event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 200);
323
324            // Call sub-operation
325            sub_operation().await;
326        }
327
328        #[instrument(target = "cryprot_metrics", fields(phase = "SubOperation"))]
329        async fn sub_operation() {
330            // Simulate some communication in the sub-operation
331            event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 50);
332            event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 100);
333        }
334
335        // Set up the metrics layer
336        let (comm_layer, comm_data) = new_comm_layer();
337        let subscriber = Registry::default().with(comm_layer);
338        let _guard = tracing::subscriber::set_default(subscriber);
339
340        // Run instrumented functions
341        top_level_operation().await;
342
343        // Verify metrics
344        let metrics = comm_data.comm_data();
345
346        // Check top level metrics
347        let top_phase = metrics
348            .get("TopLevel")
349            .expect("TopLevel phase should exist");
350        assert_eq!(top_phase.phase, "TopLevel");
351        assert_eq!(top_phase.read.bytes, 100);
352        assert_eq!(top_phase.write.bytes, 200);
353        assert_eq!(top_phase.read.bytes_with_sub_comm, 150); // 100 (direct) + 50 (from sub)
354        assert_eq!(top_phase.write.bytes_with_sub_comm, 300); // 200 (direct) + 100 (from sub)
355
356        // Check sub-phase metrics
357        let sub_phase = top_phase
358            .sub_comm_data
359            .get("SubOperation")
360            .expect("SubOperation phase should exist");
361        assert_eq!(sub_phase.phase, "SubOperation");
362        assert_eq!(sub_phase.read.bytes, 50);
363        assert_eq!(sub_phase.write.bytes, 100);
364        assert_eq!(sub_phase.read.bytes_with_sub_comm, 50);
365        assert_eq!(sub_phase.write.bytes_with_sub_comm, 100);
366
367        // Reset metrics and verify they're cleared
368        let reset_metrics = comm_data.reset();
369        assert!(reset_metrics.get("TopLevel").is_some());
370        let new_metrics = comm_data.comm_data();
371        assert!(new_metrics.get("TopLevel").is_none());
372    }
373
374    #[tokio::test]
375    async fn test_parallel_span_accumulation() {
376        #[instrument(target = "cryprot_metrics", fields(phase = "ParentPhase"))]
377        async fn parallel_operation(id: u32) {
378            // If communication of a sub-phase happens in a spawned task, the future needs
379            // to be instrumented with the current span to preserve hierarchy
380            tokio::spawn(sub_operation(id).in_current_span())
381                .await
382                .unwrap();
383        }
384
385        #[instrument(target = "cryprot_metrics", fields(phase = "SubPhase"))]
386        async fn sub_operation(id: u32) {
387            // Each sub-operation does some communication
388            event!(
389                target: "cryprot_metrics",
390                Level::TRACE,
391                bytes_written = 100,
392            );
393            event!(
394                target: "cryprot_metrics",
395                Level::TRACE,
396                bytes_read = 50
397            );
398            // Simulate some work to increase chance of overlap
399            sleep(Duration::from_millis(10)).await;
400        }
401
402        // Set up the metrics layer
403        let (comm_layer, comm_data) = new_comm_layer();
404        let subscriber = Registry::default().with(comm_layer);
405        let _guard = tracing::subscriber::set_default(subscriber);
406
407        // Run parallel operations
408        join!(parallel_operation(1), parallel_operation(2));
409
410        // Verify metrics
411        let metrics = comm_data.comm_data();
412        let phase = metrics
413            .get("ParentPhase")
414            .expect("ParentPhase should exist");
415
416        // The sub-phase metrics should accumulate from both parallel operations
417        let sub_phase = phase
418            .sub_comm_data
419            .get("SubPhase")
420            .expect("SubPhase should exist");
421
422        // Each parallel operation writes 100 bytes in the sub-phase
423        // So we expect 200 total bytes written in the sub-phase
424        assert_eq!(
425            sub_phase.write.bytes, 200,
426            "Expected accumulated writes from both parallel operations"
427        );
428
429        // Each parallel operation reads 50 bytes in the sub-phase
430        // So we expect 100 total bytes read in the sub-phase
431        assert_eq!(
432            sub_phase.read.bytes, 100,
433            "Expected accumulated reads from both parallel operations"
434        );
435
436        // Parent phase should accumulate all sub-phase metrics
437        assert_eq!(
438            phase.write.bytes_with_sub_comm, 200,
439            "Parent should include all sub-phase writes"
440        );
441        assert_eq!(
442            phase.read.bytes_with_sub_comm, 100,
443            "Parent should include all sub-phase reads"
444        );
445    }
446}