rapx/check/senryx/
driver.rs

1//! Top-level driver for Senryx verification and annotation discovery.
2//!
3//! `SenryxCheck` collects target functions, prepares supporting analyses, runs
4//! path-sensitive `BodyVisitor` checks, and formats the resulting contract
5//! diagnostics.
6
7use super::{
8    dominated_graph::InterResultNode,
9    visitor::{BodyVisitor, CheckResult},
10};
11use rustc_data_structures::fx::FxHashMap;
12use rustc_hir::{Safety, def_id::DefId};
13use rustc_middle::{
14    mir::{BasicBlock, Operand, TerminatorKind},
15    ty::{self, TyCtxt},
16};
17use std::collections::HashSet;
18
19use crate::analysis::{
20    Analysis,
21    core::alias_analysis::{AliasAnalysis, FnAliasPairs, default::AliasAnalyzer},
22    upg::{fn_collector::FnCollector, hir_visitor::ContainsUnsafe},
23    utils::fn_info::*,
24};
25
26macro_rules! cond_print {
27    ($cond:expr, $($t:tt)*) => {if $cond {rap_warn!($($t)*)} else {rap_info!($($t)*)}};
28}
29
30/// Controls how aggressively Senryx filters candidate verification targets.
31pub enum CheckLevel {
32    High,
33    Medium,
34    Low,
35}
36
37/// Entry point for running Senryx analyses over a Rust crate.
38pub struct SenryxCheck<'tcx> {
39    pub tcx: TyCtxt<'tcx>,
40    pub threshhold: usize,
41}
42
43impl<'tcx> SenryxCheck<'tcx> {
44    /// Create a new SenryxCheck instance.
45    ///
46    /// Parameters:
47    /// - `tcx`: compiler TyCtxt for querying types/definitions.
48    /// - `threshhold`: a numeric threshold used by checks.
49    pub fn new(tcx: TyCtxt<'tcx>, threshhold: usize) -> Self {
50        Self { tcx, threshhold }
51    }
52
53    /// Start the checking pass over the collected functions.
54    ///
55    /// - `check_level` controls filtering of which functions to analyze.
56    /// - `is_verify` toggles verification mode (vs. annotation mode).
57    pub fn start(&mut self, check_level: CheckLevel, is_verify: bool) {
58        let tcx = self.tcx;
59        // Build alias information for all functions first.
60        let mut analyzer = AliasAnalyzer::new(self.tcx);
61        analyzer.run(); // populate alias results
62        let fn_map = &analyzer.get_all_fn_alias();
63
64        // Collect functions of interest (e.g. from UPG/collector)
65        let related_items = FnCollector::collect(tcx);
66        for vec in related_items.clone().values() {
67            for (body_id, _span) in vec {
68                // Check whether the function/block contains unsafe code
69                let (function_unsafe, block_unsafe) =
70                    ContainsUnsafe::contains_unsafe(tcx, *body_id);
71
72                let def_id = tcx.hir_body_owner_def_id(*body_id).to_def_id();
73
74                // Gather std unsafe callees used by this function
75                let std_unsafe_callee = get_all_std_unsafe_callees(self.tcx, def_id);
76
77                // Apply filtering by configured check level
78                if !Self::filter_by_check_level(tcx, &check_level, def_id) {
79                    continue;
80                }
81
82                // If the body-level contains unsafe ops and we are verifying, run soundness checks
83                if block_unsafe && is_verify && !std_unsafe_callee.is_empty() {
84                    self.check_soundness(def_id, fn_map);
85                }
86
87                // In non-verify mode we might annotate or produce diagnostics (disabled here)
88                if function_unsafe && !is_verify && !std_unsafe_callee.is_empty() {
89                    // annotation or other non-verification actions can be placed here
90                }
91            }
92        }
93    }
94
95    /// Iterate standard library `alloc` functions and run verification for those
96    /// that match the verification target predicate.
97    pub fn start_analyze_std_func(&mut self) {
98        // Gather function definitions from the `alloc` crate
99        let v_fn_def: Vec<_> = rustc_public::find_crates("alloc")
100            .iter()
101            .flat_map(|krate| krate.fn_defs())
102            .collect();
103        for fn_def in &v_fn_def {
104            let def_id = crate::def_id::to_internal(fn_def, self.tcx);
105            if is_verify_target_func(self.tcx, def_id) {
106                rap_info!(
107                    "Begin verification process for: {:?}",
108                    get_cleaned_def_path_name(self.tcx, def_id)
109                );
110
111                // Run main body visitor/check for this def_id
112                let check_results = self.body_visit_and_check(def_id, &FxHashMap::default());
113                if !check_results.is_empty() {
114                    Self::show_check_results(self.tcx, def_id, check_results);
115                }
116            }
117        }
118    }
119
120    /// Analyze unsafe call chains across standard library functions and print
121    /// the last non-intrinsic nodes for manual inspection.
122    pub fn start_analyze_std_func_chains(&mut self) {
123        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
124        let mut last_nodes = HashSet::new();
125        for &def_id in &all_std_fn_def {
126            // Skip non-public functions based on visibility filter
127            if !check_visibility(self.tcx, def_id) {
128                continue;
129            }
130
131            // Get unsafe call chains for the function
132            let chains = get_all_std_unsafe_chains(self.tcx, def_id);
133
134            // Filter out trivial chains unless the function is explicitly unsafe
135            let valid_chains: Vec<Vec<String>> = chains
136                .into_iter()
137                .filter(|chain| {
138                    if chain.len() > 1 {
139                        return true;
140                    }
141                    if chain.len() == 1 {
142                        if check_safety(self.tcx, def_id) == Safety::Unsafe {
143                            return true;
144                        }
145                    }
146                    false
147                })
148                .collect();
149
150            // Collect last nodes that are relevant for further inspection
151            let mut last = true;
152            for chain in &valid_chains {
153                if let Some(last_node) = chain.last() {
154                    if !last_node.contains("intrinsic") && !last_node.contains("aarch64") {
155                        last_nodes.insert(last_node.clone());
156                        last = false;
157                    }
158                }
159            }
160            if last {
161                continue;
162            }
163        }
164        Self::print_last_nodes(&last_nodes);
165    }
166
167    /// Pretty-print a set of last nodes discovered in unsafe call chains.
168    pub fn print_last_nodes(last_nodes: &HashSet<String>) {
169        if last_nodes.is_empty() {
170            println!("No unsafe call chain last nodes found.");
171            return;
172        }
173
174        println!(
175            "Found {} unique unsafe call chain last nodes:",
176            last_nodes.len()
177        );
178        for (i, node) in last_nodes.iter().enumerate() {
179            println!("{}. {}", i + 1, node);
180        }
181    }
182
183    /// Filter functions by configured check level.
184    /// - High: only publicly visible functions are considered.
185    /// - Medium/Low: accept all functions.
186    pub fn filter_by_check_level(
187        tcx: TyCtxt<'tcx>,
188        check_level: &CheckLevel,
189        def_id: DefId,
190    ) -> bool {
191        match *check_level {
192            CheckLevel::High => check_visibility(tcx, def_id),
193            _ => true,
194        }
195    }
196
197    /// Run soundness checks on a single function identified by `def_id` using
198    /// the provided alias analysis map `fn_map`.
199    pub fn check_soundness(&mut self, def_id: DefId, fn_map: &FxHashMap<DefId, FnAliasPairs>) {
200        let check_results = self.body_visit_and_check(def_id, fn_map);
201        let tcx = self.tcx;
202        if !check_results.is_empty() {
203            // Display aggregated results for this function
204            Self::show_check_results(tcx, def_id, check_results);
205        }
206    }
207
208    /// Collect safety annotations for `def_id` and display them if present.
209    pub fn annotate_safety(&self, def_id: DefId) {
210        let annotation_results = self.get_annotation(def_id);
211        if annotation_results.is_empty() {
212            return;
213        }
214        Self::show_annotate_results(self.tcx, def_id, annotation_results);
215    }
216
217    /// Visit the function body and run path-sensitive checks, returning
218    /// a list of `CheckResult`s summarizing passed/failed contracts.
219    ///
220    /// If the function is a method, constructor results are merged into the
221    /// method's initial state before analyzing the method body.
222    pub fn body_visit_and_check(
223        &mut self,
224        def_id: DefId,
225        fn_map: &FxHashMap<DefId, FnAliasPairs>,
226    ) -> Vec<CheckResult> {
227        // Create a body visitor for the target function
228        let mut body_visitor = BodyVisitor::new(self.tcx, def_id, 0);
229        let target_name = get_cleaned_def_path_name(self.tcx, def_id);
230        rap_info!("Begin verification process for: {:?}", target_name);
231
232        // If this is a method, gather constructor-derived state first
233        if get_type(self.tcx, def_id) == FnKind::Method {
234            let cons = get_cons(self.tcx, def_id);
235            // Start with a default inter-result node for ADT fields
236            let mut base_inter_result = InterResultNode::new_default(get_adt_ty(self.tcx, def_id));
237            for con in cons {
238                let mut cons_body_visitor = BodyVisitor::new(self.tcx, con, 0);
239                // Analyze constructor and merge its field states
240                let cons_fields_result = cons_body_visitor.path_forward_check(fn_map);
241                // cache and merge fields' states
242                let cons_name = get_cleaned_def_path_name(self.tcx, con);
243                println!(
244                    "cons {cons_name} state results {:?}",
245                    cons_fields_result.clone()
246                );
247                base_inter_result.merge(cons_fields_result);
248            }
249
250            // Seed the method visitor with constructor-derived field states
251            body_visitor.update_fields_states(base_inter_result);
252
253            // Optionally inspect mutable methods - diagnostic only
254            let mutable_methods = get_all_mutable_methods(self.tcx, def_id);
255            for mm in mutable_methods {
256                println!("mut method {:?}", get_cleaned_def_path_name(self.tcx, mm.0));
257            }
258
259            // Analyze the method body
260            body_visitor.path_forward_check(fn_map);
261        } else {
262            // Non-method functions: just analyze body directly
263            body_visitor.path_forward_check(fn_map);
264        }
265        body_visitor.check_results
266    }
267
268    /// Variant of `body_visit_and_check` used for UI-guided annotation flows.
269    pub fn body_visit_and_check_uig(&self, def_id: DefId) {
270        let func_type = get_type(self.tcx, def_id);
271        if func_type == FnKind::Method && !self.get_annotation(def_id).is_empty() {
272            let func_cons = search_constructor(self.tcx, def_id);
273            for func_con in func_cons {
274                if check_safety(self.tcx, func_con) == Safety::Unsafe {
275                    // Display annotations for unsafe constructors
276                    Self::show_annotate_results(self.tcx, func_con, self.get_annotation(def_id));
277                }
278            }
279        }
280    }
281
282    /// Collect annotation strings for a function by scanning calls in MIR.
283    /// For each call, if the callee has a safety annotation it is aggregated; otherwise
284    /// the callee's annotations (recursively) are collected.
285    pub fn get_annotation(&self, def_id: DefId) -> HashSet<String> {
286        let mut results = HashSet::new();
287        if !self.tcx.is_mir_available(def_id) {
288            return results;
289        }
290        let body = self.tcx.optimized_mir(def_id);
291        let basicblocks = &body.basic_blocks;
292        for i in 0..basicblocks.len() {
293            let iter = BasicBlock::from(i);
294            let terminator = basicblocks[iter].terminator.clone().unwrap();
295            if let TerminatorKind::Call {
296                ref func,
297                args: _,
298                destination: _,
299                target: _,
300                unwind: _,
301                call_source: _,
302                fn_span: _,
303            } = terminator.kind
304            {
305                match func {
306                    Operand::Constant(c) => {
307                        if let ty::FnDef(id, ..) = c.ty().kind() {
308                            // If the callee has direct annotations, extend results.
309                            if !get_sp(self.tcx, *id).is_empty() {
310                                results.extend(get_sp(self.tcx, *id));
311                            } else {
312                                // Otherwise, recurse into callee's annotations.
313                                results.extend(self.get_annotation(*id));
314                            }
315                        }
316                    }
317                    _ => {}
318                }
319            }
320        }
321        results
322    }
323
324    /// Pretty-print aggregated check results for a function.
325    /// Shows succeeded and failed contracts grouped across all arguments.
326    pub fn show_check_results(tcx: TyCtxt<'tcx>, def_id: DefId, check_results: Vec<CheckResult>) {
327        rap_info!(
328            "--------In safe function {:?}---------",
329            get_cleaned_def_path_name(tcx, def_id)
330        );
331        for check_result in &check_results {
332            // Aggregate all failed contracts from all arguments
333            let mut all_failed = HashSet::new();
334            for set in check_result.failed_contracts.values() {
335                for sp in set {
336                    all_failed.insert(sp);
337                }
338            }
339
340            // Aggregate all passed contracts from all arguments
341            let mut all_passed = HashSet::new();
342            for set in check_result.passed_contracts.values() {
343                for sp in set {
344                    all_passed.insert(sp);
345                }
346            }
347
348            // Print the API name with conditional coloring
349            cond_print!(
350                !all_failed.is_empty(),
351                "  Use unsafe api {:?}.",
352                check_result.func_name
353            );
354
355            // Print aggregated Failed set
356            if !all_failed.is_empty() {
357                let mut failed_sorted: Vec<&String> = all_failed.into_iter().collect();
358                failed_sorted.sort();
359                cond_print!(true, "      Failed: {:?}", failed_sorted);
360            }
361
362            // Print aggregated Passed set
363            if !all_passed.is_empty() {
364                let mut passed_sorted: Vec<&String> = all_passed.into_iter().collect();
365                passed_sorted.sort();
366                cond_print!(false, "      Passed: {:?}", passed_sorted);
367            }
368        }
369    }
370
371    /// Show annotation results for unsafe functions (diagnostic output).
372    pub fn show_annotate_results(
373        tcx: TyCtxt<'tcx>,
374        def_id: DefId,
375        annotation_results: HashSet<String>,
376    ) {
377        rap_info!(
378            "--------In unsafe function {:?}---------",
379            get_cleaned_def_path_name(tcx, def_id)
380        );
381        rap_warn!("Lack safety annotations: {:?}.", annotation_results);
382    }
383}