rapx/analysis/core/api_dependency/graph/
resolve.rs

1use super::Config;
2use super::dep_edge::DepEdge;
3use super::dep_node::DepNode;
4use super::transform::TransformKind;
5use super::ty_wrapper::TyWrapper;
6use crate::analysis::core::api_dependency::ApiDependencyGraph;
7use crate::analysis::core::api_dependency::graph::std_tys;
8use crate::analysis::core::api_dependency::mono::{Mono, get_mono_complexity};
9use crate::analysis::core::api_dependency::utils::{
10    fn_requires_monomorphization, is_fuzzable_ty, ty_complexity,
11};
12use crate::analysis::core::api_dependency::visit::FnVisitor;
13use crate::analysis::core::api_dependency::{mono, utils};
14use crate::analysis::utils::def_path::path_str_def_id;
15use crate::utils::fs::rap_create_file;
16use crate::{rap_debug, rap_info, rap_trace};
17use petgraph::Direction::{self, Incoming};
18use petgraph::Graph;
19use petgraph::dot;
20use petgraph::graph::NodeIndex;
21use petgraph::visit::{EdgeRef, NodeIndexable, Visitable};
22use rand::Rng;
23use rustc_hir::def_id::DefId;
24use rustc_middle::ty::{self, GenericArgsRef, TraitRef, Ty, TyCtxt};
25use rustc_span::sym::{self, require};
26use std::collections::HashMap;
27use std::collections::HashSet;
28use std::collections::VecDeque;
29use std::hash::Hash;
30use std::io::Write;
31use std::path::Path;
32use std::time;
33
34const MAX_TY_COMPLX: usize = 5;
35
36fn add_return_type_if_reachable<'tcx>(
37    fn_did: DefId,
38    args: &[ty::GenericArg<'tcx>],
39    reachable_tys: &HashSet<TyWrapper<'tcx>>,
40    new_tys: &mut HashSet<Ty<'tcx>>,
41    tcx: TyCtxt<'tcx>,
42) -> bool {
43    let fn_sig = utils::fn_sig_with_generic_args(fn_did, args, tcx);
44    let inputs = fn_sig.inputs();
45    for input_ty in inputs {
46        if !is_fuzzable_ty(*input_ty, tcx) && !reachable_tys.contains(&TyWrapper::from(*input_ty)) {
47            return false;
48        }
49    }
50    let output_ty = fn_sig.output();
51    if !output_ty.is_unit() {
52        new_tys.insert(output_ty);
53    }
54    true
55}
56
57#[derive(Clone)]
58struct TypeCandidates<'tcx> {
59    tcx: TyCtxt<'tcx>,
60    candidates: HashSet<TyWrapper<'tcx>>,
61    max_complexity: usize,
62}
63
64impl<'tcx> TypeCandidates<'tcx> {
65    pub fn new(tcx: TyCtxt<'tcx>, max_complexity: usize) -> Self {
66        TypeCandidates {
67            tcx,
68            candidates: HashSet::new(),
69            max_complexity,
70        }
71    }
72
73    pub fn insert(&mut self, ty: Ty<'tcx>) -> bool {
74        if ty_complexity(ty) <= self.max_complexity {
75            self.candidates.insert(ty.into())
76        } else {
77            false
78        }
79    }
80
81    pub fn insert_all(&mut self, ty: Ty<'tcx>) -> bool {
82        let complexity = ty_complexity(ty);
83        self.insert_all_with_complexity(ty, complexity)
84    }
85
86    pub fn insert_all_with_complexity(&mut self, ty: Ty<'tcx>, current_cmplx: usize) -> bool {
87        if current_cmplx > self.max_complexity {
88            return false;
89        }
90
91        // add T
92        let mut changed = self.candidates.insert(ty.into());
93
94        // add &T
95        changed |= self.insert_all_with_complexity(
96            Ty::new_ref(
97                self.tcx,
98                self.tcx.lifetimes.re_erased,
99                ty,
100                ty::Mutability::Not,
101            ),
102            current_cmplx + 1,
103        );
104
105        // add &mut T
106        changed |= self.insert_all_with_complexity(
107            Ty::new_ref(
108                self.tcx,
109                self.tcx.lifetimes.re_erased,
110                ty,
111                ty::Mutability::Mut,
112            ),
113            current_cmplx + 1,
114        );
115
116        // add &[T]
117        changed |= self.insert_all_with_complexity(
118            Ty::new_ref(
119                self.tcx,
120                self.tcx.lifetimes.re_erased,
121                Ty::new_slice(self.tcx, ty),
122                ty::Mutability::Not,
123            ),
124            current_cmplx + 2,
125        );
126
127        // add &mut [T]
128        changed |= self.insert_all_with_complexity(
129            Ty::new_ref(
130                self.tcx,
131                self.tcx.lifetimes.re_erased,
132                Ty::new_slice(self.tcx, ty),
133                ty::Mutability::Mut,
134            ),
135            current_cmplx + 2,
136        );
137
138        changed
139    }
140
141    pub fn add_prelude_tys(&mut self) {
142        let tcx = self.tcx;
143
144        let primitive_tys = [
145            tcx.types.bool,
146            tcx.types.char,
147            tcx.types.f32,
148            tcx.types.i8,
149            tcx.types.u8,
150            tcx.types.i32,
151            tcx.types.u32,
152            tcx.types.i64,
153            tcx.types.u64,
154            tcx.types.isize,
155            tcx.types.usize,
156        ];
157
158        let mut prelude_tys = Vec::new();
159
160        prelude_tys.extend_from_slice(&primitive_tys);
161        // &str
162        prelude_tys.push(Ty::new_imm_ref(
163            tcx,
164            tcx.lifetimes.re_erased,
165            tcx.types.str_,
166        ));
167        // String
168        prelude_tys.push(Ty::new_adt(
169            self.tcx,
170            self.tcx.adt_def(self.tcx.lang_items().string().unwrap()),
171            ty::GenericArgs::empty(),
172        ));
173        for element_ty in &primitive_tys {
174            // Vec<T>
175            prelude_tys.push(std_tys::std_vec(*element_ty, self.tcx));
176        }
177        prelude_tys.into_iter().for_each(|ty| {
178            self.insert_all(ty);
179        });
180    }
181
182    pub fn candidates(&self) -> &HashSet<TyWrapper<'tcx>> {
183        &self.candidates
184    }
185}
186
187pub fn partion_generic_api<'tcx>(
188    all_apis: &HashSet<DefId>,
189    tcx: TyCtxt<'tcx>,
190) -> (HashSet<DefId>, HashSet<DefId>) {
191    let mut generic_api = HashSet::new();
192    let mut non_generic_api = HashSet::new();
193    for api_id in all_apis.iter() {
194        if tcx.generics_of(*api_id).requires_monomorphization(tcx) {
195            generic_api.insert(*api_id);
196        } else {
197            non_generic_api.insert(*api_id);
198        }
199    }
200    (non_generic_api, generic_api)
201}
202
203impl<'tcx> ApiDependencyGraph<'tcx> {
204    pub fn resolve_generic_api(
205        &mut self,
206        non_generic_apis: &[DefId],
207        generic_apis: &[DefId],
208        max_iteration: usize,
209    ) {
210        rap_info!("start resolving generic APIs");
211
212        // 1. Reachable generic API search
213        let generic_map = self.search_reachable_apis(non_generic_apis, generic_apis, max_iteration);
214
215        self.add_mono_apis_from_map(&generic_map);
216        self.update_transform_edges();
217
218        rap_info!("finish resolving generic APIs");
219        self.statistics().info();
220        self.dump_to_file(Path::new("api_graph_unpruned.dot"));
221
222        let reserved = self.prune_by_similarity(generic_map);
223
224        let count = self.reserve_nodes(&reserved);
225        rap_info!("remove {} nodes by pruning", count);
226    }
227
228    pub fn search_reachable_apis(
229        &mut self,
230        non_generic_apis: &[DefId],
231        generic_apis: &[DefId],
232        max_iteration: usize,
233    ) -> HashMap<DefId, HashSet<Mono<'tcx>>> {
234        let tcx = self.tcx;
235        let mut type_candidates = TypeCandidates::new(self.tcx, MAX_TY_COMPLX);
236
237        type_candidates.add_prelude_tys();
238
239        let mut generic_map: HashMap<DefId, HashSet<Mono>> = HashMap::new();
240        let mut unreachable_non_generic_api = Vec::from(non_generic_apis);
241
242        rap_debug!("[resolve_generic] non_generic_api = {unreachable_non_generic_api:?}");
243        rap_debug!("[resolve_generic] generic_api = {generic_apis:?}");
244
245        let mut num_iter = 0;
246
247        loop {
248            num_iter += 1;
249            let all_reachable_tys = type_candidates.candidates();
250            rap_info!(
251                "start iter #{num_iter}, # of reachble types = {}",
252                all_reachable_tys.len()
253            );
254
255            // dump all reachable types to files, each line output a type
256            let mut file = rap_create_file(Path::new("reachable_types.txt"), "create file fail");
257            for ty in all_reachable_tys.iter() {
258                writeln!(file, "{}", ty.ty()).unwrap();
259            }
260
261            let mut current_tys = HashSet::new();
262
263            // check whether there is any non-generic reachable in this iteration.
264            // if the api is reachable, add output type to reachble_tys,
265            // and remove it from the set.
266            unreachable_non_generic_api.retain(|fn_did| {
267                !add_return_type_if_reachable(
268                    *fn_did,
269                    ty::GenericArgs::identity_for_item(tcx, *fn_did),
270                    all_reachable_tys,
271                    &mut current_tys,
272                    tcx,
273                )
274            });
275
276            // check each generic API for new monomorphic API
277            for fn_did in generic_apis.iter() {
278                let mono_set = mono::resolve_mono_apis(*fn_did, &all_reachable_tys, tcx);
279                rap_debug!(
280                    "[search_reachable_apis] {} -> {:?}",
281                    tcx.def_path_str(*fn_did),
282                    mono_set
283                );
284                for mono in mono_set.monos {
285                    let fn_sig = utils::fn_sig_with_generic_args(*fn_did, &mono.value, tcx);
286                    let output_ty = fn_sig.output();
287                    if generic_map.entry(*fn_did).or_default().insert(mono) {
288                        if !output_ty.is_unit() && ty_complexity(output_ty) <= MAX_TY_COMPLX {
289                            current_tys.insert(output_ty);
290                        }
291                    }
292                }
293            }
294
295            let mut changed = false;
296            for ty in current_tys {
297                changed = changed | type_candidates.insert_all(ty);
298            }
299
300            if !changed {
301                rap_info!("Terminate. Reachable types unchange in this iteration.");
302                break;
303            }
304            if num_iter >= max_iteration {
305                rap_info!("Terminate. Max iteration reached.");
306                break;
307            }
308        }
309
310        let mono_cnt = generic_map.values().fold(0, |acc, monos| acc + monos.len());
311
312        rap_debug!("# reachable types: {}", type_candidates.candidates().len());
313        rap_debug!("# mono APIs: {}", mono_cnt);
314
315        generic_map
316    }
317
318    pub fn add_mono_apis_from_map(&mut self, generic_map: &HashMap<DefId, HashSet<Mono<'tcx>>>) {
319        for (fn_did, mono_set) in generic_map {
320            for mono in mono_set {
321                let args = self.tcx.mk_args(&mono.value);
322                self.add_api(*fn_did, args);
323            }
324        }
325    }
326
327    /// heuristic strategy: prioritize to reserve APIs that first arg of which is reachable.
328    /// This is based on that we want to reserve APIs that have the same Self type ASAP.
329    pub fn heuristic_select(&mut self, reserved: &mut [bool]) {
330        let mut worklist = VecDeque::new();
331        let mut visited = vec![false; self.graph.node_count()];
332        let mut impl_map: HashMap<DefId, HashSet<DefId>> = HashMap::new();
333        let mut count_map: HashMap<DefId, usize> = HashMap::new();
334
335        // traverse from start node, if a node can achieve a reserved node,
336        // this node should be reserved
337        for node in self.graph.node_indices() {
338            if self.is_start_node_index(node) {
339                rap_trace!("initial node {:?}", self.graph[node]);
340                worklist.push_back(node);
341            }
342        }
343
344        while let Some(node) = worklist.pop_front() {
345            if visited[node.index()] {
346                continue;
347            }
348            visited[node.index()] = true;
349
350            match self.graph[node] {
351                DepNode::Api(fn_did, args) => {
352                    if fn_requires_monomorphization(fn_did, self.tcx) {
353                        let impl_entry = impl_map.entry(fn_did).or_default();
354                        let count_entry = count_map.entry(fn_did).or_default();
355                        let impls = mono::get_impls(self.tcx, fn_did, args);
356                        let size = impls
357                            .iter()
358                            .fold(0, |cnt, did| cnt + (!impl_entry.contains(did)) as usize);
359                        if *count_entry == 0 || size > 0 {
360                            *count_entry += 1;
361                            impls.iter().for_each(|did| {
362                                impl_entry.insert(*did);
363                            });
364                            reserved[node.index()] = true;
365                        }
366                    }
367                    for neighbor in self.graph.neighbors(node) {
368                        worklist.push_back(neighbor);
369                    }
370                }
371                DepNode::Ty(..) => {
372                    for edge in self.graph.edges_directed(node, Direction::Outgoing) {
373                        let weight = self.graph.edge_weight(edge.id()).unwrap();
374                        if let DepEdge::Transform(_) | DepEdge::Arg { no: 0 } = weight {
375                            worklist.push_back(edge.target());
376                        }
377                    }
378                }
379            }
380
381            if reserved[node.index()] {
382                rap_debug!(
383                    "[propagate_reserved] reserve: {:?}",
384                    self.graph.node_weight(node).unwrap()
385                );
386            }
387        }
388    }
389
390    pub fn minimal_select(
391        &mut self,
392        reserved: &mut [bool],
393        generic_map: &HashMap<DefId, HashSet<Mono<'tcx>>>,
394    ) {
395        let mut rng = rand::rng();
396        let mut reserved_map: HashMap<DefId, Vec<(GenericArgsRef<'tcx>, bool)>> = HashMap::new();
397
398        // transform into reserved map
399        for (fn_did, mono_set) in generic_map {
400            let entry = reserved_map.entry(*fn_did).or_default();
401            mono_set.into_iter().for_each(|mono| {
402                let args = self.tcx.mk_args(&mono.value);
403                entry.push((args, false));
404            });
405        }
406        // add all monomorphic APIs to API Graph, but select minimal set cover to be reserved
407        for (fn_did, monos) in &mut reserved_map {
408            select_minimal_set_cover(self.tcx, *fn_did, monos, &mut rng);
409            for (args, r) in monos {
410                if *r {
411                    let idx = self.get_index(DepNode::Api(*fn_did, args)).unwrap();
412                    reserved[idx.index()] = true;
413                }
414            }
415        }
416    }
417
418    pub fn prune_by_similarity(
419        &mut self,
420        generic_map: HashMap<DefId, HashSet<Mono<'tcx>>>,
421    ) -> Vec<bool> {
422        let (estimate, total) = self.estimate_coverage_distinct();
423        rap_info!(
424            "estimate API coverage before pruning: {:.2} ({}/{})",
425            estimate as f64 / total as f64,
426            estimate,
427            total
428        );
429
430        let mut visited = vec![false; self.graph.node_count()];
431        let mut reserved = vec![false; self.graph.node_count()];
432
433        // initialize reserved
434        // all non-generic API should be reserved
435        for idx in self.graph.node_indices() {
436            if let DepNode::Api(fn_did, _) = self.graph[idx] {
437                if !utils::fn_requires_monomorphization(fn_did, self.tcx) {
438                    reserved[idx.index()] = true;
439                }
440            }
441        }
442
443        // minimal set cover strategy
444        // self.minimal_select(&mut reserved, &generic_map);
445
446        // heuristic strategy
447        self.heuristic_select(&mut reserved);
448
449        // traverse from start node, if a node can achieve a reserved node,
450        // this node should be reserved as well
451        for node in self.graph.node_indices() {
452            if !visited[node.index()] && self.is_start_node_index(node) {
453                rap_trace!("start propagate from {:?}", self.graph[node]);
454                self.propagate_reserved(node, &mut visited, &mut reserved);
455            }
456        }
457
458        for node in self.graph.node_indices() {
459            if !visited[node.index()] {
460                rap_trace!("{:?} is unvisited", self.graph[node]);
461                self.propagate_reserved(node, &mut visited, &mut reserved);
462            }
463        }
464
465        reserved
466    }
467
468    pub fn reserve_nodes(&mut self, reserved: &[bool]) -> usize {
469        let mut count = 0;
470        for idx in (0..self.graph.node_count()).rev() {
471            if !reserved[idx] {
472                self.graph
473                    .remove_node(NodeIndex::new(idx))
474                    .expect("remove should not fail");
475                count += 1;
476            }
477        }
478        self.recache();
479        count
480    }
481
482    pub fn propagate_reserved(
483        &self,
484        node: NodeIndex,
485        visited: &mut [bool],
486        reserved: &mut [bool],
487    ) -> bool {
488        visited[node.index()] = true;
489
490        match self.graph[node] {
491            // Api should be reserved if must_reserve is true,
492            // or at least one its neighbor is reserved
493            DepNode::Api(fn_did, args) => {
494                for neighbor in self.graph.neighbors(node) {
495                    if !visited[neighbor.index()] {
496                        reserved[node.index()] |=
497                            self.propagate_reserved(neighbor, visited, reserved);
498                    }
499                }
500            }
501
502            // Ty should be reserved if at least one its neighbor is reserved
503            DepNode::Ty(..) => {
504                // self.graph.edges_directed(node, dir)
505                for neighbor in self.graph.neighbors(node) {
506                    if !visited[neighbor.index()] {
507                        self.propagate_reserved(neighbor, visited, reserved);
508                    }
509                    reserved[node.index()] |= reserved[neighbor.index()]
510                }
511            }
512        }
513
514        if reserved[node.index()] {
515            rap_trace!(
516                "[propagate_reserved] reserve: {:?}",
517                self.graph.node_weight(node).unwrap()
518            );
519        }
520        reserved[node.index()]
521    }
522}
523
524fn select_minimal_set_cover<'tcx, 'a>(
525    tcx: TyCtxt<'tcx>,
526    fn_did: DefId,
527    monos: &'a mut Vec<(ty::GenericArgsRef<'tcx>, bool)>,
528    rng: &mut impl Rng,
529) {
530    rap_debug!("select minimal set for: {}", tcx.def_path_str(fn_did));
531    let mut impl_vec = Vec::new();
532    let mut cmplx_vec = Vec::new();
533    for (args, _) in monos.iter() {
534        impl_vec.push(mono::get_impls(tcx, fn_did, args));
535        cmplx_vec.push(get_mono_complexity(args));
536    }
537
538    let mut selected_cnt = 0;
539    let mut complete = HashSet::new();
540    loop {
541        let mut current_max = 0;
542        let mut current_cmplx = usize::MAX;
543        let mut idx = 0;
544        for i in 0..impl_vec.len() {
545            let size = impl_vec[i]
546                .iter()
547                .fold(0, |cnt, did| cnt + (!complete.contains(did)) as usize);
548
549            if size > current_max || (size == current_max && cmplx_vec[i] < current_cmplx) {
550                current_max = size;
551                current_cmplx = cmplx_vec[i];
552                idx = i;
553            }
554        }
555        // though maybe all impls is empty, we have to select at least one
556        if current_max == 0 && selected_cnt > 0 {
557            break;
558        }
559        selected_cnt += 1;
560        monos[idx].1 = true;
561        rap_debug!("select: {:?}", monos[idx].0);
562        impl_vec[idx].iter().for_each(|did| {
563            complete.insert(*did);
564        });
565    }
566
567    // if selected_cnt == 0 {
568    //     let idx = rng.random_range(0..impl_vec.len());
569    //     rap_debug!("random select: {:?}", monos[idx].0);
570    //     monos[idx].1 = true;
571    // }
572}