rapx/analysis/core/ssa_transform/
SSATransformer.rs

1#![allow(unused_imports)]
2#![allow(unused_variables)]
3#![allow(dead_code)]
4
5use rustc_data_structures::graph::dominators::Dominators;
6use rustc_data_structures::graph::{Predecessors, dominators};
7use rustc_driver::args;
8use rustc_hir::def_id::DefId;
9use rustc_hir::def_id::{CRATE_DEF_INDEX, CrateNum, DefIndex, LOCAL_CRATE, LocalDefId};
10use rustc_middle::mir::*;
11use rustc_middle::{
12    mir::{Body, Local, Location, visit::Visitor},
13    ty::TyCtxt,
14};
15use rustc_span::symbol::Symbol;
16use std::collections::{HashMap, HashSet};
17pub struct PhiPlaceholder;
18pub struct SSATransformer<'tcx> {
19    pub tcx: TyCtxt<'tcx>,
20    pub body: Body<'tcx>,
21    pub cfg: HashMap<BasicBlock, Vec<BasicBlock>>,
22    pub dominators: Dominators<BasicBlock>,
23    pub dom_tree: HashMap<BasicBlock, Vec<BasicBlock>>,
24    pub df: HashMap<BasicBlock, HashSet<BasicBlock>>,
25    pub local_assign_blocks: HashMap<Local, HashSet<BasicBlock>>,
26    pub reaching_def: HashMap<Local, Option<Local>>,
27    pub local_index: usize,
28    pub local_defination_block: HashMap<Local, BasicBlock>,
29    pub skipped: HashSet<usize>,
30    pub phi_index: HashMap<Location, usize>,
31    pub phi_def_id: DefId,
32    pub essa_def_id: DefId,
33    pub ref_local_map: HashMap<Local, Local>,
34    pub places_map: HashMap<Place<'tcx>, HashSet<Place<'tcx>>>,
35    pub ssa_locals_map: HashMap<Place<'tcx>, HashSet<Place<'tcx>>>,
36}
37
38impl<'tcx> SSATransformer<'tcx> {
39    fn find_phi_placeholder(tcx: TyCtxt<'_>, crate_name: &str) -> Option<DefId> {
40        let sym_crate = Symbol::intern(crate_name);
41        let krate = tcx
42            .crates(())
43            .iter()
44            .find(|&&c| tcx.crate_name(c) == sym_crate)?;
45        let root_def_id = DefId {
46            krate: *krate,
47            index: CRATE_DEF_INDEX,
48        };
49        // print!("Phid\n");
50
51        for item in tcx.module_children(root_def_id) {
52            // println!("Module child: {:?}", item.ident.name.as_str());
53
54            if item.ident.name.as_str() == "PhiPlaceholder" {
55                if let Some(def_id) = item.res.opt_def_id() {
56                    return Some(def_id);
57                }
58            }
59        }
60        // print!("Phid\n");
61        return Some(root_def_id);
62    }
63    pub fn new(
64        tcx: TyCtxt<'tcx>,
65        body: &Body<'tcx>,
66        ssa_def_id: DefId,
67        essa_def_id: DefId,
68        arg_count: usize,
69    ) -> Self {
70        let cfg: HashMap<BasicBlock, Vec<BasicBlock>> = Self::extract_cfg_from_predecessors(&body);
71
72        let dominators: Dominators<BasicBlock> = body.basic_blocks.dominators().clone();
73
74        let dom_tree: HashMap<BasicBlock, Vec<BasicBlock>> = Self::construct_dominance_tree(&body);
75
76        let df: HashMap<BasicBlock, HashSet<BasicBlock>> =
77            Self::compute_dominance_frontier(&body, &dom_tree);
78
79        let local_assign_blocks: HashMap<Local, HashSet<BasicBlock>> =
80            Self::map_locals_to_assign_blocks(&body);
81        let local_defination_block: HashMap<Local, BasicBlock> =
82            Self::map_locals_to_definition_block(&body);
83        let len = body.local_decls.len() as usize;
84        let mut skipped = HashSet::new();
85        if len > 0 {
86            skipped.extend(arg_count + 1..len + 1);
87            // skipped.insert(0); // Skip the return place
88        }
89
90        SSATransformer {
91            tcx,
92            body: body.clone(),
93            cfg,
94            dominators,
95            dom_tree,
96            df,
97            local_assign_blocks,
98            reaching_def: HashMap::default(),
99            local_index: len,
100            local_defination_block: local_defination_block,
101            skipped: skipped,
102            phi_index: HashMap::default(),
103            phi_def_id: ssa_def_id,
104            essa_def_id: essa_def_id,
105            ref_local_map: HashMap::default(),
106            places_map: HashMap::default(),
107            ssa_locals_map: HashMap::default(),
108        }
109    }
110
111    pub fn return_body_ref(&self) -> &Body<'tcx> {
112        &self.body
113    }
114
115    fn map_locals_to_definition_block(body: &Body) -> HashMap<Local, BasicBlock> {
116        let mut local_to_block_map: HashMap<Local, BasicBlock> = HashMap::new();
117
118        for (bb, block_data) in body.basic_blocks.iter_enumerated() {
119            for statement in &block_data.statements {
120                match &statement.kind {
121                    StatementKind::Assign(box (place, _)) => {
122                        if let Some(local) = place.as_local() {
123                            if local.as_u32() == 0 {
124                                continue; // Skip the return place
125                            }
126                            local_to_block_map.entry(local).or_insert(bb);
127                        }
128                    }
129                    _ => {}
130                }
131            }
132            if let Some(terminator) = &block_data.terminator {
133                match &terminator.kind {
134                    TerminatorKind::Call { destination, .. } => {
135                        if let Some(local) = destination.as_local() {
136                            if local.as_u32() == 0 {
137                                continue; // Skip the return place
138                            }
139                            local_to_block_map.entry(local).or_insert(bb);
140                        }
141                    }
142                    _ => {}
143                }
144            }
145        }
146
147        local_to_block_map
148    }
149    pub fn depth_first_search_preorder(
150        dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
151        root: BasicBlock,
152    ) -> Vec<BasicBlock> {
153        let mut visited: HashSet<BasicBlock> = HashSet::new();
154        let mut preorder = Vec::new();
155
156        fn dfs(
157            node: BasicBlock,
158            dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
159            visited: &mut HashSet<BasicBlock>,
160            preorder: &mut Vec<BasicBlock>,
161        ) {
162            if visited.insert(node) {
163                preorder.push(node);
164
165                if let Some(children) = dom_tree.get(&node) {
166                    for &child in children {
167                        dfs(child, dom_tree, visited, preorder);
168                    }
169                }
170            }
171        }
172
173        dfs(root, dom_tree, &mut visited, &mut preorder);
174        preorder
175    }
176    pub fn depth_first_search_postorder(
177        dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
178        root: &BasicBlock,
179    ) -> Vec<BasicBlock> {
180        let mut visited: HashSet<BasicBlock> = HashSet::new();
181        let mut postorder = Vec::new();
182
183        fn dfs(
184            node: BasicBlock,
185            dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
186            visited: &mut HashSet<BasicBlock>,
187            postorder: &mut Vec<BasicBlock>,
188        ) {
189            if visited.insert(node) {
190                if let Some(children) = dom_tree.get(&node) {
191                    for &child in children {
192                        dfs(child, dom_tree, visited, postorder);
193                    }
194                }
195                postorder.push(node);
196            }
197        }
198
199        dfs(*root, dom_tree, &mut visited, &mut postorder);
200        postorder
201    }
202
203    fn map_locals_to_assign_blocks(body: &Body) -> HashMap<Local, HashSet<BasicBlock>> {
204        let mut local_to_blocks: HashMap<Local, HashSet<BasicBlock>> = HashMap::new();
205
206        for (bb, data) in body.basic_blocks.iter_enumerated() {
207            for stmt in &data.statements {
208                if let StatementKind::Assign(box (place, _)) = &stmt.kind {
209                    let local = place.local;
210                    if local.as_u32() == 0 {
211                        continue; // Skip the return place
212                    }
213                    local_to_blocks
214                        .entry(local)
215                        .or_insert_with(HashSet::new)
216                        .insert(bb);
217                }
218            }
219        }
220        for arg in body.args_iter() {
221            local_to_blocks
222                .entry(arg)
223                .or_insert_with(HashSet::new)
224                .insert(BasicBlock::from_u32(0)); // Assuming arg block is 0
225        }
226        local_to_blocks
227    }
228    fn construct_dominance_tree(body: &Body<'_>) -> HashMap<BasicBlock, Vec<BasicBlock>> {
229        let mut dom_tree: HashMap<BasicBlock, Vec<BasicBlock>> = HashMap::new();
230        let dominators = body.basic_blocks.dominators();
231        for (block, _) in body.basic_blocks.iter_enumerated() {
232            if let Some(idom) = dominators.immediate_dominator(block) {
233                dom_tree.entry(idom).or_default().push(block);
234            }
235        }
236
237        dom_tree
238    }
239    fn compute_dominance_frontier(
240        body: &Body<'_>,
241        dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
242    ) -> HashMap<BasicBlock, HashSet<BasicBlock>> {
243        let mut dominance_frontier: HashMap<BasicBlock, HashSet<BasicBlock>> = HashMap::new();
244        let dominators = body.basic_blocks.dominators();
245        let predecessors = body.basic_blocks.predecessors();
246        for (block, _) in body.basic_blocks.iter_enumerated() {
247            dominance_frontier.entry(block).or_default();
248        }
249
250        for (block, _) in body.basic_blocks.iter_enumerated() {
251            if predecessors[block].len() > 1 {
252                let preds = body.basic_blocks.predecessors()[block].clone();
253
254                for &pred in &preds {
255                    let mut runner = pred;
256                    while runner != dominators.immediate_dominator(block).unwrap() {
257                        dominance_frontier.entry(runner).or_default().insert(block);
258                        runner = dominators.immediate_dominator(runner).unwrap();
259                    }
260                }
261            }
262        }
263
264        dominance_frontier
265    }
266    fn extract_cfg_from_predecessors(body: &Body<'_>) -> HashMap<BasicBlock, Vec<BasicBlock>> {
267        let mut cfg: HashMap<BasicBlock, Vec<BasicBlock>> = HashMap::new();
268
269        for (block, _) in body.basic_blocks.iter_enumerated() {
270            for &predecessor in body.basic_blocks.predecessors()[block].iter() {
271                cfg.entry(predecessor).or_default().push(block);
272            }
273        }
274
275        cfg
276    }
277    fn print_dominance_tree(
278        dom_tree: &HashMap<BasicBlock, Vec<BasicBlock>>,
279        current: BasicBlock,
280        depth: usize,
281    ) {
282        if let Some(children) = dom_tree.get(&current) {
283            for &child in children {
284                Self::print_dominance_tree(dom_tree, child, depth + 1);
285            }
286        }
287    }
288
289    pub fn is_phi_statement(&self, statement: &Statement<'tcx>) -> bool {
290        if let StatementKind::Assign(box (_, rvalue)) = &statement.kind {
291            if let Rvalue::Aggregate(box aggregate_kind, _) = rvalue {
292                if let AggregateKind::Adt(def_id, ..) = aggregate_kind {
293                    return *def_id == self.phi_def_id;
294                }
295            }
296        }
297        false
298    }
299
300    pub fn is_essa_statement(&self, statement: &Statement<'tcx>) -> bool {
301        if let StatementKind::Assign(box (_, rvalue)) = &statement.kind {
302            if let Rvalue::Aggregate(box aggregate_kind, _) = rvalue {
303                if let AggregateKind::Adt(def_id, ..) = aggregate_kind {
304                    return *def_id == self.essa_def_id;
305                }
306            }
307        }
308        false
309    }
310    pub fn get_essa_source_block(&self, statement: &Statement<'tcx>) -> Option<BasicBlock> {
311        if !self.is_essa_statement(statement) {
312            return None;
313        }
314
315        if let StatementKind::Assign(box (_, Rvalue::Aggregate(_, operands))) = &statement.kind {
316            if let Some(last_op) = operands.into_iter().last() {
317                if let Operand::Constant(box ConstOperand { const_: c, .. }) = last_op {
318                    if let Some(val) = self.try_const_to_usize(c) {
319                        return Some(BasicBlock::from_usize(val as usize));
320                    }
321                }
322            }
323        }
324        None
325    }
326
327    fn try_const_to_usize(&self, c: &Const<'tcx>) -> Option<u64> {
328        if let Some(scalar_int) = c.try_to_scalar_int() {
329            let size = scalar_int.size();
330            if let Ok(bits) = scalar_int.try_to_bits(size) {
331                return Some(bits as u64);
332            }
333        }
334        None
335    }
336}