rapx/check/senryx/
generic_check.rs

1//! Generic-parameter helper for Senryx layout-sensitive checks.
2//!
3//! The current implementation maps selected generic trait bounds to a finite
4//! set of representative concrete types. Alignment and size checks use these
5//! representatives when a contract mentions a generic type.
6
7use std::collections::{HashMap, HashSet};
8
9use if_chain::if_chain;
10use rustc_hir::{ImplPolarity, ItemId, ItemKind, hir_id::OwnerId};
11use rustc_middle::ty::{FloatTy, IntTy, ParamEnv, Ty, TyCtxt, TyKind, UintTy};
12// use crate::rap_info;
13
14/// Computes representative concrete types for generic parameters.
15pub struct GenericChecker<'tcx> {
16    // tcx: TyCtxt<'tcx>,
17    trait_map: HashMap<String, HashSet<Ty<'tcx>>>,
18}
19
20impl<'tcx> GenericChecker<'tcx> {
21    /// Build a generic checker from the current type context and parameter environment.
22    pub fn new(tcx: TyCtxt<'tcx>, p_env: ParamEnv<'tcx>) -> Self {
23        let mut trait_bnd_map_for_generic: HashMap<String, HashSet<String>> = HashMap::new();
24        let mut satisfied_ty_map_for_generic: HashMap<String, HashSet<Ty<'tcx>>> = HashMap::new();
25
26        for cb in p_env.caller_bounds() {
27            // cb: Binder(TraitPredicate(<Self as trait>, ..)
28            // Focus on the trait bound applied to our generic parameter
29
30            if let Some(trait_pred) = cb.as_trait_clause() {
31                let trait_def_id = trait_pred.def_id();
32                let generic_name = trait_pred.self_ty().skip_binder().to_string();
33                let satisfied_ty_set = satisfied_ty_map_for_generic
34                    .entry(generic_name.clone())
35                    .or_insert_with(|| HashSet::new());
36                let trait_name = tcx.def_path_str(trait_def_id);
37                let trait_bnd_set = trait_bnd_map_for_generic
38                    .entry(generic_name)
39                    .or_insert_with(|| HashSet::new());
40                trait_bnd_set.insert(trait_name.clone());
41
42                // for each implementation
43                for def_id in tcx.all_impls(trait_def_id) {
44                    // impl_id: LocalDefId
45                    if !def_id.is_local() {
46                        continue;
47                    }
48                    let impl_owner_id = tcx
49                        .hir_owner_node(OwnerId {
50                            def_id: def_id.expect_local(),
51                        })
52                        .def_id();
53
54                    let item = tcx.hir_item(ItemId {
55                        owner_id: impl_owner_id,
56                    });
57                    if_chain! {
58                        if let ItemKind::Impl(impl_item) = item.kind;
59                        if let Some(trait_impl_header) = impl_item.of_trait;
60                        if trait_impl_header.polarity == ImplPolarity::Positive;
61                        if let Some(binder) = tcx.impl_opt_trait_ref(def_id);
62                        then {
63                            let trait_ref = binder.skip_binder();
64                            let impl_ty = trait_ref.self_ty();
65                            match impl_ty.kind() {
66                                TyKind::Adt(adt_def, _impl_trait_substs) => {
67                                    let adt_did = adt_def.did();
68                                    let adt_ty = tcx.type_of(adt_did).skip_binder();
69                                    // rap_info!("{} is implemented on adt({:?})", trait_name, adt_ty);
70                                    satisfied_ty_set.insert(adt_ty);
71                                },
72                                TyKind::Param(p_ty) => {
73                                    let _param_ty = p_ty.to_ty(tcx);
74                                },
75                                _ => {
76                                    // rap_info!("{} is implemented on {:?}", trait_name, impl_ty);
77                                    satisfied_ty_set.insert(impl_ty);
78                                },
79                            }
80                        }
81                    }
82                }
83
84                // handle known external trait e.g., Pod
85                if trait_name == "bytemuck::Pod" || trait_name == "plain::Plain" {
86                    let ty_bnd = Self::get_satisfied_ty_for_pod(tcx);
87                    satisfied_ty_set.extend(&ty_bnd);
88                    // rap_info!("current trait bound type set: {:?}", satisfied_ty_set);
89                }
90            }
91        }
92
93        // check trait_bnd_set
94        let std_trait_set = HashSet::from([
95            String::from("std::marker::Copy"),
96            String::from("std::clone::Clone"),
97            String::from("std::marker::Sized"),
98        ]);
99        // if all trait_bound is std::marker, then we could assume it to be arbitrary type
100        // to avoid messing up with build type manually
101        // we just clear the satisfied ty set
102        for (key, satisfied_ty_set) in &mut satisfied_ty_map_for_generic {
103            let trait_bnd_set = trait_bnd_map_for_generic
104                .entry(key.clone())
105                .or_insert_with(|| HashSet::new());
106            if trait_bnd_set.is_subset(&std_trait_set) {
107                satisfied_ty_set.clear();
108            }
109        }
110
111        // rap_info!("trait bound type map: {:?}", satisfied_ty_map_for_generic);
112
113        GenericChecker {
114            trait_map: satisfied_ty_map_for_generic,
115        }
116    }
117
118    /// Return the representative type set for each generic parameter.
119    pub fn get_satisfied_ty_map(&self) -> HashMap<String, HashSet<Ty<'tcx>>> {
120        self.trait_map.clone()
121    }
122
123    fn get_satisfied_ty_for_pod(tcx: TyCtxt<'tcx>) -> HashSet<Ty<'tcx>> {
124        let mut satisfied_ty_set_for_pod: HashSet<Ty<'tcx>> = HashSet::new();
125        // f64, u64, i8, i32, u8, i16, u16, u32, usize, i128, isize, i64, u128, f32
126        let pod_ty = [
127            tcx.mk_ty_from_kind(TyKind::Int(IntTy::Isize)),
128            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I8)),
129            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I16)),
130            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I32)),
131            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I64)),
132            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I128)),
133            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::Usize)),
134            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U8)),
135            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U16)),
136            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U32)),
137            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U64)),
138            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U128)),
139            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F32)),
140            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F64)),
141        ];
142
143        for pt in pod_ty.iter() {
144            satisfied_ty_set_for_pod.insert(*pt);
145        }
146        satisfied_ty_set_for_pod.clone()
147    }
148}