1use std::{
25 collections::{BTreeMap, btree_map::Entry},
26 fmt::Debug,
27 mem,
28 ops::AddAssign,
29 sync::{Arc, Mutex},
30};
31
32use serde::{Deserialize, Serialize};
33use tracing::{
34 Level,
35 field::{Field, Visit},
36 span::{Attributes, Id},
37 warn,
38};
39use tracing_subscriber::{
40 filter::{Filtered, Targets},
41 layer::{Context, Layer},
42};
43
44#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45pub struct CommData {
47 pub phase: String,
48 pub read: Counter,
49 pub write: Counter,
50 pub sub_comm_data: SubCommData,
51}
52
53#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
54pub struct Counter {
55 pub bytes: u64,
57 pub bytes_with_sub_comm: u64,
59}
60
61#[derive(Debug, Default, Clone, Serialize, Deserialize)]
62pub struct SubCommData(BTreeMap<String, CommData>);
64
65pub type CommLayer<S> = Filtered<CommLayerData, Targets, S>;
68
69#[derive(Clone, Debug, Default)]
70pub struct CommLayerData {
72 comm_data: Arc<Mutex<SubCommData>>,
74}
75
76pub fn new_comm_layer<S>() -> (CommLayer<S>, CommLayerData)
78where
79 S: tracing::Subscriber,
80 S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
81{
82 let inner = CommLayerData::default();
83 let target_filter = Targets::new().with_target("cryprot_metrics", Level::TRACE);
84 (inner.clone().with_filter(target_filter), inner)
85}
86
87impl CommLayerData {
88 pub fn comm_data(&self) -> SubCommData {
90 self.comm_data.lock().expect("lock poisoned").clone()
91 }
92
93 pub fn reset(&self) -> SubCommData {
98 let mut comm_data = self.comm_data.lock().expect("lock poisoned");
99 mem::take(&mut *comm_data)
100 }
101}
102
103impl<S> Layer<S> for CommLayerData
104where
105 S: tracing::Subscriber,
106 S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
107{
108 fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
109 let span = ctx.span(id).expect("Id is valid");
110 let mut visitor = PhaseVisitor(None);
111 attrs.record(&mut visitor);
112 if let Some(phase) = visitor.0 {
113 let data = CommData::new(phase);
114 span.extensions_mut().insert(data);
115 }
116 }
117
118 fn on_event(&self, event: &tracing::Event<'_>, ctx: Context<'_, S>) {
119 let Some(span) = ctx.event_span(event) else {
120 warn!(
121 "Received cryprot_metrics event outside of cryprot_metrics span. \
122 Communication is not tracked"
123 );
124 return;
125 };
126 let field_cnt = event
129 .fields()
130 .filter(|field| field.name() == "bytes_read" || field.name() == "bytes_written")
131 .count();
132 if field_cnt >= 2 {
133 warn!("Use individual events to record bytes_read and bytes_written");
134 return;
135 }
136 let mut vis = CommEventVisitor(None);
137 event.record(&mut vis);
138 if let Some(event) = vis.0 {
139 let mut extensions = span.extensions_mut();
140 let Some(comm_data) = extensions.get_mut::<CommData>() else {
141 warn!(
142 "Received cryprot_metrics event inside cryprot_metrics span with no phase. \
143 Communication is not tracked"
144 );
145 return;
146 };
147 match event {
148 CommEvent::Read(read) => {
149 comm_data.read += read;
150 }
151 CommEvent::Write(written) => {
152 comm_data.write += written;
153 }
154 }
155 }
156 }
157
158 fn on_close(&self, id: Id, ctx: Context<'_, S>) {
159 let span = ctx.span(&id).expect("Id is valid");
160 let mut extensions = span.extensions_mut();
161 let Some(comm_data) = extensions.get_mut::<CommData>().map(mem::take) else {
162 return;
164 };
165
166 if let Some(parent) = span.parent() {
169 if let Some(parent_comm_data) = parent.extensions_mut().get_mut::<CommData>() {
170 let entry = parent_comm_data
171 .sub_comm_data
172 .0
173 .entry(comm_data.phase.clone())
174 .or_insert_with(|| CommData::new(comm_data.phase.clone()));
175 parent_comm_data.read.bytes_with_sub_comm += comm_data.read.bytes_with_sub_comm;
176 parent_comm_data.write.bytes_with_sub_comm += comm_data.write.bytes_with_sub_comm;
177 merge(comm_data, entry)
178 }
179 } else {
180 let mut root_comm_data = self.comm_data.lock().expect("lock poisoned");
181 let phase_comm_data = root_comm_data
182 .0
183 .entry(comm_data.phase.clone())
184 .or_insert_with(|| CommData::new(comm_data.phase.clone()));
185 merge(comm_data, phase_comm_data);
186 }
187 }
188}
189
190fn merge(from: CommData, into: &mut CommData) {
191 into.read += from.read;
192 into.write += from.write;
193 for (phase, from_sub_comm) in from.sub_comm_data.0.into_iter() {
194 match into.sub_comm_data.0.entry(phase) {
195 Entry::Vacant(entry) => {
196 entry.insert(from_sub_comm);
197 }
198 Entry::Occupied(mut entry) => {
199 merge(from_sub_comm, entry.get_mut());
200 }
201 }
202 }
203}
204
205impl SubCommData {
206 pub fn get(&self, phase: &str) -> Option<&CommData> {
208 self.0.get(phase)
209 }
210
211 pub fn iter(&self) -> impl Iterator<Item = &CommData> {
213 self.0.values()
214 }
215}
216
217impl AddAssign for Counter {
218 fn add_assign(&mut self, rhs: Self) {
219 self.bytes += rhs.bytes;
220 self.bytes_with_sub_comm += rhs.bytes_with_sub_comm;
221 }
222}
223
224impl AddAssign<u64> for Counter {
225 fn add_assign(&mut self, rhs: u64) {
226 self.bytes += rhs;
227 self.bytes_with_sub_comm += rhs;
228 }
229}
230
231impl CommData {
232 fn new(phase: String) -> Self {
233 Self {
234 phase,
235 ..Default::default()
236 }
237 }
238}
239
240struct PhaseVisitor(Option<String>);
241
242impl Visit for PhaseVisitor {
243 fn record_str(&mut self, field: &Field, value: &str) {
244 if field.name() == "phase" {
245 self.0 = Some(value.to_owned());
246 }
247 }
248
249 fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
250 if field.name() == "phase" {
251 self.0 = Some(format!("{value:?}"));
252 }
253 }
254}
255
256enum CommEvent {
257 Read(u64),
258 Write(u64),
259}
260
261struct CommEventVisitor(Option<CommEvent>);
262
263impl CommEventVisitor {
264 fn record<T>(&mut self, field: &Field, value: T)
265 where
266 T: TryInto<u64>,
267 T::Error: Debug,
268 {
269 let name = field.name();
270 if name != "bytes_written" && name != "bytes_read" {
271 return;
272 }
273 let value = value
274 .try_into()
275 .expect("recorded bytes must be convertible to u64");
276 if name == "bytes_written" {
277 self.0 = Some(CommEvent::Write(value))
278 } else if name == "bytes_read" {
279 self.0 = Some(CommEvent::Read(value))
280 }
281 }
282}
283
284impl Visit for CommEventVisitor {
285 fn record_i64(&mut self, field: &Field, value: i64) {
286 self.record(field, value);
287 }
288 fn record_u64(&mut self, field: &Field, value: u64) {
289 self.record(field, value)
290 }
291 fn record_i128(&mut self, field: &Field, value: i128) {
292 self.record(field, value)
293 }
294 fn record_u128(&mut self, field: &Field, value: u128) {
295 self.record(field, value)
296 }
297 fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
298 warn!(
299 "cryprot_metrics event with field which is not an integer. {}: {:?}",
300 field.name(),
301 value
302 )
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use std::time::Duration;
309
310 use tokio::{self, join, time::sleep};
311 use tracing::{Instrument, Level, event, instrument};
312 use tracing_subscriber::{Registry, layer::SubscriberExt};
313
314 use crate::metrics::new_comm_layer;
315
316 #[tokio::test]
317 async fn test_communication_metrics() {
318 #[instrument(target = "cryprot_metrics", fields(phase = "TopLevel"))]
319 async fn top_level_operation() {
320 event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 100);
322 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 200);
323
324 sub_operation().await;
326 }
327
328 #[instrument(target = "cryprot_metrics", fields(phase = "SubOperation"))]
329 async fn sub_operation() {
330 event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 50);
332 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 100);
333 }
334
335 let (comm_layer, comm_data) = new_comm_layer();
337 let subscriber = Registry::default().with(comm_layer);
338 let _guard = tracing::subscriber::set_default(subscriber);
339
340 top_level_operation().await;
342
343 let metrics = comm_data.comm_data();
345
346 let top_phase = metrics
348 .get("TopLevel")
349 .expect("TopLevel phase should exist");
350 assert_eq!(top_phase.phase, "TopLevel");
351 assert_eq!(top_phase.read.bytes, 100);
352 assert_eq!(top_phase.write.bytes, 200);
353 assert_eq!(top_phase.read.bytes_with_sub_comm, 150); assert_eq!(top_phase.write.bytes_with_sub_comm, 300); let sub_phase = top_phase
358 .sub_comm_data
359 .get("SubOperation")
360 .expect("SubOperation phase should exist");
361 assert_eq!(sub_phase.phase, "SubOperation");
362 assert_eq!(sub_phase.read.bytes, 50);
363 assert_eq!(sub_phase.write.bytes, 100);
364 assert_eq!(sub_phase.read.bytes_with_sub_comm, 50);
365 assert_eq!(sub_phase.write.bytes_with_sub_comm, 100);
366
367 let reset_metrics = comm_data.reset();
369 assert!(reset_metrics.get("TopLevel").is_some());
370 let new_metrics = comm_data.comm_data();
371 assert!(new_metrics.get("TopLevel").is_none());
372 }
373
374 #[tokio::test]
375 async fn test_parallel_span_accumulation() {
376 #[instrument(target = "cryprot_metrics", fields(phase = "ParentPhase"))]
377 async fn parallel_operation(id: u32) {
378 tokio::spawn(sub_operation(id).in_current_span())
381 .await
382 .unwrap();
383 }
384
385 #[instrument(target = "cryprot_metrics", fields(phase = "SubPhase"))]
386 async fn sub_operation(id: u32) {
387 event!(
389 target: "cryprot_metrics",
390 Level::TRACE,
391 bytes_written = 100,
392 );
393 event!(
394 target: "cryprot_metrics",
395 Level::TRACE,
396 bytes_read = 50
397 );
398 sleep(Duration::from_millis(10)).await;
400 }
401
402 let (comm_layer, comm_data) = new_comm_layer();
404 let subscriber = Registry::default().with(comm_layer);
405 let _guard = tracing::subscriber::set_default(subscriber);
406
407 join!(parallel_operation(1), parallel_operation(2));
409
410 let metrics = comm_data.comm_data();
412 let phase = metrics
413 .get("ParentPhase")
414 .expect("ParentPhase should exist");
415
416 let sub_phase = phase
418 .sub_comm_data
419 .get("SubPhase")
420 .expect("SubPhase should exist");
421
422 assert_eq!(
425 sub_phase.write.bytes, 200,
426 "Expected accumulated writes from both parallel operations"
427 );
428
429 assert_eq!(
432 sub_phase.read.bytes, 100,
433 "Expected accumulated reads from both parallel operations"
434 );
435
436 assert_eq!(
438 phase.write.bytes_with_sub_comm, 200,
439 "Parent should include all sub-phase writes"
440 );
441 assert_eq!(
442 phase.read.bytes_with_sub_comm, 100,
443 "Parent should include all sub-phase reads"
444 );
445 }
446}