rapx/verify/
helpers.rs

1use rustc_abi::FieldIdx;
2use rustc_hir::{Safety, def_id::DefId};
3use rustc_middle::{
4    mir::{BasicBlock, Operand, TerminatorKind, UnwindAction},
5    ty::{self, Ty, TyCtxt, TyKind},
6};
7use rustc_span::Span;
8use serde_json::Value;
9use std::collections::HashSet;
10use syn::Expr;
11
12/// Stable MIR location for a call terminator inside one function body.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
14pub struct CallsiteLocation {
15    /// Function containing the call terminator.
16    pub caller: DefId,
17    /// Basic block whose terminator is the call.
18    pub block: BasicBlock,
19}
20
21/// A concrete unsafe callsite in one MIR body.
22#[derive(Clone, Debug)]
23pub struct Callsite<'tcx> {
24    /// Function containing this call.
25    pub caller: DefId,
26    /// Unsafe callee being invoked.
27    pub callee: DefId,
28    /// MIR block whose terminator is the call.
29    pub block: BasicBlock,
30    /// Source span attached to the MIR call terminator.
31    pub span: Span,
32    /// MIR operands passed to the callee.
33    pub args: Vec<Operand<'tcx>>,
34}
35
36impl<'tcx> Callsite<'tcx> {
37    /// Return the MIR location that identifies this callsite inside the verifier.
38    pub fn location(&self) -> CallsiteLocation {
39        CallsiteLocation {
40            caller: self.caller,
41            block: self.block,
42        }
43    }
44
45    /// Return a stable human-readable callee path for diagnostics.
46    pub fn callee_name(&self, tcx: TyCtxt<'tcx>) -> String {
47        get_cleaned_def_path_name(tcx, self.callee)
48    }
49}
50
51/// Collect all unsafe MIR callsites in `def_id`.
52pub fn collect_unsafe_callsites<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Vec<Callsite<'tcx>> {
53    let mut callsites = Vec::new();
54    if !tcx.is_mir_available(def_id) {
55        return callsites;
56    }
57
58    let body = tcx.optimized_mir(def_id);
59    for (bb, data) in body.basic_blocks.iter_enumerated() {
60        let TerminatorKind::Call {
61            func,
62            args,
63            fn_span,
64            ..
65        } = &data.terminator().kind
66        else {
67            continue;
68        };
69
70        let Operand::Constant(func_constant) = func else {
71            continue;
72        };
73
74        let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() else {
75            continue;
76        };
77
78        if check_safety(tcx, *callee_def_id) != Safety::Unsafe {
79            continue;
80        }
81
82        callsites.push(Callsite {
83            caller: def_id,
84            callee: *callee_def_id,
85            block: bb,
86            span: *fn_span,
87            args: args.iter().map(|arg| arg.node.clone()).collect(),
88        });
89    }
90
91    callsites
92}
93
94/// Return the set of unsafe callees invoked by `def_id`.
95pub fn get_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
96    let mut unsafe_callees = HashSet::new();
97    if tcx.is_mir_available(def_id) {
98        let body = tcx.optimized_mir(def_id);
99        for bb in body.basic_blocks.iter() {
100            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
101                if let Operand::Constant(func_constant) = func {
102                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
103                        if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
104                            unsafe_callees.insert(*callee_def_id);
105                        }
106                    }
107                }
108            }
109        }
110    }
111    unsafe_callees
112}
113
114/// A compact MIR CFG used by the verifier path extractor.
115#[derive(Clone, Debug)]
116pub struct CFG {
117    pub entry: BasicBlock,
118    pub successors: Vec<Vec<BasicBlock>>,
119}
120
121impl CFG {
122    /// Build a successor graph from optimized MIR.
123    pub fn new(tcx: TyCtxt<'_>, def_id: DefId) -> Self {
124        let body = tcx.optimized_mir(def_id);
125        let successors = body
126            .basic_blocks
127            .iter()
128            .map(|block| terminator_successors(&block.terminator().kind))
129            .collect();
130        Self {
131            entry: BasicBlock::from_usize(0),
132            successors,
133        }
134    }
135
136    /// Return successors of a block.
137    pub fn successors(&self, block: BasicBlock) -> &[BasicBlock] {
138        self.successors
139            .get(block.as_usize())
140            .map(Vec::as_slice)
141            .unwrap_or(&[])
142    }
143}
144
145/// Compute MIR successor blocks for one terminator.
146///
147/// The extractor includes normal successors and cleanup successors so the
148/// skeleton reflects all CFG edges that can affect reachability.  Later phases
149/// may decide whether a cleanup path is relevant to a particular obligation.
150fn terminator_successors(kind: &TerminatorKind<'_>) -> Vec<BasicBlock> {
151    let mut successors = Vec::new();
152    match kind {
153        TerminatorKind::Goto { target } => successors.push(*target),
154        TerminatorKind::SwitchInt { targets, .. } => {
155            successors.extend(targets.all_targets().iter().copied());
156        }
157        TerminatorKind::Drop { target, unwind, .. }
158        | TerminatorKind::Assert { target, unwind, .. } => {
159            successors.push(*target);
160            push_unwind_target(unwind, &mut successors);
161        }
162        TerminatorKind::Call { target, unwind, .. } => {
163            if let Some(target) = target {
164                successors.push(*target);
165            }
166            push_unwind_target(unwind, &mut successors);
167        }
168        TerminatorKind::Yield { resume, drop, .. } => {
169            successors.push(*resume);
170            if let Some(drop) = drop {
171                successors.push(*drop);
172            }
173        }
174        TerminatorKind::FalseEdge { real_target, .. } => successors.push(*real_target),
175        TerminatorKind::FalseUnwind {
176            real_target,
177            unwind,
178        } => {
179            successors.push(*real_target);
180            push_unwind_target(unwind, &mut successors);
181        }
182        TerminatorKind::InlineAsm {
183            targets, unwind, ..
184        } => {
185            successors.extend(targets.iter().copied());
186            push_unwind_target(unwind, &mut successors);
187        }
188        TerminatorKind::Return
189        | TerminatorKind::Unreachable
190        | TerminatorKind::UnwindResume
191        | TerminatorKind::UnwindTerminate(_)
192        | TerminatorKind::CoroutineDrop
193        | TerminatorKind::TailCall { .. } => {}
194    }
195    successors.sort_unstable_by_key(|bb| bb.as_usize());
196    successors.dedup();
197    successors
198}
199
200/// Append a cleanup unwind target when one exists.
201fn push_unwind_target(unwind: &UnwindAction, successors: &mut Vec<BasicBlock>) {
202    if let UnwindAction::Cleanup(target) = unwind {
203        successors.push(*target);
204    }
205}
206
207pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
208    let def_id_str = format!("{:?}", def_id);
209    let mut parts: Vec<&str> = def_id_str.split("::").collect();
210
211    let mut remove_first = false;
212    if let Some(first_part) = parts.get_mut(0) {
213        if first_part.contains("core") {
214            *first_part = "core";
215        } else if first_part.contains("std") {
216            *first_part = "std";
217        } else if first_part.contains("alloc") {
218            *first_part = "alloc";
219        } else {
220            remove_first = true;
221        }
222    }
223    if remove_first && !parts.is_empty() {
224        parts.remove(0);
225    }
226
227    let new_parts: Vec<String> = parts
228        .into_iter()
229        .filter_map(|s| {
230            if s.contains("{") {
231                if remove_first {
232                    get_struct_name(tcx, def_id)
233                } else {
234                    None
235                }
236            } else {
237                Some(s.to_string())
238            }
239        })
240        .collect();
241
242    let mut cleaned_path = new_parts.join("::");
243    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
244    cleaned_path
245}
246
247pub fn parse_expr_into_local_and_ty<'tcx>(
248    tcx: TyCtxt<'tcx>,
249    def_id: DefId,
250    expr: &Expr,
251) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
252    if let Some((base_ident, fields)) = access_ident_recursive(expr) {
253        let (param_names, param_tys) = parse_signature(tcx, def_id);
254        if param_names[0] != "0" {
255            if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
256                return resolve_projection_from_base_ident(
257                    tcx,
258                    base_ident,
259                    fields,
260                    param_index + 1,
261                    param_tys[param_index],
262                );
263            }
264        }
265
266        if let Some(struct_ty) = get_struct_self_ty(tcx, def_id) {
267            return resolve_projection_from_struct_ident(tcx, base_ident, fields, struct_ty);
268        }
269    }
270    None
271}
272
273pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
274    if def_id.as_local().is_some() {
275        parse_local_signature(tcx, def_id)
276    } else {
277        parse_outside_signature(tcx, def_id)
278    }
279}
280
281pub fn parse_local_signature<'tcx>(
282    tcx: TyCtxt<'tcx>,
283    def_id: DefId,
284) -> (Vec<String>, Vec<Ty<'tcx>>) {
285    let local_def_id = def_id.as_local().unwrap();
286    let hir_body = tcx.hir_body_owned_by(local_def_id);
287    if hir_body.params.is_empty() {
288        return (vec!["0".to_string()], Vec::new());
289    }
290
291    let params = hir_body.params;
292    let typeck_results = tcx.typeck_body(hir_body.id());
293    let mut param_names = Vec::new();
294    let mut param_tys = Vec::new();
295    for param in params {
296        match param.pat.kind {
297            rustc_hir::PatKind::Binding(_, _, ident, _) => {
298                param_names.push(ident.name.to_string());
299                let ty = typeck_results.pat_ty(param.pat);
300                param_tys.push(ty);
301            }
302            _ => {
303                param_names.push(String::new());
304                param_tys.push(typeck_results.pat_ty(param.pat));
305            }
306        }
307    }
308    (param_names, param_tys)
309}
310
311fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
312    let sig = tcx.fn_sig(def_id).skip_binder();
313    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
314
315    if let Some(args_name) = get_known_std_names(tcx, def_id) {
316        return (args_name, param_tys);
317    }
318
319    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
320    (args_name, param_tys)
321}
322
323fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
324    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
325    let json_data = get_std_api_signature_json();
326
327    if let Some(arg_info) = json_data.get(&std_func_name) {
328        if let Some(args_name) = arg_info.as_array() {
329            if args_name.is_empty() {
330                return Some(vec!["0".to_string()]);
331            }
332            let mut result = Vec::new();
333            for arg in args_name {
334                if let Some(sp_name) = arg.as_str() {
335                    result.push(sp_name.to_string());
336                }
337            }
338            return Some(result);
339        }
340    }
341    None
342}
343
344fn get_std_api_signature_json() -> Value {
345    serde_json::from_str(include_str!("../analysis/utils/data/std_sig.json"))
346        .expect("Unable to parse JSON")
347}
348
349pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
350    match expr {
351        Expr::Path(syn::ExprPath { path, .. }) => {
352            if path.segments.len() == 1 {
353                let ident = path.segments[0].ident.to_string();
354                Some((ident, Vec::new()))
355            } else {
356                None
357            }
358        }
359        Expr::Field(syn::ExprField { base, member, .. }) => {
360            let (base_ident, mut fields) =
361                if let Some((base_ident, fields)) = access_ident_recursive(base) {
362                    (base_ident, fields)
363                } else {
364                    return None;
365                };
366            let field_name = match member {
367                syn::Member::Named(ident) => ident.to_string(),
368                syn::Member::Unnamed(index) => index.index.to_string(),
369            };
370            fields.push(field_name);
371            Some((base_ident, fields))
372        }
373        _ => None,
374    }
375}
376
377pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
378    if let Expr::Lit(expr_lit) = expr {
379        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
380            return lit_int.base10_parse::<usize>().ok();
381        }
382    }
383    None
384}
385
386pub fn match_ty_with_ident<'tcx>(
387    tcx: TyCtxt<'tcx>,
388    def_id: DefId,
389    type_ident: String,
390) -> Option<Ty<'tcx>> {
391    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
392        return Some(primitive_ty);
393    }
394    find_generic_param(tcx, def_id, type_ident)
395}
396
397fn match_primitive_type<'tcx>(tcx: TyCtxt<'tcx>, type_ident: String) -> Option<Ty<'tcx>> {
398    match type_ident.as_str() {
399        "i8" => Some(tcx.types.i8),
400        "i16" => Some(tcx.types.i16),
401        "i32" => Some(tcx.types.i32),
402        "i64" => Some(tcx.types.i64),
403        "i128" => Some(tcx.types.i128),
404        "isize" => Some(tcx.types.isize),
405        "u8" => Some(tcx.types.u8),
406        "u16" => Some(tcx.types.u16),
407        "u32" => Some(tcx.types.u32),
408        "u64" => Some(tcx.types.u64),
409        "u128" => Some(tcx.types.u128),
410        "usize" => Some(tcx.types.usize),
411        "f16" => Some(tcx.types.f16),
412        "f32" => Some(tcx.types.f32),
413        "f64" => Some(tcx.types.f64),
414        "f128" => Some(tcx.types.f128),
415        "bool" => Some(tcx.types.bool),
416        "char" => Some(tcx.types.char),
417        "str" => Some(tcx.types.str_),
418        _ => None,
419    }
420}
421
422fn find_generic_param<'tcx>(
423    tcx: TyCtxt<'tcx>,
424    def_id: DefId,
425    type_ident: String,
426) -> Option<Ty<'tcx>> {
427    let (_, param_tys) = parse_signature(tcx, def_id);
428    for &ty in &param_tys {
429        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
430            return Some(found);
431        }
432    }
433
434    if let Some(struct_ty) = get_struct_self_ty(tcx, def_id) {
435        if let Some(found) = find_generic_in_ty(tcx, struct_ty, &type_ident) {
436            return Some(found);
437        }
438    }
439
440    None
441}
442
443fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
444    match ty.kind() {
445        TyKind::Param(param_ty) => {
446            if param_ty.name.as_str() == type_ident {
447                return Some(ty);
448            }
449        }
450        TyKind::RawPtr(ty, _)
451        | TyKind::Ref(_, ty, _)
452        | TyKind::Slice(ty)
453        | TyKind::Array(ty, _) => {
454            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
455                return Some(found);
456            }
457        }
458        TyKind::Tuple(tys) => {
459            for tuple_ty in tys.iter() {
460                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
461                    return Some(found);
462                }
463            }
464        }
465        TyKind::Adt(adt_def, substs) => {
466            let name = tcx.item_name(adt_def.did()).to_string();
467            if name == type_ident {
468                return Some(ty);
469            }
470            for field in adt_def.all_fields() {
471                let field_ty = field.ty(tcx, substs);
472                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
473                    return Some(found);
474                }
475            }
476        }
477        _ => {}
478    }
479    None
480}
481
482fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
483    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
484        if let Some(impl_id) = assoc_item.impl_container(tcx) {
485            let ty = tcx.type_of(impl_id).skip_binder();
486            let type_name = ty.to_string();
487            let struct_name = type_name
488                .split('<')
489                .next()
490                .unwrap_or("")
491                .split("::")
492                .last()
493                .unwrap_or("")
494                .to_string();
495            return Some(struct_name);
496        }
497    }
498    None
499}
500
501fn get_struct_self_ty<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Ty<'tcx>> {
502    let assoc_item = tcx.opt_associated_item(def_id)?;
503    let impl_id = assoc_item.impl_container(tcx)?;
504    let self_ty = tcx.type_of(impl_id).skip_binder();
505    match self_ty.kind() {
506        TyKind::Adt(_, _) => Some(self_ty),
507        _ => None,
508    }
509}
510
511fn resolve_projection_from_base_ident<'tcx>(
512    tcx: TyCtxt<'tcx>,
513    base_ident: String,
514    fields: Vec<String>,
515    base_local: usize,
516    base_ty: Ty<'tcx>,
517) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
518    let mut current_ty = base_ty;
519    let mut field_indices = Vec::new();
520    for field_name in fields {
521        let Some((field_idx, field_ty)) = resolve_next_field(tcx, current_ty, &field_name) else {
522            return if field_indices.is_empty() && base_ident.is_empty() {
523                None
524            } else {
525                None
526            };
527        };
528        current_ty = field_ty;
529        field_indices.push((field_idx, current_ty));
530    }
531    Some((base_local, field_indices, current_ty))
532}
533
534fn resolve_projection_from_struct_ident<'tcx>(
535    tcx: TyCtxt<'tcx>,
536    base_ident: String,
537    fields: Vec<String>,
538    struct_ty: Ty<'tcx>,
539) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
540    let Some((field_idx, field_ty)) = resolve_next_field(tcx, struct_ty, &base_ident) else {
541        return None;
542    };
543
544    let mut current_ty = field_ty;
545    let mut field_indices = vec![(field_idx, current_ty)];
546    for field_name in fields {
547        let Some((next_field_idx, next_field_ty)) =
548            resolve_next_field(tcx, current_ty, &field_name)
549        else {
550            return None;
551        };
552        current_ty = next_field_ty;
553        field_indices.push((next_field_idx, current_ty));
554    }
555
556    Some((1, field_indices, current_ty))
557}
558
559fn resolve_next_field<'tcx>(
560    tcx: TyCtxt<'tcx>,
561    base_ty: Ty<'tcx>,
562    field_name: &str,
563) -> Option<(usize, Ty<'tcx>)> {
564    let peeled_ty = base_ty.peel_refs();
565    if let TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
566        let variant = adt_def.non_enum_variant();
567        if let Ok(field_idx) = field_name.parse::<usize>() {
568            if field_idx < variant.fields.len() {
569                let field_ty = variant.fields[FieldIdx::from_usize(field_idx)].ty(tcx, arg_list);
570                return Some((field_idx, field_ty));
571            }
572        }
573        if let Some((idx, _)) = variant
574            .fields
575            .iter()
576            .enumerate()
577            .find(|(_, f)| f.ident(tcx).name.to_string() == field_name)
578        {
579            let field_ty = variant.fields[FieldIdx::from_usize(idx)].ty(tcx, arg_list);
580            return Some((idx, field_ty));
581        }
582    }
583    None
584}
585
586pub(crate) fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> Safety {
587    let poly_fn_sig = tcx.fn_sig(def_id);
588    let fn_sig = poly_fn_sig.skip_binder();
589    fn_sig.safety()
590}