rapx/analysis/core/alias_analysis/mfp/
intraproc.rs

1use rustc_data_structures::fx::FxHashMap;
2use rustc_hir::def_id::DefId;
3use rustc_middle::{
4    mir::{
5        Body, CallReturnPlaces, Location, Operand, Place, Rvalue, Statement, StatementKind,
6        Terminator, TerminatorEdges, TerminatorKind,
7    },
8    ty::{self, Ty, TyCtxt, TypingEnv},
9};
10use rustc_mir_dataflow::{Analysis, JoinSemiLattice, fmt::DebugWithContext};
11use std::cell::RefCell;
12use std::rc::Rc;
13
14use super::super::{FnAliasMap, FnAliasPairs};
15use super::transfer;
16use crate::analysis::core::alias_analysis::default::types::is_not_drop;
17
18/// Apply a function summary to the current state
19fn apply_function_summary<'tcx>(
20    state: &mut AliasDomain,
21    destination: Place<'tcx>,
22    args: &[Operand<'tcx>],
23    summary: &FnAliasPairs,
24    place_info: &PlaceInfo<'tcx>,
25) {
26    // Convert destination to PlaceId
27    let dest_id = transfer::mir_place_to_place_id(destination);
28
29    // Build a mapping from callee's argument indices to caller's PlaceIds
30    // Index 0 is return value, indices 1+ are arguments
31    let mut actual_places = vec![dest_id.clone()];
32    for arg in args {
33        if let Some(arg_id) = transfer::operand_to_place_id(arg) {
34            actual_places.push(arg_id);
35        } else {
36            // If argument is not a place (e.g., constant), use a dummy
37            actual_places.push(PlaceId::Local(usize::MAX));
38        }
39    }
40
41    // Apply each alias pair from the summary
42    for alias_pair in summary.aliases() {
43        let left_idx = alias_pair.left_local();
44        let right_idx = alias_pair.right_local();
45
46        // Check bounds
47        if left_idx >= actual_places.len() || right_idx >= actual_places.len() {
48            continue;
49        }
50
51        // Skip if either place is a dummy (constant argument)
52        // Dummy places use usize::MAX as a sentinel value
53        if actual_places[left_idx] == PlaceId::Local(usize::MAX)
54            || actual_places[right_idx] == PlaceId::Local(usize::MAX)
55        {
56            continue;
57        }
58
59        // Get actual places with field projections
60        let mut left_place = actual_places[left_idx].clone();
61        for &field_idx in alias_pair.lhs_fields() {
62            left_place = left_place.project_field(field_idx);
63        }
64
65        let mut right_place = actual_places[right_idx].clone();
66        for &field_idx in alias_pair.rhs_fields() {
67            right_place = right_place.project_field(field_idx);
68        }
69
70        // Get indices and union
71        if let (Some(left_place_idx), Some(right_place_idx)) = (
72            place_info.get_index(&left_place),
73            place_info.get_index(&right_place),
74        ) {
75            let left_may_drop = place_info.may_drop(left_place_idx);
76            let right_may_drop = place_info.may_drop(right_place_idx);
77            if left_may_drop && right_may_drop {
78                state.union(left_place_idx, right_place_idx);
79            }
80        }
81    }
82}
83
84/// Conservative fallback for library functions without MIR
85/// Assumes return value may alias with any may_drop argument
86fn apply_conservative_alias_for_call<'tcx>(
87    state: &mut AliasDomain,
88    destination: Place<'tcx>,
89    args: &[rustc_span::source_map::Spanned<rustc_middle::mir::Operand<'tcx>>],
90    place_info: &PlaceInfo<'tcx>,
91) {
92    // Get destination place
93    let dest_id = transfer::mir_place_to_place_id(destination);
94    let dest_idx = match place_info.get_index(&dest_id) {
95        Some(idx) => idx,
96        None => {
97            return;
98        }
99    };
100
101    // Only apply if destination may_drop
102    if !place_info.may_drop(dest_idx) {
103        return;
104    }
105
106    // Union with all may_drop arguments
107    for (_i, arg) in args.iter().enumerate() {
108        if let Some(arg_id) = transfer::operand_to_place_id(&arg.node) {
109            if let Some(arg_idx) = place_info.get_index(&arg_id) {
110                if place_info.may_drop(arg_idx) {
111                    // Create conservative alias
112                    state.union(dest_idx, arg_idx);
113
114                    // Sync fields for more precision
115                    transfer::sync_fields(state, &dest_id, &arg_id, place_info);
116                }
117            }
118        }
119    }
120}
121
122/// Place identifier supporting field-sensitive analysis
123#[derive(Debug, Clone, PartialEq, Eq, Hash)]
124pub enum PlaceId {
125    /// A local variable (e.g., _1)
126    Local(usize),
127    /// A field projection (e.g., _1.0)
128    Field {
129        base: Box<PlaceId>,
130        field_idx: usize,
131    },
132}
133
134impl PlaceId {
135    /// Get the root local of this place
136    pub fn root_local(&self) -> usize {
137        match self {
138            PlaceId::Local(idx) => *idx,
139            PlaceId::Field { base, .. } => base.root_local(),
140        }
141    }
142
143    /// Create a field projection
144    pub fn project_field(&self, field_idx: usize) -> PlaceId {
145        PlaceId::Field {
146            base: Box::new(self.clone()),
147            field_idx,
148        }
149    }
150
151    /// Check if this place has the given place as a prefix
152    /// e.g., _1.0.1 has prefix _1, _1.0.1 has prefix _1.0, but not _2
153    pub fn has_prefix(&self, prefix: &PlaceId) -> bool {
154        if self == prefix {
155            return true;
156        }
157
158        match self {
159            PlaceId::Local(_) => false,
160            PlaceId::Field { base, .. } => base.has_prefix(prefix),
161        }
162    }
163}
164
165/// Information about all places in a function
166#[derive(Clone)]
167pub struct PlaceInfo<'tcx> {
168    /// Mapping from PlaceId to index
169    place_to_index: FxHashMap<PlaceId, usize>,
170    /// Mapping from index to PlaceId
171    index_to_place: Vec<PlaceId>,
172    /// Mapping from PlaceId to MIR Place (when available)
173    place_to_mir: FxHashMap<PlaceId, Place<'tcx>>,
174    /// Whether each place may need drop
175    may_drop: Vec<bool>,
176    /// Whether each place needs drop
177    need_drop: Vec<bool>,
178    /// Total number of places
179    num_places: usize,
180}
181
182impl<'tcx> PlaceInfo<'tcx> {
183    /// Create a new PlaceInfo with initial capacity
184    pub fn new() -> Self {
185        PlaceInfo {
186            place_to_index: FxHashMap::default(),
187            index_to_place: Vec::new(),
188            place_to_mir: FxHashMap::default(),
189            may_drop: Vec::new(),
190            need_drop: Vec::new(),
191            num_places: 0,
192        }
193    }
194
195    /// Build PlaceInfo from MIR body
196    pub fn build(tcx: TyCtxt<'tcx>, def_id: DefId, body: &'tcx Body<'tcx>) -> Self {
197        let mut info = Self::new();
198        let ty_env = TypingEnv::post_analysis(tcx, def_id);
199
200        // Register all locals first
201        for (local, local_decl) in body.local_decls.iter_enumerated() {
202            let ty = local_decl.ty;
203            let need_drop = ty.needs_drop(tcx, ty_env);
204            let may_drop = !is_not_drop(tcx, ty);
205
206            let place_id = PlaceId::Local(local.as_usize());
207            info.register_place(place_id.clone(), may_drop, need_drop);
208
209            // Create fields for this type recursively
210            info.create_fields_for_type(tcx, ty, place_id, 0, 0, ty_env);
211        }
212
213        info
214    }
215
216    /// Recursively create field PlaceIds for a type
217    fn create_fields_for_type(
218        &mut self,
219        tcx: TyCtxt<'tcx>,
220        ty: Ty<'tcx>,
221        base_place: PlaceId,
222        field_depth: usize,
223        deref_depth: usize,
224        ty_env: TypingEnv<'tcx>,
225    ) {
226        // Limit recursion depth to avoid infinite loops
227        const MAX_FIELD_DEPTH: usize = 5;
228        const MAX_DEREF_DEPTH: usize = 3;
229        if field_depth >= MAX_FIELD_DEPTH || deref_depth >= MAX_DEREF_DEPTH {
230            return;
231        }
232
233        match ty.kind() {
234            // For references, recursively create fields for the inner type
235            // This allows handling patterns like (*_1).0 where _1 is &T
236            ty::Ref(_, inner_ty, _) => {
237                self.create_fields_for_type(
238                    tcx,
239                    *inner_ty,
240                    base_place,
241                    field_depth,
242                    deref_depth + 1,
243                    ty_env,
244                );
245            }
246            // For raw pointers, also create fields for the inner type
247            ty::RawPtr(inner_ty, _) => {
248                self.create_fields_for_type(
249                    tcx,
250                    *inner_ty,
251                    base_place,
252                    field_depth,
253                    deref_depth + 1,
254                    ty_env,
255                );
256            }
257            // For ADTs (structs/enums), create fields
258            ty::Adt(adt_def, substs) => {
259                for (field_idx, field) in adt_def.all_fields().enumerate() {
260                    let field_ty = field.ty(tcx, substs);
261                    let field_place = base_place.project_field(field_idx);
262
263                    // Check if field may/need drop
264                    // Use the ty_env from the function context to avoid param-env mismatch
265                    let need_drop = field_ty.needs_drop(tcx, ty_env);
266
267                    // Special handling: when deref_depth > 0, we are creating fields for
268                    // a type accessed through a reference/pointer (e.g., (*_1).0 where _1 is &T).
269                    // In this case, even if the field type itself doesn't need drop (e.g., i32),
270                    // we should still track it for alias analysis because it represents memory
271                    // accessed through a reference.
272                    let may_drop = if deref_depth > 0 {
273                        true
274                    } else {
275                        !is_not_drop(tcx, field_ty)
276                    };
277
278                    self.register_place(field_place.clone(), may_drop, need_drop);
279
280                    // Recursively create nested fields
281                    self.create_fields_for_type(
282                        tcx,
283                        field_ty,
284                        field_place,
285                        field_depth + 1,
286                        deref_depth,
287                        ty_env,
288                    );
289                }
290            }
291            // For tuples, create fields
292            ty::Tuple(fields) => {
293                for (field_idx, field_ty) in fields.iter().enumerate() {
294                    let field_place = base_place.project_field(field_idx);
295
296                    // For tuples, we conservatively check drop requirements
297                    // Note: Tuple fields don't have a specific DefId, so we use a simpler check
298
299                    // Special handling: when deref_depth > 0, we are creating fields for
300                    // a type accessed through a reference/pointer. Even if the field type
301                    // doesn't need drop, we should track it for alias analysis.
302                    let may_drop = if deref_depth > 0 {
303                        true
304                    } else {
305                        !is_not_drop(tcx, field_ty)
306                    };
307
308                    // For need_drop, use the ty_env from the function context
309                    let need_drop = field_ty.needs_drop(tcx, ty_env);
310
311                    self.register_place(field_place.clone(), may_drop, need_drop);
312
313                    // Recursively create nested fields
314                    self.create_fields_for_type(
315                        tcx,
316                        field_ty,
317                        field_place,
318                        field_depth + 1,
319                        deref_depth,
320                        ty_env,
321                    );
322                }
323            }
324            _ => {
325                // Other types don't have explicit fields we track
326            }
327        }
328    }
329
330    /// Register a new place and return its index
331    pub fn register_place(&mut self, place_id: PlaceId, may_drop: bool, need_drop: bool) -> usize {
332        if let Some(&idx) = self.place_to_index.get(&place_id) {
333            return idx;
334        }
335
336        let idx = self.num_places;
337        self.place_to_index.insert(place_id.clone(), idx);
338        self.index_to_place.push(place_id);
339        self.may_drop.push(may_drop);
340        self.need_drop.push(need_drop);
341        self.num_places += 1;
342        idx
343    }
344
345    /// Get the index of a place
346    pub fn get_index(&self, place_id: &PlaceId) -> Option<usize> {
347        self.place_to_index.get(place_id).copied()
348    }
349
350    /// Get the PlaceId for an index
351    pub fn get_place(&self, idx: usize) -> Option<&PlaceId> {
352        self.index_to_place.get(idx)
353    }
354
355    /// Check if a place may drop
356    pub fn may_drop(&self, idx: usize) -> bool {
357        self.may_drop.get(idx).copied().unwrap_or(false)
358    }
359
360    /// Check if a place needs drop
361    pub fn need_drop(&self, idx: usize) -> bool {
362        self.need_drop.get(idx).copied().unwrap_or(false)
363    }
364
365    /// Get total number of places
366    pub fn num_places(&self) -> usize {
367        self.num_places
368    }
369
370    /// Associate a MIR place with a PlaceId
371    pub fn associate_mir_place(&mut self, place_id: PlaceId, mir_place: Place<'tcx>) {
372        self.place_to_mir.insert(place_id, mir_place);
373    }
374}
375
376/// Alias domain using Union-Find data structure
377#[derive(Clone, PartialEq, Eq, Debug)]
378pub struct AliasDomain {
379    /// Parent array for Union-Find
380    parent: Vec<usize>,
381    /// Rank for path compression
382    rank: Vec<usize>,
383}
384
385impl AliasDomain {
386    /// Create a new domain with n places
387    pub fn new(num_places: usize) -> Self {
388        AliasDomain {
389            parent: (0..num_places).collect(),
390            rank: vec![0; num_places],
391        }
392    }
393
394    /// Find the representative of a place (with path compression)
395    pub fn find(&mut self, idx: usize) -> usize {
396        if self.parent[idx] != idx {
397            self.parent[idx] = self.find(self.parent[idx]);
398        }
399        self.parent[idx]
400    }
401
402    /// Union two places (returns true if they were not already aliased)
403    pub fn union(&mut self, idx1: usize, idx2: usize) -> bool {
404        let root1 = self.find(idx1);
405        let root2 = self.find(idx2);
406
407        if root1 == root2 {
408            return false;
409        }
410
411        // Union by rank
412        if self.rank[root1] < self.rank[root2] {
413            self.parent[root1] = root2;
414        } else if self.rank[root1] > self.rank[root2] {
415            self.parent[root2] = root1;
416        } else {
417            self.parent[root2] = root1;
418            self.rank[root1] += 1;
419        }
420
421        true
422    }
423
424    /// Check if two places are aliased
425    pub fn are_aliased(&mut self, idx1: usize, idx2: usize) -> bool {
426        self.find(idx1) == self.find(idx2)
427    }
428
429    /// Remove all aliases for a place (used in kill phase)
430    /// This correctly handles the case where idx is the root of a connected component
431    pub fn remove_aliases(&mut self, idx: usize) {
432        // Find the root of the connected component containing idx
433        let root = self.find(idx);
434
435        // Collect all nodes in the same connected component
436        let mut component_nodes = Vec::new();
437        for i in 0..self.parent.len() {
438            if self.find(i) == root {
439                component_nodes.push(i);
440            }
441        }
442
443        // Remove idx from the component
444        component_nodes.retain(|&i| i != idx);
445
446        // Isolate idx
447        self.parent[idx] = idx;
448        self.rank[idx] = 0;
449
450        // Rebuild the remaining component if it's not empty
451        if !component_nodes.is_empty() {
452            // Reset all nodes in the remaining component
453            for &i in &component_nodes {
454                self.parent[i] = i;
455                self.rank[i] = 0;
456            }
457
458            // Re-union them together (excluding idx)
459            let first = component_nodes[0];
460            for &i in &component_nodes[1..] {
461                self.union(first, i);
462            }
463        }
464    }
465
466    /// Remove all aliases for a place and all its field projections
467    /// This ensures that when lv is killed, all lv.* are also killed
468    pub fn remove_aliases_with_prefix(&mut self, place_id: &PlaceId, place_info: &PlaceInfo) {
469        // Collect all place indices that have place_id as a prefix
470        let mut indices_to_remove = Vec::new();
471
472        for idx in 0..self.parent.len() {
473            if let Some(pid) = place_info.get_place(idx) {
474                if pid.has_prefix(place_id) {
475                    indices_to_remove.push(idx);
476                }
477            }
478        }
479
480        // Remove aliases for all collected indices
481        for idx in indices_to_remove {
482            self.remove_aliases(idx);
483        }
484    }
485
486    /// Get all alias pairs (for debugging/summary extraction)
487    pub fn get_all_alias_pairs(&self) -> Vec<(usize, usize)> {
488        let mut pairs = Vec::new();
489        let mut domain_clone = self.clone();
490
491        for i in 0..self.parent.len() {
492            for j in (i + 1)..self.parent.len() {
493                if domain_clone.are_aliased(i, j) {
494                    pairs.push((i, j));
495                }
496            }
497        }
498
499        pairs
500    }
501}
502
503impl JoinSemiLattice for AliasDomain {
504    fn join(&mut self, other: &Self) -> bool {
505        // Safety check: both domains must have the same size
506        // This ensures they represent the same place space
507        assert_eq!(
508            self.parent.len(),
509            other.parent.len(),
510            "AliasDomain::join: size mismatch (self: {}, other: {})",
511            self.parent.len(),
512            other.parent.len()
513        );
514
515        let mut changed = false;
516
517        // Get all alias pairs from other and union them in self
518        let pairs = other.get_all_alias_pairs();
519        for (i, j) in pairs {
520            if self.union(i, j) {
521                changed = true;
522            }
523        }
524
525        changed
526    }
527}
528
529impl DebugWithContext<FnAliasAnalyzer<'_>> for AliasDomain {}
530
531/// Intraprocedural alias analyzer
532pub struct FnAliasAnalyzer<'tcx> {
533    pub tcx: TyCtxt<'tcx>,
534    _body: &'tcx Body<'tcx>,
535    _def_id: DefId,
536    place_info: PlaceInfo<'tcx>,
537    /// Function summaries for interprocedural analysis
538    fn_summaries: Rc<RefCell<FnAliasMap>>,
539    /// (Debug) Number of BBs we have iterated through
540    pub bb_iter_cnt: RefCell<usize>,
541}
542
543impl<'tcx> FnAliasAnalyzer<'tcx> {
544    /// Create a new analyzer for a function
545    pub fn new(
546        tcx: TyCtxt<'tcx>,
547        def_id: DefId,
548        body: &'tcx Body<'tcx>,
549        fn_summaries: Rc<RefCell<FnAliasMap>>,
550    ) -> Self {
551        // Build place info by analyzing the body
552        let place_info = PlaceInfo::build(tcx, def_id, body);
553        FnAliasAnalyzer {
554            tcx,
555            _body: body,
556            _def_id: def_id,
557            place_info,
558            fn_summaries,
559            bb_iter_cnt: RefCell::new(0),
560        }
561    }
562
563    /// Get the place info
564    pub fn place_info(&self) -> &PlaceInfo<'tcx> {
565        &self.place_info
566    }
567}
568
569// Implement Analysis for FnAliasAnalyzer
570impl<'tcx> Analysis<'tcx> for FnAliasAnalyzer<'tcx> {
571    type Domain = AliasDomain;
572
573    const NAME: &'static str = "FnAliasAnalyzer";
574
575    fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain {
576        // Bottom is no aliases
577        AliasDomain::new(self.place_info.num_places())
578    }
579
580    fn initialize_start_block(&self, _body: &Body<'tcx>, _state: &mut Self::Domain) {
581        // Entry state: no initial aliases between parameters
582    }
583
584    fn apply_primary_statement_effect(
585        &self,
586        state: &mut Self::Domain,
587        statement: &Statement<'tcx>,
588        _location: Location,
589    ) {
590        match &statement.kind {
591            StatementKind::Assign(box (lv, rvalue)) => {
592                match rvalue {
593                    // Use(operand): lv = operand
594                    Rvalue::Use(operand) => {
595                        transfer::transfer_assign(state, *lv, operand, &self.place_info);
596                    }
597                    // Ref: lv = &rv or lv = &raw rv
598                    Rvalue::Ref(_, _, rv) | Rvalue::RawPtr(_, rv) => {
599                        transfer::transfer_ref(state, *lv, *rv, &self.place_info);
600                    }
601                    // CopyForDeref: similar to ref
602                    Rvalue::CopyForDeref(rv) => {
603                        transfer::transfer_ref(state, *lv, *rv, &self.place_info);
604                    }
605                    // Cast: lv = operand as T
606                    Rvalue::Cast(_, operand, _) => {
607                        transfer::transfer_assign(state, *lv, operand, &self.place_info);
608                    }
609                    // Aggregate: lv = (operands...)
610                    Rvalue::Aggregate(_, operands) => {
611                        let operand_slice: Vec<_> = operands.iter().map(|op| op.clone()).collect();
612                        transfer::transfer_aggregate(state, *lv, &operand_slice, &self.place_info);
613                    }
614                    // ShallowInitBox: lv = ShallowInitBox(operand, T)
615                    Rvalue::ShallowInitBox(operand, _) => {
616                        transfer::transfer_assign(state, *lv, operand, &self.place_info);
617                    }
618                    // Other rvalues don't create aliases
619                    _ => {}
620                }
621            }
622            // Other statement kinds don't affect alias analysis
623            _ => {}
624        }
625    }
626
627    fn apply_primary_terminator_effect<'mir>(
628        &self,
629        state: &mut Self::Domain,
630        terminator: &'mir Terminator<'tcx>,
631        _location: Location,
632    ) -> TerminatorEdges<'mir, 'tcx> {
633        // (Debug)
634        {
635            *self.bb_iter_cnt.borrow_mut() += 1;
636        }
637        match &terminator.kind {
638            // Call: apply both kill and gen effects
639            // Note: Ideally gen effect should be in apply_call_return_effect, but that method
640            // is not being called by rustc's dataflow framework in current version.
641            // Therefore, we handle both effects here, following MOP's approach.
642            TerminatorKind::Call {
643                target,
644                destination,
645                args,
646                func,
647                ..
648            } => {
649                // Step 1: Apply kill effect for the destination
650                let operand_slice: Vec<_> = args
651                    .iter()
652                    .map(|spanned_arg| spanned_arg.node.clone())
653                    .collect();
654                transfer::transfer_call(state, *destination, &operand_slice, &self.place_info);
655
656                // Step 2: Apply gen effect - function summary or fallback
657                if let Operand::Constant(c) = func {
658                    if let ty::FnDef(callee_def_id, _) = c.ty().kind() {
659                        // Try to get the function summary
660                        let fn_summaries = self.fn_summaries.borrow();
661                        if let Some(summary) = fn_summaries.get(callee_def_id) {
662                            // Apply the function summary
663                            apply_function_summary(
664                                state,
665                                *destination,
666                                &operand_slice,
667                                summary,
668                                &self.place_info,
669                            );
670                        } else {
671                            // No summary available (e.g., library function without MIR)
672                            // Drop the borrow before calling the fallback function
673                            drop(fn_summaries);
674
675                            // Apply conservative fallback: assume return value may alias with any may_drop argument
676                            apply_conservative_alias_for_call(
677                                state,
678                                *destination,
679                                args,
680                                &self.place_info,
681                            );
682                        }
683                    } else {
684                        // FnPtr? Closure?
685                        // rap_warn!(
686                        //     "[MFP-alias] Ignoring call to {:?} because it's not a FnDef",
687                        //     c
688                        // );
689                    }
690                }
691
692                // Step 3: Return control flow edges
693                if let Some(target_bb) = target {
694                    TerminatorEdges::Single(*target_bb)
695                } else {
696                    TerminatorEdges::None
697                }
698            }
699
700            // Drop: doesn't affect alias relationships
701            TerminatorKind::Drop { target, .. } => TerminatorEdges::Single(*target),
702
703            // SwitchInt: return all possible edges
704            TerminatorKind::SwitchInt { discr, targets } => {
705                TerminatorEdges::SwitchInt { discr, targets }
706            }
707
708            // Assert: return normal edge
709            TerminatorKind::Assert { target, .. } => TerminatorEdges::Single(*target),
710
711            // Goto: return target
712            TerminatorKind::Goto { target } => TerminatorEdges::Single(*target),
713
714            // Return: no successors
715            TerminatorKind::Return => TerminatorEdges::None,
716
717            // All other terminators: assume no successors for safety
718            _ => TerminatorEdges::None,
719        }
720    }
721
722    fn apply_call_return_effect(
723        &self,
724        _state: &mut Self::Domain,
725        _block: rustc_middle::mir::BasicBlock,
726        _return_places: CallReturnPlaces<'_, 'tcx>,
727    ) {
728        // Note: This method is part of the rustc Analysis trait but is not being called
729        // by the dataflow framework in current rustc version when using iterate_to_fixpoint.
730        // The call return effect (gen effect) is instead handled directly in
731        // apply_primary_terminator_effect to ensure it is actually executed.
732        // This is consistent with how MOP analysis handles function calls.
733    }
734}