rapx/analysis/extract/
mod.rs

1/*
2 * This module implements the `extract unsafe-apis` and `extract std-unsafe-apis` commands.
3 * It collects all public unsafe functions from the current crate or the standard library
4 * and outputs them as JSON.
5 */
6
7use crate::analysis::utils::fn_info::{
8    check_safety, check_visibility, get_all_std_fns_by_rustc_public,
9};
10use rustc_hir::{
11    ImplItemKind, PatKind, Safety, TraitFn, TraitItemKind,
12    def::DefKind,
13    def_id::{DefId, LOCAL_CRATE},
14};
15use rustc_middle::ty::TyCtxt;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Serialize, Deserialize)]
19pub struct ParamInfo {
20    pub name: String,
21    pub ty: String,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25pub struct UnsafeApiEntry {
26    pub module: String,
27    pub name: String,
28    pub params: Vec<ParamInfo>,
29    pub safety_doc: Option<String>,
30}
31
32/// Returns true if `line` is a Markdown heading that should stop content
33/// collection for a `# Safety` section at the given `level` (1 or 2).
34fn is_heading_stop(line: &str, level: usize) -> bool {
35    let is_h1 = line.starts_with("# ") || line == "#";
36    if level == 1 {
37        return is_h1;
38    }
39    // level == 2: stop at `#` or `##` headings
40    is_h1 || line.starts_with("## ") || line == "##"
41}
42
43/// Extract the `# Safety` or `## Safety` section from a Rust doc comment string.
44///
45/// The `doc` parameter should be the concatenation of all `#[doc = "..."]` attribute
46/// values joined by newlines, as returned by `attr.doc_str()`.
47///
48/// Returns the text of the Safety section (with leading/trailing whitespace trimmed),
49/// or `None` if no Safety section is present.
50pub fn extract_safety_doc(doc: &str) -> Option<String> {
51    let lines: Vec<&str> = doc.lines().collect();
52    let mut start_idx: Option<usize> = None;
53    let mut safety_level: usize = 0;
54
55    for (i, line) in lines.iter().enumerate() {
56        let trimmed = line.trim();
57        if trimmed == "# Safety" {
58            safety_level = 1;
59            start_idx = Some(i + 1);
60            break;
61        } else if trimmed == "## Safety" {
62            safety_level = 2;
63            start_idx = Some(i + 1);
64            break;
65        }
66    }
67
68    let start = start_idx?;
69
70    let mut content_lines: Vec<&str> = Vec::new();
71    for line in lines.iter().skip(start) {
72        let trimmed = line.trim();
73        // Stop at any heading at the same or higher level.
74        // For level 1 (`# Safety`), any `#` heading stops the section.
75        // For level 2 (`## Safety`), any `#` or `##` heading stops the section.
76        if is_heading_stop(trimmed, safety_level) {
77            break;
78        }
79        content_lines.push(trimmed);
80    }
81
82    // Trim trailing empty lines
83    while content_lines.last().map_or(false, |l| l.is_empty()) {
84        content_lines.pop();
85    }
86    // Trim leading empty lines
87    while content_lines.first().map_or(false, |l| l.is_empty()) {
88        content_lines.remove(0);
89    }
90
91    let content = content_lines.join("\n");
92    if content.is_empty() {
93        None
94    } else {
95        Some(content)
96    }
97}
98
99pub struct ExtractUnsafeApis<'tcx> {
100    tcx: TyCtxt<'tcx>,
101}
102
103impl<'tcx> ExtractUnsafeApis<'tcx> {
104    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
105        Self { tcx }
106    }
107
108    /// Run extract for the current (local) crate and print JSON to stderr.
109    pub fn run_local(&self) {
110        let entries = self.collect_local();
111        match serde_json::to_string_pretty(&entries) {
112            Ok(json) => eprintln!("{}", json),
113            Err(e) => eprintln!("extract JSON serialization error: {}", e),
114        }
115    }
116
117    /// Run extract for the Rust standard library and print JSON to stderr.
118    pub fn run_std(&self) {
119        let entries = self.collect_std();
120        match serde_json::to_string_pretty(&entries) {
121            Ok(json) => eprintln!("{}", json),
122            Err(e) => eprintln!("extract: JSON serialization error: {}", e),
123        }
124    }
125
126    /// Collect doc comment text for a def_id by joining all `#[doc = "..."]` attrs.
127    fn get_doc_string(&self, def_id: DefId) -> String {
128        self.tcx
129            .get_all_attrs(def_id)
130            .iter()
131            .filter_map(|attr| attr.doc_str())
132            .map(|sym| sym.as_str().to_string())
133            .collect::<Vec<_>>()
134            .join("\n")
135    }
136
137    /// Extract the `# Safety` section from the doc comment of a def_id.
138    fn get_safety_doc(&self, def_id: DefId) -> Option<String> {
139        extract_safety_doc(&self.get_doc_string(def_id))
140    }
141
142    /// Get parameter types from the function signature.
143    fn get_params(&self, def_id: DefId) -> Vec<ParamInfo> {
144        let fn_sig = self.tcx.fn_sig(def_id).instantiate_identity();
145        let inputs = fn_sig.skip_binder().inputs();
146
147        // Try to get parameter names from HIR for local functions.
148        let param_names = self.get_hir_param_names(def_id);
149
150        inputs
151            .iter()
152            .enumerate()
153            .map(|(i, ty)| {
154                let name = param_names
155                    .get(i)
156                    .cloned()
157                    .unwrap_or_else(|| format!("arg{}", i));
158                ParamInfo {
159                    name,
160                    ty: format!("{}", ty),
161                }
162            })
163            .collect()
164    }
165
166    /// Attempt to retrieve parameter names from the HIR body for a local function.
167    fn get_hir_param_names(&self, def_id: DefId) -> Vec<String> {
168        let Some(local_def_id) = def_id.as_local() else {
169            return Vec::new();
170        };
171
172        let hir_node = self.tcx.hir_node_by_def_id(local_def_id);
173        let body_id = match hir_node {
174            rustc_hir::Node::Item(item) => {
175                if let rustc_hir::ItemKind::Fn { body, .. } = &item.kind {
176                    Some(*body)
177                } else {
178                    None
179                }
180            }
181            rustc_hir::Node::ImplItem(item) => {
182                if let ImplItemKind::Fn(_, body) = item.kind {
183                    Some(body)
184                } else {
185                    None
186                }
187            }
188            rustc_hir::Node::TraitItem(item) => {
189                if let TraitItemKind::Fn(_, TraitFn::Provided(body)) = item.kind {
190                    Some(body)
191                } else {
192                    None
193                }
194            }
195            _ => None,
196        };
197
198        if let Some(body_id) = body_id {
199            let body = self.tcx.hir_body(body_id);
200            body.params
201                .iter()
202                .map(|param| match &param.pat.kind {
203                    PatKind::Binding(_, _, ident, _) => ident.name.as_str().to_string(),
204                    _ => "_".to_string(),
205                })
206                .collect()
207        } else {
208            Vec::new()
209        }
210    }
211
212    /// Build an `UnsafeApiEntry` from a `DefId`.
213    fn make_entry(&self, def_id: DefId) -> UnsafeApiEntry {
214        let name = self.tcx.item_name(def_id).as_str().to_string();
215
216        let module = if let Some(local_def_id) = def_id.as_local() {
217            // For local items, build the module path as `crate_name[::parent_module]`.
218            let crate_name = self.tcx.crate_name(LOCAL_CRATE).as_str().to_string();
219            let mod_local = self.tcx.parent_module_from_def_id(local_def_id);
220            let parent_path = self.tcx.def_path_str(mod_local.to_def_id());
221            if parent_path.is_empty() {
222                crate_name
223            } else {
224                format!("{}::{}", crate_name, parent_path)
225            }
226        } else {
227            // For external items, derive the module by stripping the trailing `::name`
228            // component from the full qualified path.
229            let full_path = self.tcx.def_path_str(def_id);
230            if let Some(pos) = full_path.rfind("::") {
231                full_path[..pos].to_string()
232            } else {
233                full_path
234            }
235        };
236
237        UnsafeApiEntry {
238            module,
239            name,
240            params: self.get_params(def_id),
241            safety_doc: self.get_safety_doc(def_id),
242        }
243    }
244
245    /// Collect all public unsafe `fn` and `AssocFn` items in the local crate.
246    fn collect_local(&self) -> Vec<UnsafeApiEntry> {
247        let mut entries = Vec::new();
248
249        for local_def_id in self.tcx.mir_keys(()) {
250            let def_id = local_def_id.to_def_id();
251            let kind = self.tcx.def_kind(def_id);
252            if !matches!(kind, DefKind::Fn | DefKind::AssocFn) {
253                continue;
254            }
255            if !check_visibility(self.tcx, def_id) {
256                continue;
257            }
258            if check_safety(self.tcx, def_id) != Safety::Unsafe {
259                continue;
260            }
261            entries.push(self.make_entry(def_id));
262        }
263
264        entries
265    }
266
267    /// Collect all public unsafe functions from the Rust standard library.
268    fn collect_std(&self) -> Vec<UnsafeApiEntry> {
269        let mut entries = Vec::new();
270
271        let all_std_fns = get_all_std_fns_by_rustc_public(self.tcx);
272        for def_id in all_std_fns {
273            if !self.tcx.visibility(def_id).is_public() {
274                continue;
275            }
276            if check_safety(self.tcx, def_id) != Safety::Unsafe {
277                continue;
278            }
279            entries.push(self.make_entry(def_id));
280        }
281
282        entries
283    }
284}