rapx/analysis/utils/
fn_info.rs

1use super::draw_dot::render_dot_string;
2use crate::def_id::*;
3use crate::{
4    analysis::core::dataflow::{DataFlowAnalysis, default::DataFlowAnalyzer},
5    check::senryx::{callsite::has_unsafe_api_contract, contract::PropertyContract},
6};
7use crate::{rap_debug, rap_warn};
8use rustc_ast::ItemKind;
9use rustc_data_structures::fx::{FxHashMap, FxHashSet};
10use rustc_hir::{
11    Attribute, ImplItemKind, Safety,
12    def::DefKind,
13    def_id::{CrateNum, DefId, DefIndex},
14};
15use rustc_middle::{
16    hir::place::PlaceBase,
17    mir::{
18        BasicBlock, BinOp, Body, Local, Operand, Place, PlaceElem, PlaceRef, ProjectionElem,
19        Rvalue, StatementKind, Terminator, TerminatorKind,
20    },
21    ty,
22    ty::{AssocKind, ConstKind, Mutability, Ty, TyCtxt, TyKind},
23};
24use rustc_span::{def_id::LocalDefId, kw, sym};
25use serde::de;
26use serde::{Deserialize, Serialize};
27use std::{
28    collections::{HashMap, HashSet},
29    fmt::Debug,
30    hash::Hash,
31};
32use syn::Expr;
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
35pub enum FnKind {
36    Fn,
37    Method,
38    Constructor,
39    Intrinsic,
40}
41
42#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
43pub struct FnInfo {
44    pub def_id: DefId,
45    pub fn_safety: Safety,
46    pub fn_kind: FnKind,
47}
48
49impl FnInfo {
50    pub fn new(def_id: DefId, fn_safety: Safety, fn_kind: FnKind) -> Self {
51        FnInfo {
52            def_id,
53            fn_safety,
54            fn_kind,
55        }
56    }
57}
58
59#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
60pub struct AdtInfo {
61    pub def_id: DefId,
62    pub literal_cons_enabled: bool,
63}
64
65impl AdtInfo {
66    pub fn new(def_id: DefId, literal_cons_enabled: bool) -> Self {
67        AdtInfo {
68            def_id,
69            literal_cons_enabled,
70        }
71    }
72}
73
74pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
75    if !tcx.visibility(func_defid).is_public() {
76        return false;
77    }
78    true
79}
80
81pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
82    for child in tcx.module_children_local(module_defid) {
83        if child.vis.is_public() {
84            if let Some(def_id) = child.res.opt_def_id() {
85                if def_id == target_defid {
86                    return true;
87                }
88            }
89        }
90    }
91    false
92}
93
94pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
95    for item in set {
96        println!("{:?}", item);
97    }
98    println!("---------------");
99}
100
101pub fn get_cleaned_def_path_name_ori(tcx: TyCtxt, def_id: DefId) -> String {
102    let def_id_str = format!("{:?}", def_id);
103    let mut parts: Vec<&str> = def_id_str.split("::").collect();
104
105    let mut remove_first = false;
106    if let Some(first_part) = parts.get_mut(0) {
107        if first_part.contains("core") {
108            *first_part = "core";
109        } else if first_part.contains("std") {
110            *first_part = "std";
111        } else if first_part.contains("alloc") {
112            *first_part = "alloc";
113        } else {
114            remove_first = true;
115        }
116    }
117    if remove_first && !parts.is_empty() {
118        parts.remove(0);
119    }
120
121    let new_parts: Vec<String> = parts
122        .into_iter()
123        .filter_map(|s| {
124            if s.contains("{") {
125                if remove_first {
126                    get_struct_name(tcx, def_id)
127                } else {
128                    None
129                }
130            } else {
131                Some(s.to_string())
132            }
133        })
134        .collect();
135
136    let mut cleaned_path = new_parts.join("::");
137    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
138    cleaned_path
139}
140
141pub fn get_sp_tags_json() -> serde_json::Value {
142    let json_data: serde_json::Value =
143        serde_json::from_str(include_str!("data/std_sps.json")).expect("Unable to parse JSON");
144    json_data
145}
146
147pub fn get_std_api_signature_json() -> serde_json::Value {
148    let json_data: serde_json::Value =
149        serde_json::from_str(include_str!("data/std_sig.json")).expect("Unable to parse JSON");
150    json_data
151}
152
153pub fn get_sp_tags_and_args_json() -> serde_json::Value {
154    let json_data: serde_json::Value =
155        serde_json::from_str(include_str!("data/std_sps_args.json")).expect("Unable to parse JSON");
156    json_data
157}
158
159#[derive(Debug, Serialize, Deserialize, Clone)]
160pub struct ContractEntry {
161    pub tag: String,
162    pub args: Vec<String>,
163}
164
165pub fn get_std_contracts(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<ContractEntry> {
166    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
167    let json_data: serde_json::Value = get_sp_tags_and_args_json();
168
169    if let Some(entries) = json_data.get(&cleaned_path_name) {
170        if let Ok(contracts) = serde_json::from_value::<Vec<ContractEntry>>(entries.clone()) {
171            return contracts;
172        }
173    }
174    Vec::new()
175}
176
177pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
178    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
179    let json_data: serde_json::Value = get_sp_tags_json();
180
181    if let Some(function_info) = json_data.get(&cleaned_path_name) {
182        if let Some(sp_list) = function_info.get("0") {
183            let mut result = HashSet::new();
184            if let Some(sp_array) = sp_list.as_array() {
185                for sp in sp_array {
186                    if let Some(sp_name) = sp.as_str() {
187                        result.insert(sp_name.to_string());
188                    }
189                }
190            }
191            return result;
192        }
193    }
194    HashSet::new()
195}
196
197pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
198    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
199        if let Some(impl_id) = assoc_item.impl_container(tcx) {
200            let ty = tcx.type_of(impl_id).skip_binder();
201            let type_name = ty.to_string();
202            let struct_name = type_name
203                .split('<')
204                .next()
205                .unwrap_or("")
206                .split("::")
207                .last()
208                .unwrap_or("")
209                .to_string();
210
211            return Some(struct_name);
212        }
213    }
214    None
215}
216
217pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> Safety {
218    let poly_fn_sig = tcx.fn_sig(def_id);
219    let fn_sig = poly_fn_sig.skip_binder();
220    fn_sig.safety()
221}
222
223pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> FnKind {
224    let mut node_type = 2;
225    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
226        match assoc_item.kind {
227            AssocKind::Fn { has_self, .. } => {
228                if has_self {
229                    return FnKind::Method;
230                } else {
231                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
232                    let output = fn_sig.output().skip_binder();
233                    // return type is 'Self'
234                    if output.is_param(0) {
235                        return FnKind::Constructor;
236                    }
237                    // return type is struct's name
238                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
239                        let ty = tcx.type_of(impl_id).skip_binder();
240                        if output == ty {
241                            return FnKind::Constructor;
242                        }
243                    }
244                    match output.kind() {
245                        TyKind::Ref(_, ref_ty, _) => {
246                            if ref_ty.is_param(0) {
247                                return FnKind::Constructor;
248                            }
249                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
250                                let ty = tcx.type_of(impl_id).skip_binder();
251                                if *ref_ty == ty {
252                                    return FnKind::Constructor;
253                                }
254                            }
255                        }
256                        TyKind::Adt(adt_def, substs) => {
257                            if adt_def.is_enum()
258                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
259                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
260                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
261                            {
262                                let inner_ty = substs.type_at(0);
263                                if inner_ty.is_param(0) {
264                                    return FnKind::Constructor;
265                                }
266                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
267                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
268                                    if inner_ty == ty_impl {
269                                        return FnKind::Constructor;
270                                    }
271                                }
272                            }
273                        }
274                        _ => {}
275                    }
276                }
277            }
278            _ => todo!(),
279        }
280    }
281    return FnKind::Fn;
282}
283
284pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
285    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
286        if let Some(impl_id) = assoc_item.impl_container(tcx) {
287            return Some(tcx.type_of(impl_id).skip_binder());
288        }
289    }
290    None
291}
292
293// check whether this adt contains a literal constructor
294// result: adt_def_id, is_literal
295pub fn get_adt_via_method(tcx: TyCtxt<'_>, method_def_id: DefId) -> Option<AdtInfo> {
296    let assoc_item = tcx.opt_associated_item(method_def_id)?;
297    let impl_id = assoc_item.impl_container(tcx)?;
298    let ty = tcx.type_of(impl_id).skip_binder();
299    let adt_def = ty.ty_adt_def()?;
300    let adt_def_id = adt_def.did();
301
302    let all_fields: Vec<_> = adt_def.all_fields().collect();
303    let total_count = all_fields.len();
304
305    if total_count == 0 {
306        return Some(AdtInfo::new(adt_def_id, true));
307    }
308
309    let pub_count = all_fields
310        .iter()
311        .filter(|field| tcx.visibility(field.did).is_public())
312        .count();
313
314    if pub_count == 0 {
315        return None;
316    }
317    Some(AdtInfo::new(adt_def_id, pub_count == total_count))
318}
319
320fn place_has_raw_deref<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: &Place<'tcx>) -> bool {
321    let mut local = place.local;
322    for proj in place.projection.iter() {
323        if let ProjectionElem::Deref = proj.kind() {
324            let ty = body.local_decls[local].ty;
325            if let TyKind::RawPtr(_, _) = ty.kind() {
326                return true;
327            }
328        }
329    }
330    false
331}
332
333/// Analyzes the MIR of the given function to collect all local variables
334/// that are involved in dereferencing raw pointers (`*const T` or `*mut T`).
335pub fn get_rawptr_deref(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<Local> {
336    let mut raw_ptrs = HashSet::new();
337    if tcx.is_mir_available(def_id) {
338        let body = tcx.optimized_mir(def_id);
339        for bb in body.basic_blocks.iter() {
340            for stmt in &bb.statements {
341                if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
342                    if place_has_raw_deref(tcx, &body, lhs) {
343                        raw_ptrs.insert(lhs.local);
344                    }
345                    if let Rvalue::Use(op) = rhs {
346                        match op {
347                            Operand::Copy(place) | Operand::Move(place) => {
348                                if place_has_raw_deref(tcx, &body, place) {
349                                    raw_ptrs.insert(place.local);
350                                }
351                            }
352                            _ => {}
353                        }
354                    }
355                    if let Rvalue::Ref(_, _, place) = rhs {
356                        if place_has_raw_deref(tcx, &body, place) {
357                            raw_ptrs.insert(place.local);
358                        }
359                    }
360                }
361            }
362            if let Some(terminator) = &bb.terminator {
363                match &terminator.kind {
364                    rustc_middle::mir::TerminatorKind::Call { args, .. } => {
365                        for arg in args {
366                            match arg.node {
367                                Operand::Copy(place) | Operand::Move(place) => {
368                                    if place_has_raw_deref(tcx, &body, &place) {
369                                        raw_ptrs.insert(place.local);
370                                    }
371                                }
372                                _ => {}
373                            }
374                        }
375                    }
376                    _ => {}
377                }
378            }
379        }
380    }
381    raw_ptrs
382}
383
384/* Example mir of static mutable access.
385
386static mut COUNTER: i32 = {
387    let mut _0: i32;
388
389    bb0: {
390        _0 = const 0_i32;
391        return;
392    }
393}
394
395fn main() -> () {
396    let mut _0: ();
397    let mut _1: *mut i32;
398
399    bb0: {
400        StorageLive(_1);
401        _1 = const {alloc1: *mut i32};
402        (*_1) = const 1_i32;
403        StorageDead(_1);
404        return;
405    }
406}
407
408alloc1 (static: COUNTER, size: 4, align: 4) {
409    00 00 00 00                                     │ ....
410}
411
412*/
413
414/// Collects pairs of global static variables and their corresponding local variables
415/// within a function's MIR that are assigned from statics.
416pub fn collect_global_local_pairs(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, Vec<Local>> {
417    let mut globals: HashMap<DefId, Vec<Local>> = HashMap::new();
418
419    if !tcx.is_mir_available(def_id) {
420        return globals;
421    }
422
423    let body = tcx.optimized_mir(def_id);
424
425    for bb in body.basic_blocks.iter() {
426        for stmt in &bb.statements {
427            if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
428                if let Rvalue::Use(Operand::Constant(c)) = rhs {
429                    if let Some(static_def_id) = c.check_static_ptr(tcx) {
430                        globals.entry(static_def_id).or_default().push(lhs.local);
431                    }
432                }
433            }
434        }
435    }
436
437    globals
438}
439
440pub fn get_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
441    let mut unsafe_callees = HashSet::new();
442    if tcx.is_mir_available(def_id) {
443        let body = tcx.optimized_mir(def_id);
444        for bb in body.basic_blocks.iter() {
445            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
446                if let Operand::Constant(func_constant) = func {
447                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
448                        if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
449                            unsafe_callees.insert(*callee_def_id);
450                        }
451                    }
452                }
453            }
454        }
455    }
456    unsafe_callees
457}
458
459pub fn get_all_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
460    let mut callees = HashSet::new();
461    if tcx.is_mir_available(def_id) {
462        let body = tcx.optimized_mir(def_id);
463        for bb in body.basic_blocks.iter() {
464            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
465                if let Operand::Constant(func_constant) = func {
466                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
467                        callees.insert(*callee_def_id);
468                    }
469                }
470            }
471        }
472    }
473    callees
474}
475
476// return all the impls def id of corresponding struct
477pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
478    let mut impls = Vec::new();
479    for item_id in tcx.hir_crate_items(()).free_items() {
480        let item = tcx.hir_item(item_id);
481        if let rustc_hir::ItemKind::Impl(impl_details) = &item.kind {
482            if let rustc_hir::TyKind::Path(rustc_hir::QPath::Resolved(_, path)) =
483                &impl_details.self_ty.kind
484            {
485                if let rustc_hir::def::Res::Def(_, def_id) = path.res {
486                    if def_id == struct_def_id {
487                        impls.push(item_id.owner_id.to_def_id());
488                    }
489                }
490            }
491        }
492    }
493    impls
494}
495
496pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
497    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
498        if let Some(impl_id) = assoc_item.impl_container(tcx) {
499            // get struct ty
500            let ty = tcx.type_of(impl_id).skip_binder();
501            if let Some(adt_def) = ty.ty_adt_def() {
502                return Some(adt_def.did());
503            }
504        }
505    }
506    None
507}
508
509// get the pointee or wrapped type
510pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
511    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
512    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
513        get_pointee(*ty_mut)
514    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
515        get_pointee(*referred_ty)
516    } else {
517        matched_ty
518    };
519    pointee
520}
521
522pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
523    if let ty::RawPtr(_, _) = matched_ty.kind() {
524        return true;
525    }
526    false
527}
528
529pub fn is_ref(matched_ty: Ty<'_>) -> bool {
530    if let ty::Ref(_, _, _) = matched_ty.kind() {
531        return true;
532    }
533    false
534}
535
536pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
537    if let ty::Slice(inner) = matched_ty.kind() {
538        return Some(*inner);
539    }
540    None
541}
542
543pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
544    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
545        match assoc_item.kind {
546            AssocKind::Fn { has_self, .. } => {
547                if has_self && tcx.is_mir_available(def_id) {
548                    let body = tcx.optimized_mir(def_id);
549                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
550                    let ty = fst_arg.ty;
551                    let is_mut_ref =
552                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
553                    return fst_arg.mutability.is_mut() || is_mut_ref;
554                }
555            }
556            _ => (),
557        }
558    }
559    false
560}
561
562// Check each field's visibility, return the public fields vec
563pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
564    let adt_def = tcx.adt_def(def_id);
565    adt_def
566        .all_fields()
567        .enumerate()
568        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
569        .collect()
570}
571
572// general function for displaying hashmap
573pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
574where
575    K: Ord + Debug + Hash,
576    V: Debug,
577{
578    let indent = "  ".repeat(level);
579    let mut sorted_keys: Vec<_> = map.keys().collect();
580    sorted_keys.sort();
581
582    for key in sorted_keys {
583        if let Some(value) = map.get(key) {
584            println!("{}{:?}: {:?}", indent, key, value);
585        }
586    }
587}
588
589pub fn match_std_unsafe_chains_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
590    let mut results = Vec::new();
591    if let TerminatorKind::Call { func, .. } = &terminator.kind {
592        if let Operand::Constant(func_constant) = func {
593            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
594                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
595            }
596        }
597    }
598    results
599}
600
601pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
602    let mut results = Vec::new();
603    let body = tcx.optimized_mir(def_id);
604    let bb_len = body.basic_blocks.len();
605    for i in 0..bb_len {
606        let callees = match_std_unsafe_callee(
607            tcx,
608            body.basic_blocks[BasicBlock::from_usize(i)]
609                .clone()
610                .terminator(),
611        );
612        results.extend(callees);
613    }
614    results
615}
616
617pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
618    let mut results = Vec::new();
619    let body = tcx.optimized_mir(def_id);
620    let bb_len = body.basic_blocks.len();
621    for i in 0..bb_len {
622        if match_std_unsafe_callee(
623            tcx,
624            body.basic_blocks[BasicBlock::from_usize(i)]
625                .clone()
626                .terminator(),
627        )
628        .is_empty()
629        {
630            results.push(i);
631        }
632    }
633    results
634}
635
636pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
637    let mut results = Vec::new();
638    if let TerminatorKind::Call { func, .. } = &terminator.kind {
639        if let Operand::Constant(func_constant) = func {
640            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
641                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
642                // rap_info!("{func_name}");
643                if has_unsafe_api_contract(&func_name) {
644                    results.push(func_name);
645                }
646            }
647        }
648    }
649    results
650}
651
652// Bug definition: (1) strict -> weak & dst is mutable;
653//                 (2) _ -> strict
654pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
655    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
656}
657
658// strict ty: bool, str, adt fields containing bool or str;
659pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
660    let ty = get_pointee(ori_ty);
661    let mut flag = false;
662    if let TyKind::Adt(adt_def, substs) = ty.kind() {
663        if adt_def.is_struct() {
664            for field_def in adt_def.all_fields() {
665                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
666            }
667        }
668    }
669    ty.is_bool() || ty.is_str() || flag
670}
671
672pub fn reverse_op(op: BinOp) -> BinOp {
673    match op {
674        BinOp::Lt => BinOp::Ge,
675        BinOp::Ge => BinOp::Lt,
676        BinOp::Le => BinOp::Gt,
677        BinOp::Gt => BinOp::Le,
678        BinOp::Eq => BinOp::Eq,
679        BinOp::Ne => BinOp::Ne,
680        _ => op,
681    }
682}
683
684/// Generate contracts from pre-defined std-lib JSON configuration (std_sps_args.json).
685pub fn generate_contract_from_std_annotation_json(
686    tcx: TyCtxt<'_>,
687    def_id: DefId,
688) -> Vec<(usize, Vec<usize>, PropertyContract<'_>)> {
689    let mut results = Vec::new();
690    let std_contracts = get_std_contracts(tcx, def_id);
691
692    for entry in std_contracts {
693        let tag_name = entry.tag;
694        let raw_args = entry.args;
695
696        if raw_args.is_empty() {
697            continue;
698        }
699
700        let arg_index_str = &raw_args[0];
701        let local_id = if let Ok(arg_idx) = arg_index_str.parse::<usize>() {
702            arg_idx
703        } else {
704            rap_error!(
705                "JSON Contract Error: First argument must be an arg index number, got {}",
706                arg_index_str
707            );
708            continue;
709        };
710
711        let mut exprs: Vec<Expr> = Vec::new();
712        for arg_str in &raw_args {
713            match syn::parse_str::<Expr>(arg_str) {
714                Ok(expr) => exprs.push(expr),
715                Err(_) => {
716                    rap_error!(
717                        "JSON Contract Error: Failed to parse arg '{}' as Rust Expr for tag {}",
718                        arg_str,
719                        tag_name
720                    );
721                }
722            }
723        }
724
725        // Robustness check of arguments transition
726        if exprs.len() != raw_args.len() {
727            rap_error!(
728                "Parse std API args error: Failed to parse arg '{:?}'",
729                raw_args
730            );
731            continue;
732        }
733        let fields: Vec<usize> = Vec::new();
734        let contract = PropertyContract::new(tcx, def_id, &tag_name, &exprs);
735        results.push((local_id, fields, contract));
736    }
737
738    // rap_warn!("Get contract {:?}.", results);
739    results
740}
741
742/// Same with `generate_contract_from_annotation` but does not contain field types.
743pub fn generate_contract_from_annotation_without_field_types<'tcx>(
744    tcx: TyCtxt<'tcx>,
745    def_id: DefId,
746) -> Vec<(usize, Vec<usize>, PropertyContract<'tcx>)> {
747    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
748
749    contracts_with_ty
750        .into_iter()
751        .map(|(local_id, fields_with_ty, contract)| {
752            let fields: Vec<usize> = fields_with_ty
753                .into_iter()
754                .map(|(field_idx, _)| field_idx)
755                .collect();
756            (local_id, fields, contract)
757        })
758        .collect()
759}
760
761/// Filter the function which contains "rapx::proof"
762pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
763    for attr in tcx.get_all_attrs(def_id).into_iter() {
764        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
765        // Find proof placeholder
766        if attr_str.contains("#[rapx::proof(proof)]") {
767            return true;
768        }
769    }
770    false
771}
772
773/// Get the annotation in tag-std style.
774/// Then generate contract facts for the args.
775/// This function will recognize the args name and record states to MIR variable (represent by usize).
776/// Return value means Vec<(local_id, fields of this local, contracts)>
777pub fn generate_contract_from_annotation<'tcx>(
778    tcx: TyCtxt<'tcx>,
779    def_id: DefId,
780) -> Vec<(usize, Vec<(usize, Ty<'tcx>)>, PropertyContract<'tcx>)> {
781    const REGISTER_TOOL: &str = "rapx";
782    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
783        if let Attribute::Unparsed(tool_attr) = attr {
784            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL {
785                return true;
786            }
787        }
788        false
789    });
790    let mut results = Vec::new();
791    for attr in tool_attrs {
792        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
793        // Find proof placeholder, skip it
794        if attr_str.contains("#[rapx::proof(proof)]") {
795            continue;
796        }
797        rap_debug!("{:?}", attr_str);
798        let safety_attr = safety_parser::safety::parse_attr_and_get_properties(attr_str.as_str());
799        for par in safety_attr.iter() {
800            for property in par.tags.iter() {
801                let tag_name = property.tag.name();
802                let exprs = property.args.clone().into_vec();
803                let contract = PropertyContract::new(tcx, def_id, tag_name, &exprs);
804                let (local, fields) = parse_contract_target(tcx, def_id, exprs);
805                results.push((local, fields, contract));
806            }
807        }
808    }
809    // if results.len() > 0 {
810    //     rap_warn!("results:\n{:?}", results);
811    // }
812    results
813}
814
815/// Parse attr.expr into local id and local fields.
816///
817/// Example:
818/// ```
819/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
820/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
821/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
822/// ```
823///
824/// The first attribute will be parsed as (1, []).
825///     -> "1" means the first arg "ptr", "[]" means no fields.
826/// The second attribute will be parsed as (2, [1]).
827///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
828///
829/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
830pub fn parse_contract_target(
831    tcx: TyCtxt,
832    def_id: DefId,
833    expr: Vec<Expr>,
834) -> (usize, Vec<(usize, Ty)>) {
835    // Match expressions with the graph node that should receive the contract fact.
836    for e in expr {
837        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
838            return (base, fields);
839        }
840    }
841    (0, Vec::new())
842}
843
844/// parse single expr into (local, fields, ty)
845pub fn parse_expr_into_local_and_ty<'tcx>(
846    tcx: TyCtxt<'tcx>,
847    def_id: DefId,
848    expr: &Expr,
849) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
850    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
851        let (param_names, param_tys) = parse_signature(tcx, def_id);
852        if param_names[0] == "0".to_string() {
853            return None;
854        }
855        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
856            let mut current_ty = param_tys[param_index];
857            let mut field_indices = Vec::new();
858            for field_name in fields {
859                // peel the ref and ptr
860                let peeled_ty = current_ty.peel_refs();
861                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
862                    let variant = adt_def.non_enum_variant();
863                    // 1. if field_name is number, then parse it as usize
864                    if let Ok(field_idx) = field_name.parse::<usize>() {
865                        if field_idx < variant.fields.len() {
866                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
867                                .ty(tcx, arg_list);
868                            field_indices.push((field_idx, current_ty));
869                            continue;
870                        }
871                    }
872                    // 2. if field_name is String, then compare it with current ty's field names
873                    if let Some((idx, _)) = variant
874                        .fields
875                        .iter()
876                        .enumerate()
877                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
878                    {
879                        current_ty =
880                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
881                        field_indices.push((idx, current_ty));
882                    }
883                    // 3. if field_name can not match any fields, then break
884                    else {
885                        break; // TODO:
886                    }
887                }
888                // if current ty is not Adt, then break the loop
889                else {
890                    break; // TODO:
891                }
892            }
893            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
894            // But 0 in MIR is the ret index, the args' indexes begin from 1.
895            return Some((param_index + 1, field_indices, current_ty));
896        }
897    }
898    None
899}
900
901/// Return the Vecs of args' names and types
902/// This function will handle outside def_id by different way.
903pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
904    // 0. If the def id is local
905    if def_id.as_local().is_some() {
906        return parse_local_signature(tcx, def_id);
907    } else {
908        rap_debug!("{:?} is not local def id.", def_id);
909        return parse_outside_signature(tcx, def_id);
910    };
911}
912
913/// Return the Vecs of args' names and types of outside functions.
914fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
915    let sig = tcx.fn_sig(def_id).skip_binder();
916    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
917
918    // 1. check pre-defined std unsafe api signature
919    if let Some(args_name) = get_known_std_names(tcx, def_id) {
920        // rap_warn!(
921        //     "function {:?} has arg: {:?}, arg types: {:?}",
922        //     def_id,
923        //     args_name,
924        //     param_tys
925        // );
926        return (args_name, param_tys);
927    }
928
929    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
930    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
931    rap_debug!(
932        "function {:?} has arg: {:?}, arg types: {:?}",
933        def_id,
934        args_name,
935        param_tys
936    );
937    return (args_name, param_tys);
938}
939
940/// We use a json to record known std apis' arg names.
941/// This function will search the json and return the names.
942/// Notes: If std gets updated, the json may still record old ones.
943fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
944    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
945    let json_data: serde_json::Value = get_std_api_signature_json();
946
947    if let Some(arg_info) = json_data.get(&std_func_name) {
948        if let Some(args_name) = arg_info.as_array() {
949            // set default value to arg name
950            if args_name.len() == 0 {
951                return Some(vec!["0".to_string()]);
952            }
953            // iterate and collect
954            let mut result = Vec::new();
955            for arg in args_name {
956                if let Some(sp_name) = arg.as_str() {
957                    result.push(sp_name.to_string());
958                }
959            }
960            return Some(result);
961        }
962    }
963    None
964}
965
966/// Return the Vecs of args' names and types of local functions.
967pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
968    // 1. parse local def_id and get arg list
969    let local_def_id = def_id.as_local().unwrap();
970    let hir_body = tcx.hir_body_owned_by(local_def_id);
971    if hir_body.params.len() == 0 {
972        return (vec!["0".to_string()], Vec::new());
973    }
974    // 2. contruct the vec of param and param ty
975    let params = hir_body.params;
976    let typeck_results = tcx.typeck_body(hir_body.id());
977    let mut param_names = Vec::new();
978    let mut param_tys = Vec::new();
979    for param in params {
980        match param.pat.kind {
981            rustc_hir::PatKind::Binding(_, _, ident, _) => {
982                param_names.push(ident.name.to_string());
983                let ty = typeck_results.pat_ty(param.pat);
984                param_tys.push(ty);
985            }
986            _ => {
987                param_names.push(String::new());
988                param_tys.push(typeck_results.pat_ty(param.pat));
989            }
990        }
991    }
992    (param_names, param_tys)
993}
994
995/// return the (ident, its fields) of the expr.
996///
997/// illustrated cases :
998///    ptr	-> ("ptr", [])
999///    region.size	-> ("region", ["size"])
1000///    tuple.0.value -> ("tuple", ["0", "value"])
1001pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
1002    match expr {
1003        Expr::Path(syn::ExprPath { path, .. }) => {
1004            if path.segments.len() == 1 {
1005                rap_debug!("expr2 {:?}", expr);
1006                let ident = path.segments[0].ident.to_string();
1007                Some((ident, Vec::new()))
1008            } else {
1009                None
1010            }
1011        }
1012        // get the base and fields recursively
1013        Expr::Field(syn::ExprField { base, member, .. }) => {
1014            let (base_ident, mut fields) =
1015                if let Some((base_ident, fields)) = access_ident_recursive(base) {
1016                    (base_ident, fields)
1017                } else {
1018                    return None;
1019                };
1020            let field_name = match member {
1021                syn::Member::Named(ident) => ident.to_string(),
1022                syn::Member::Unnamed(index) => index.index.to_string(),
1023            };
1024            fields.push(field_name);
1025            Some((base_ident, fields))
1026        }
1027        _ => None,
1028    }
1029}
1030
1031/// parse expr into number.
1032pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
1033    if let Expr::Lit(expr_lit) = expr {
1034        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
1035            return lit_int.base10_parse::<usize>().ok();
1036        }
1037    }
1038    None
1039}
1040
1041/// Match a type identifier string to a concrete Rust type
1042///
1043/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
1044/// to a type in the provided parameter type list. It handles:
1045/// 1. Built-in primitive types (u32, usize, etc.)
1046/// 2. Generic type parameters (T, U, etc.)
1047/// 3. User-defined types found in the parameter list
1048///
1049/// Arguments:
1050/// - `tcx`: Type context for querying compiler information
1051/// - `type_ident`: String representing the type identifier to match
1052/// - `param_ty`: List of parameter types from the function signature
1053///
1054/// Returns:
1055/// - `Some(Ty)` if a matching type is found
1056/// - `None` if no match is found
1057pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1058    // 1. First check for built-in primitive types
1059    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
1060        return Some(primitive_ty);
1061    }
1062    // 2. Check if the identifier matches any generic type parameter
1063    return find_generic_param(tcx, def_id, type_ident.clone());
1064    // 3. Check if the identifier matches any user-defined type in the parameters
1065    // find_user_defined_type(tcx, def_id, type_ident)
1066}
1067
1068/// Match built-in primitive types from String
1069fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
1070    match type_ident.as_str() {
1071        "i8" => Some(tcx.types.i8),
1072        "i16" => Some(tcx.types.i16),
1073        "i32" => Some(tcx.types.i32),
1074        "i64" => Some(tcx.types.i64),
1075        "i128" => Some(tcx.types.i128),
1076        "isize" => Some(tcx.types.isize),
1077        "u8" => Some(tcx.types.u8),
1078        "u16" => Some(tcx.types.u16),
1079        "u32" => Some(tcx.types.u32),
1080        "u64" => Some(tcx.types.u64),
1081        "u128" => Some(tcx.types.u128),
1082        "usize" => Some(tcx.types.usize),
1083        "f16" => Some(tcx.types.f16),
1084        "f32" => Some(tcx.types.f32),
1085        "f64" => Some(tcx.types.f64),
1086        "f128" => Some(tcx.types.f128),
1087        "bool" => Some(tcx.types.bool),
1088        "char" => Some(tcx.types.char),
1089        "str" => Some(tcx.types.str_),
1090        _ => None,
1091    }
1092}
1093
1094/// Find generic type parameters in the parameter list
1095fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1096    rap_debug!(
1097        "Searching for generic param: {} in {:?}",
1098        type_ident,
1099        def_id
1100    );
1101    let (_, param_tys) = parse_signature(tcx, def_id);
1102    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
1103    // 递归查找泛型参数
1104    for &ty in &param_tys {
1105        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
1106            return Some(found);
1107        }
1108    }
1109
1110    None
1111}
1112
1113/// Iterate the args' types recursively and find the matched generic one.
1114fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
1115    match ty.kind() {
1116        TyKind::Param(param_ty) => {
1117            if param_ty.name.as_str() == type_ident {
1118                return Some(ty);
1119            }
1120        }
1121        TyKind::RawPtr(ty, _)
1122        | TyKind::Ref(_, ty, _)
1123        | TyKind::Slice(ty)
1124        | TyKind::Array(ty, _) => {
1125            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
1126                return Some(found);
1127            }
1128        }
1129        TyKind::Tuple(tys) => {
1130            for tuple_ty in tys.iter() {
1131                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
1132                    return Some(found);
1133                }
1134            }
1135        }
1136        TyKind::Adt(adt_def, substs) => {
1137            let name = tcx.item_name(adt_def.did()).to_string();
1138            if name == type_ident {
1139                return Some(ty);
1140            }
1141            for field in adt_def.all_fields() {
1142                let field_ty = field.ty(tcx, substs);
1143                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
1144                    return Some(found);
1145                }
1146            }
1147        }
1148        _ => {}
1149    }
1150    None
1151}
1152
1153pub fn reflect_generic<'tcx>(
1154    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
1155    func_name: &str,
1156    ty: Ty<'tcx>,
1157) -> Ty<'tcx> {
1158    let mut actual_ty = ty;
1159    match ty.kind() {
1160        TyKind::Param(param_ty) => {
1161            let generic_name = param_ty.name.to_string();
1162            if let Some(actual_ty_from_map) = generic_mapping.get(&generic_name) {
1163                actual_ty = *actual_ty_from_map;
1164            }
1165        }
1166        _ => {}
1167    }
1168    rap_debug!(
1169        "peel generic ty for {:?}, actual_ty is {:?}",
1170        func_name,
1171        actual_ty
1172    );
1173    actual_ty
1174}
1175
1176// src_var = 0: for constructor
1177// src_var = 1: for methods
1178pub fn has_tainted_fields(tcx: TyCtxt, def_id: DefId, src_var: u32) -> bool {
1179    let mut dataflow_analyzer = DataFlowAnalyzer::new(tcx, false);
1180    dataflow_analyzer.build_graph(def_id);
1181
1182    let body = tcx.optimized_mir(def_id);
1183    let params = &body.args_iter().collect::<Vec<_>>();
1184    rap_info!("params {:?}", params);
1185    let self_local = Local::from(src_var);
1186
1187    let flowing_params: Vec<Local> = params
1188        .iter()
1189        .filter(|&&param_local| {
1190            dataflow_analyzer.has_flow_between(def_id, self_local, param_local)
1191                && self_local != param_local
1192        })
1193        .copied()
1194        .collect();
1195
1196    if !flowing_params.is_empty() {
1197        rap_info!(
1198            "Taint flow found from self to other parameters: {:?}",
1199            flowing_params
1200        );
1201        true
1202    } else {
1203        false
1204    }
1205}
1206
1207// 修改返回值类型为调用链的向量
1208pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<Vec<String>> {
1209    let mut results = Vec::new();
1210    let mut visited = HashSet::new(); // 避免循环调用
1211    let mut current_chain = Vec::new();
1212
1213    // 开始DFS遍历
1214    dfs_find_unsafe_chains(tcx, def_id, &mut current_chain, &mut results, &mut visited);
1215    results
1216}
1217
1218// DFS递归查找unsafe调用链
1219fn dfs_find_unsafe_chains(
1220    tcx: TyCtxt,
1221    def_id: DefId,
1222    current_chain: &mut Vec<String>,
1223    results: &mut Vec<Vec<String>>,
1224    visited: &mut HashSet<DefId>,
1225) {
1226    // 避免循环调用
1227    if visited.contains(&def_id) {
1228        return;
1229    }
1230    visited.insert(def_id);
1231
1232    let current_func_name = get_cleaned_def_path_name(tcx, def_id);
1233    current_chain.push(current_func_name.clone());
1234
1235    // 获取当前函数的所有unsafe callee
1236    let unsafe_callees = find_unsafe_callees_in_function(tcx, def_id);
1237
1238    if unsafe_callees.is_empty() {
1239        // 如果没有更多的unsafe callee,保存当前链
1240        results.push(current_chain.clone());
1241    } else {
1242        // 对每个unsafe callee继续DFS
1243        for (callee_def_id, callee_name) in unsafe_callees {
1244            dfs_find_unsafe_chains(tcx, callee_def_id, current_chain, results, visited);
1245        }
1246    }
1247
1248    // 回溯
1249    current_chain.pop();
1250    visited.remove(&def_id);
1251}
1252
1253fn find_unsafe_callees_in_function(tcx: TyCtxt, def_id: DefId) -> Vec<(DefId, String)> {
1254    let mut callees = Vec::new();
1255
1256    if let Some(body) = try_get_mir(tcx, def_id) {
1257        for bb in body.basic_blocks.iter() {
1258            if let Some(terminator) = &bb.terminator {
1259                if let Some((callee_def_id, callee_name)) = extract_unsafe_callee(tcx, terminator) {
1260                    callees.push((callee_def_id, callee_name));
1261                }
1262            }
1263        }
1264    }
1265
1266    callees
1267}
1268
1269fn extract_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Option<(DefId, String)> {
1270    if let TerminatorKind::Call { func, .. } = &terminator.kind {
1271        if let Operand::Constant(func_constant) = func {
1272            if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
1273                if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
1274                    let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
1275                    return Some((*callee_def_id, func_name));
1276                }
1277            }
1278        }
1279    }
1280    None
1281}
1282
1283fn try_get_mir(tcx: TyCtxt<'_>, def_id: DefId) -> Option<&rustc_middle::mir::Body<'_>> {
1284    if tcx.is_mir_available(def_id) {
1285        Some(tcx.optimized_mir(def_id))
1286    } else {
1287        None
1288    }
1289}
1290
1291pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
1292    let def_id_str = format!("{:?}", def_id);
1293    let mut parts: Vec<&str> = def_id_str.split("::").collect();
1294
1295    let mut remove_first = false;
1296    if let Some(first_part) = parts.get_mut(0) {
1297        if first_part.contains("core") {
1298            *first_part = "core";
1299        } else if first_part.contains("std") {
1300            *first_part = "std";
1301        } else if first_part.contains("alloc") {
1302            *first_part = "alloc";
1303        } else {
1304            remove_first = true;
1305        }
1306    }
1307    if remove_first && !parts.is_empty() {
1308        parts.remove(0);
1309    }
1310
1311    let new_parts: Vec<String> = parts
1312        .into_iter()
1313        .filter_map(|s| {
1314            if s.contains("{") {
1315                if remove_first {
1316                    get_struct_name(tcx, def_id)
1317                } else {
1318                    None
1319                }
1320            } else {
1321                Some(s.to_string())
1322            }
1323        })
1324        .collect();
1325
1326    let mut cleaned_path = new_parts.join("::");
1327    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
1328    cleaned_path
1329    // tcx.def_path_str(def_id)
1330    //     .replace("::", "_")
1331    //     .replace("<", "_")
1332    //     .replace(">", "_")
1333    //     .replace(",", "_")
1334    //     .replace(" ", "")
1335    //     .replace("__", "_")
1336}
1337
1338pub fn print_unsafe_chains(chains: &[Vec<String>]) {
1339    if chains.is_empty() {
1340        return;
1341    }
1342
1343    println!("==============================");
1344    println!("Found {} unsafe call chain(s):", chains.len());
1345    for (i, chain) in chains.iter().enumerate() {
1346        println!("Chain {}:", i + 1);
1347        for (j, func_name) in chain.iter().enumerate() {
1348            let indent = "  ".repeat(j);
1349            println!("{}{}-> {}", indent, if j > 0 { " " } else { "" }, func_name);
1350        }
1351        println!();
1352    }
1353}
1354
1355pub fn get_all_std_fns_by_rustc_public(tcx: TyCtxt) -> Vec<DefId> {
1356    let mut all_std_fn_def = Vec::new();
1357    let mut results = Vec::new();
1358    let mut core_fn_def: Vec<_> = rustc_public::find_crates("core")
1359        .iter()
1360        .flat_map(|krate| krate.fn_defs())
1361        .collect();
1362    let mut std_fn_def: Vec<_> = rustc_public::find_crates("std")
1363        .iter()
1364        .flat_map(|krate| krate.fn_defs())
1365        .collect();
1366    let mut alloc_fn_def: Vec<_> = rustc_public::find_crates("alloc")
1367        .iter()
1368        .flat_map(|krate| krate.fn_defs())
1369        .collect();
1370    all_std_fn_def.append(&mut core_fn_def);
1371    all_std_fn_def.append(&mut std_fn_def);
1372    all_std_fn_def.append(&mut alloc_fn_def);
1373
1374    for fn_def in &all_std_fn_def {
1375        let def_id = crate::def_id::to_internal(fn_def, tcx);
1376        results.push(def_id);
1377    }
1378    results
1379}
1380
1381pub fn generate_mir_cfg_dot<'tcx>(
1382    tcx: TyCtxt<'tcx>,
1383    def_id: DefId,
1384    alias_sets: &Vec<FxHashSet<usize>>,
1385) -> Result<(), std::io::Error> {
1386    let mir = tcx.optimized_mir(def_id);
1387
1388    let mut dot_content = String::new();
1389
1390    let alias_info_str = format!("Alias Sets: {:?}", alias_sets);
1391
1392    dot_content.push_str(&format!(
1393        "digraph mir_cfg_{} {{\n",
1394        get_cleaned_def_path_name(tcx, def_id)
1395    ));
1396
1397    dot_content.push_str(&format!(
1398        "    label = \"MIR CFG for {}\\n{}\\n\";\n",
1399        tcx.def_path_str(def_id),
1400        alias_info_str.replace("\"", "\\\"")
1401    ));
1402    dot_content.push_str("    labelloc = \"t\";\n");
1403    dot_content.push_str("    node [shape=box, fontname=\"Courier\", align=\"left\"];\n\n");
1404
1405    for (bb_index, bb_data) in mir.basic_blocks.iter_enumerated() {
1406        let mut lines: Vec<String> = bb_data
1407            .statements
1408            .iter()
1409            .map(|stmt| format!("{:?}", stmt))
1410            .collect();
1411
1412        let mut node_style = String::new();
1413
1414        if let Some(terminator) = &bb_data.terminator {
1415            let mut is_drop_related = false;
1416
1417            match &terminator.kind {
1418                TerminatorKind::Drop { .. } => {
1419                    is_drop_related = true;
1420                }
1421                TerminatorKind::Call { func, .. } => {
1422                    if let Operand::Constant(c) = func
1423                        && let ty::FnDef(def_id, _) = *c.ty().kind()
1424                        && is_drop_fn(def_id)
1425                    {
1426                        is_drop_related = true;
1427                    }
1428                }
1429                _ => {}
1430            }
1431
1432            if is_drop_related {
1433                node_style = ", style=\"filled\", fillcolor=\"#ffdddd\", color=\"red\"".to_string();
1434            }
1435
1436            lines.push(format!("{:?}", terminator.kind));
1437        } else {
1438            lines.push("(no terminator)".to_string());
1439        }
1440
1441        let label_content = lines.join("\\l");
1442
1443        let node_label = format!("BB{}:\\l{}\\l", bb_index.index(), label_content);
1444
1445        dot_content.push_str(&format!(
1446            "    BB{} [label=\"{}\"{}];\n",
1447            bb_index.index(),
1448            node_label.replace("\"", "\\\""),
1449            node_style
1450        ));
1451
1452        if let Some(terminator) = &bb_data.terminator {
1453            for target in terminator.successors() {
1454                let edge_label = match terminator.kind {
1455                    _ => "".to_string(),
1456                };
1457
1458                dot_content.push_str(&format!(
1459                    "    BB{} -> BB{} [label=\"{}\"];\n",
1460                    bb_index.index(),
1461                    target.index(),
1462                    edge_label
1463                ));
1464            }
1465        }
1466    }
1467    dot_content.push_str("}\n");
1468    let name = get_cleaned_def_path_name(tcx, def_id);
1469    render_dot_string(name, dot_content);
1470    rap_debug!("render dot for {:?}", def_id);
1471    Ok(())
1472}
1473
1474// Input the adt def id
1475// Return set of (mutable method def_id, fields can be modified)
1476pub fn get_all_mutable_methods(tcx: TyCtxt, src_def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
1477    let mut std_results = HashMap::new();
1478    if get_type(tcx, src_def_id) == FnKind::Constructor {
1479        return std_results;
1480    }
1481    let all_std_fn_def = get_all_std_fns_by_rustc_public(tcx);
1482    let target_adt_def = get_adt_def_id_by_adt_method(tcx, src_def_id);
1483    let mut is_std = false;
1484    for &def_id in &all_std_fn_def {
1485        let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
1486        if adt_def.is_some() && adt_def == target_adt_def && src_def_id != def_id {
1487            if has_mut_self_param(tcx, def_id) {
1488                std_results.insert(def_id, HashSet::new());
1489            }
1490            is_std = true;
1491        }
1492    }
1493    if is_std {
1494        return std_results;
1495    }
1496    let mut results = HashMap::new();
1497    let public_fields = target_adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
1498    let impl_vec = target_adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
1499    for impl_id in impl_vec {
1500        if !matches!(tcx.def_kind(impl_id), rustc_hir::def::DefKind::Impl { .. }) {
1501            continue;
1502        }
1503        let associated_items = tcx.associated_items(impl_id);
1504        for item in associated_items.in_definition_order() {
1505            if let ty::AssocKind::Fn {
1506                name: _,
1507                has_self: _,
1508            } = item.kind
1509            {
1510                let item_def_id = item.def_id;
1511                if has_mut_self_param(tcx, item_def_id) {
1512                    let modified_fields = public_fields.clone();
1513                    results.insert(item_def_id, modified_fields);
1514                }
1515            }
1516        }
1517    }
1518    results
1519}
1520
1521pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<DefId> {
1522    let mut cons = Vec::new();
1523    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == FnKind::Constructor {
1524        return cons;
1525    }
1526    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1527        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1528            // get struct ty
1529            let ty = tcx.type_of(impl_id).skip_binder();
1530            if let Some(adt_def) = ty.ty_adt_def() {
1531                let adt_def_id = adt_def.did();
1532                let impls = tcx.inherent_impls(adt_def_id);
1533                for impl_def_id in impls {
1534                    for item in tcx.associated_item_def_ids(impl_def_id) {
1535                        if (tcx.def_kind(item) == DefKind::Fn
1536                            || tcx.def_kind(item) == DefKind::AssocFn)
1537                            && get_type(tcx, *item) == FnKind::Constructor
1538                        {
1539                            cons.push(*item);
1540                        }
1541                    }
1542                }
1543            }
1544        }
1545    }
1546    cons
1547}
1548
1549pub fn append_fn_with_types(tcx: TyCtxt, def_id: DefId) -> FnInfo {
1550    FnInfo::new(def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
1551}
1552pub fn search_constructor(tcx: TyCtxt, def_id: DefId) -> Vec<DefId> {
1553    let mut constructors = Vec::new();
1554    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1555        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1556            // get struct ty
1557            let ty = tcx.type_of(impl_id).skip_binder();
1558            if let Some(adt_def) = ty.ty_adt_def() {
1559                let adt_def_id = adt_def.did();
1560                let impl_vec = get_impls_for_struct(tcx, adt_def_id);
1561                for impl_id in impl_vec {
1562                    let associated_items = tcx.associated_items(impl_id);
1563                    for item in associated_items.in_definition_order() {
1564                        if let ty::AssocKind::Fn {
1565                            name: _,
1566                            has_self: _,
1567                        } = item.kind
1568                        {
1569                            let item_def_id = item.def_id;
1570                            if get_type(tcx, item_def_id) == FnKind::Constructor {
1571                                constructors.push(item_def_id);
1572                            }
1573                        }
1574                    }
1575                }
1576            }
1577        }
1578    }
1579    constructors
1580}
1581
1582pub fn get_ptr_deref_dummy_def_id(tcx: TyCtxt<'_>) -> Option<DefId> {
1583    tcx.hir_crate_items(()).free_items().find_map(|item_id| {
1584        let def_id = item_id.owner_id.to_def_id();
1585        let name = tcx.opt_item_name(def_id)?;
1586
1587        (name.as_str() == "__raw_ptr_deref_dummy").then_some(def_id)
1588    })
1589}