rapx/analysis/core/api_dependency/graph/
transform.rs

1use super::dep_edge::DepEdge;
2use super::{ApiDependencyGraph, DepNode, TyWrapper};
3use petgraph::graph::NodeIndex;
4use rustc_middle::ty::{self};
5use serde::Serialize;
6use std::fmt::Display;
7
8static ALL_TRANSFORMKIND: [TransformKind; 2] = [
9    TransformKind::Ref(ty::Mutability::Not),
10    TransformKind::Ref(ty::Mutability::Mut),
11    // TransformKind::Deref,
12    // TransformKind::Box,
13];
14
15#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
16pub enum TransformKind {
17    Ref(ty::Mutability),
18    Unwrap, // unwrap Option<T>, Result<T, E>
19}
20
21impl TransformKind {
22    pub fn all() -> &'static [TransformKind] {
23        &ALL_TRANSFORMKIND
24    }
25}
26
27impl Display for TransformKind {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            TransformKind::Ref(mutability) => write!(f, "{}T", mutability.ref_prefix_str()),
31            TransformKind::Unwrap => write!(f, "Unwrap"),
32        }
33    }
34}
35
36impl Serialize for TransformKind {
37    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38    where
39        S: serde::Serializer,
40    {
41        serializer.serialize_str(&self.to_string())
42    }
43}
44
45impl<'tcx> ApiDependencyGraph<'tcx> {
46    pub fn update_transform_edges(&mut self) {
47        for node_index in self.graph.node_indices() {
48            if let DepNode::Ty(ty) = self.graph[node_index] {
49                self.add_possible_transform::<3>(ty, 0);
50            }
51        }
52    }
53
54    fn add_possible_transform<const MAX_DEPTH: usize>(
55        &mut self,
56        current_ty: TyWrapper<'tcx>,
57        depth: usize,
58    ) -> Option<NodeIndex> {
59        if depth > 0 {
60            let index = self.get_index(DepNode::Ty(current_ty));
61            if index.is_some() {
62                return index;
63            }
64        }
65
66        if depth >= MAX_DEPTH {
67            return None;
68        }
69
70        let mut ret = None;
71        for kind in TransformKind::all() {
72            let new_ty = current_ty.transform(*kind, self.tcx()); // &T or &mut T
73            if let Some(next_index) = self.add_possible_transform::<MAX_DEPTH>(new_ty, depth + 1) {
74                let current_index = self.get_or_create_index(DepNode::Ty(current_ty));
75                self.add_edge_once(current_index, next_index, DepEdge::transform(*kind));
76                ret = Some(current_index);
77            }
78        }
79        ret
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::TransformKind;
86    use rustc_middle::ty;
87
88    #[test]
89    fn serialize_ref_not_matches_expected() {
90        let kind = TransformKind::Ref(ty::Mutability::Not);
91        let serialized = serde_json::to_string(&kind).expect("serialize TransformKind::Ref(Not)");
92        assert_eq!(serialized, "\"&T\"");
93    }
94
95    #[test]
96    fn serialize_ref_mut_matches_expected() {
97        let kind = TransformKind::Ref(ty::Mutability::Mut);
98        let serialized = serde_json::to_string(&kind).expect("serialize TransformKind::Ref(Mut)");
99        assert_eq!(serialized, "\"&mut T\"");
100    }
101
102    #[test]
103    fn serialize_unwrap_matches_expected() {
104        let kind = TransformKind::Unwrap;
105        let serialized = serde_json::to_string(&kind).expect("serialize TransformKind::Unwrap");
106        assert_eq!(serialized, "\"Unwrap\"");
107    }
108}