rapx/analysis/core/api_dependency/
mono.rs

1#![allow(warnings, unused)]
2
3use super::graph::TyWrapper;
4use super::utils::{self, fn_sig_with_generic_args};
5use crate::analysis::utils::def_path::path_str_def_id;
6use crate::{rap_debug, rap_trace};
7use rand::Rng;
8use rand::seq::SliceRandom;
9use rustc_hir::LangItem;
10use rustc_hir::def_id::DefId;
11use rustc_infer::infer::DefineOpaqueTypes;
12use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
13use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
14use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
15use rustc_span::DUMMY_SP;
16use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
17use std::collections::HashSet;
18
19static MAX_STEP_SET_SIZE: usize = 1000;
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct Mono<'tcx> {
23    pub value: Vec<ty::GenericArg<'tcx>>,
24}
25
26impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
27    fn from_iter<T>(iter: T) -> Self
28    where
29        T: IntoIterator<Item = ty::GenericArg<'tcx>>,
30    {
31        Mono {
32            value: iter.into_iter().collect(),
33        }
34    }
35}
36
37impl<'tcx> Mono<'tcx> {
38    pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
39        Mono {
40            value: Vec::from(identity),
41        }
42    }
43
44    fn has_infer_types(&self) -> bool {
45        self.value.iter().any(|arg| match arg.kind() {
46            ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
47            _ => false,
48        })
49    }
50
51    fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
52        &mut self.value[idx]
53    }
54
55    fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
56        assert!(self.value.len() == other.value.len());
57        let mut res = Vec::new();
58        for i in 0..self.value.len() {
59            let arg = self.value[i];
60            let other_arg = other.value[i];
61            let new_arg = if let Some(ty) = arg.as_type() {
62                let other_ty = other_arg.expect_ty();
63                if ty.is_ty_var() && other_ty.is_ty_var() {
64                    arg
65                } else if ty.is_ty_var() {
66                    other_arg
67                } else if other_ty.is_ty_var() {
68                    arg
69                } else if utils::is_ty_eq(ty, other_ty, tcx) {
70                    arg
71                } else {
72                    return None;
73                }
74            } else {
75                arg
76            };
77            res.push(new_arg);
78        }
79        Some(Mono { value: res })
80    }
81
82    fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
83        let candidates = get_unbound_generic_candidates(tcx);
84        let mut res = vec![self.clone()];
85        rap_trace!("fill unbound: {:?}", self);
86
87        for (i, arg) in self.value.iter().enumerate() {
88            if let Some(ty) = arg.as_type() {
89                if ty.is_ty_var() {
90                    let mut last = Vec::new();
91                    std::mem::swap(&mut res, &mut last);
92                    last.into_iter().for_each(|mono| {
93                        for candidate in &candidates {
94                            let mut new_mono = mono.clone();
95                            *new_mono.mut_arg_at(i) = (*candidate).into();
96                            res.push(new_mono);
97                        }
98                    });
99                }
100            }
101        }
102        res
103    }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct MonoSet<'tcx> {
108    pub monos: Vec<Mono<'tcx>>,
109}
110
111impl<'tcx> MonoSet<'tcx> {
112    pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
113        MonoSet {
114            monos: vec![Mono::new(identity)],
115        }
116    }
117
118    pub fn empty() -> MonoSet<'tcx> {
119        MonoSet { monos: Vec::new() }
120    }
121
122    pub fn count(&self) -> usize {
123        self.monos.len()
124    }
125
126    pub fn at(&self, no: usize) -> &Mono<'tcx> {
127        &self.monos[no]
128    }
129
130    pub fn is_empty(&self) -> bool {
131        self.monos.is_empty()
132    }
133
134    pub fn new() -> MonoSet<'tcx> {
135        MonoSet { monos: Vec::new() }
136    }
137
138    pub fn insert(&mut self, mono: Mono<'tcx>) {
139        self.monos.push(mono);
140    }
141
142    pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
143        let mut res = MonoSet::new();
144
145        for args in self.monos.iter() {
146            for other_args in other.monos.iter() {
147                let merged = args.merge(other_args, tcx);
148                if let Some(mono) = merged {
149                    res.insert(mono);
150                }
151            }
152        }
153        res
154    }
155
156    fn filter_unbound_solution(mut self) -> Self {
157        self.monos.retain(|mono| mono.has_infer_types());
158        self
159    }
160
161    // if the unbound generic type is still exist (this could happen
162    // if `T` has no trait bounds at all)
163    // we substitute the unbound generic type with predefined type candidates
164    fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
165        let mut res = MonoSet::new();
166        for mono in &self.monos {
167            let filled = mono.fill_unbound_var(tcx);
168            res.monos.extend(filled);
169        }
170        res
171    }
172
173    fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
174        for mono in &mut self.monos {
175            mono.value
176                .iter_mut()
177                .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
178        }
179    }
180
181    pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
182        self.monos.retain(|args| f(args));
183        self
184    }
185
186    pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
187        if self.monos.len() <= MAX_STEP_SET_SIZE {
188            return;
189        }
190        self.monos.shuffle(rng);
191        self.monos.truncate(MAX_STEP_SET_SIZE);
192    }
193}
194
195/// try to unfiy lhs = rhs,
196/// e.g.,
197/// try_unify(Vec<T>, Vec<i32>, ...) = Some(i32)
198/// try_unify(Vec<T>, i32, ...) = None
199fn unify_ty<'tcx>(
200    lhs: Ty<'tcx>,
201    rhs: Ty<'tcx>,
202    identity: &[ty::GenericArg<'tcx>],
203    infcx: &InferCtxt<'tcx>,
204    cause: &ObligationCause<'tcx>,
205    param_env: ty::ParamEnv<'tcx>,
206) -> Option<Mono<'tcx>> {
207    // rap_info!("check {} = {}", lhs, rhs);
208    infcx.probe(|_| {
209        match infcx
210            .at(cause, param_env)
211            .eq(DefineOpaqueTypes::Yes, lhs, rhs)
212        {
213            Ok(_infer_ok) => {
214                // rap_trace!("[infer_ok] {} = {} : {:?}", lhs, rhs, infer_ok);
215                let mono = identity
216                    .iter()
217                    .map(|arg| match arg.kind() {
218                        ty::GenericArgKind::Lifetime(region) => {
219                            infcx.resolve_vars_if_possible(region).into()
220                        }
221                        ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
222                        ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
223                    })
224                    .collect();
225                Some(mono)
226            }
227            Err(_e) => {
228                // rap_trace!("[infer_err] {} = {} : {:?}", lhs, rhs, e);
229                None
230            }
231        }
232    })
233}
234
235fn is_args_fit_trait_bound<'tcx>(
236    fn_did: DefId,
237    args: &[ty::GenericArg<'tcx>],
238    tcx: TyCtxt<'tcx>,
239) -> bool {
240    let args = tcx.mk_args(args);
241    rap_trace!(
242        "fn: {:?} args: {:?} identity: {:?}",
243        fn_did,
244        args,
245        ty::GenericArgs::identity_for_item(tcx, fn_did)
246    );
247    let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
248    let pred = tcx.predicates_of(fn_did);
249    let inst_pred = pred.instantiate(tcx, args);
250    let param_env = tcx.param_env(fn_did);
251    rap_trace!(
252        "[trait bound] check {}",
253        tcx.def_path_str_with_args(fn_did, args)
254    );
255
256    for pred in inst_pred.predicates.iter() {
257        let obligation = Obligation::new(
258            tcx,
259            ObligationCause::dummy(),
260            param_env,
261            pred.as_predicate(),
262        );
263        rap_trace!("[trait bound] check pred: {:?}", pred);
264
265        let res = infcx.evaluate_obligation(&obligation);
266        match res {
267            Ok(eva) => {
268                if !eva.may_apply() {
269                    rap_trace!("[trait bound] check fail for {pred:?}");
270                    return false;
271                }
272            }
273            Err(_) => {
274                rap_trace!("[trait bound] check fail for {pred:?}");
275                return false;
276            }
277        }
278    }
279    rap_trace!("[trait bound] check succ");
280    true
281}
282
283fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
284    for pred in tcx
285        .predicates_of(fn_did)
286        .instantiate_identity(tcx)
287        .predicates
288    {
289        if let Some(pred) = pred.as_trait_clause() {
290            let trait_did = pred.skip_binder().trait_ref.def_id;
291            if tcx.is_lang_item(trait_did, LangItem::Fn)
292                || tcx.is_lang_item(trait_did, LangItem::FnMut)
293                || tcx.is_lang_item(trait_did, LangItem::FnOnce)
294            {
295                return false;
296            }
297        }
298    }
299    true
300}
301
302fn get_mono_set<'tcx>(
303    fn_did: DefId,
304    available_ty: &HashSet<TyWrapper<'tcx>>,
305    tcx: TyCtxt<'tcx>,
306) -> MonoSet<'tcx> {
307    let mut rng = rand::rng();
308
309    // sample from reachable types
310    rap_debug!("[get_mono_set] solve {}", tcx.def_path_str(fn_did));
311    let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
312    let infcx = tcx
313        .infer_ctxt()
314        .ignoring_regions()
315        .build(ty::TypingMode::PostAnalysis);
316    let param_env = tcx.param_env(fn_did);
317    let dummy_cause = ObligationCause::dummy();
318    let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
319    // this replace generic types in fn_sig to infer var, e.g. fn(Vec<T>, i32) => fn(Vec<?0>, i32)
320    let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
321    let identity_fnsig = fn_sig_with_generic_args(fn_did, identity, tcx);
322    let generics = tcx.generics_of(fn_did);
323
324    // print fresh_args for debugging
325    for i in 0..fresh_args.len() {
326        rap_trace!(
327            "[get_mono_set] arg#{}: {:?} -> {:?}",
328            i,
329            generics.param_at(i, tcx).name,
330            fresh_args[i]
331        );
332    }
333
334    let mut s = MonoSet::all(&fresh_args);
335
336    rap_trace!("[get_mono_set] initialize s: {:?}", s);
337
338    for (no, input_ty) in fn_sig.inputs().iter().enumerate() {
339        if !input_ty.has_infer_types() {
340            continue;
341        }
342        rap_debug!(
343            "[get_mono_set] input_ty#{}: {}",
344            no,
345            identity_fnsig.inputs()[no]
346        );
347
348        let reachable_set = available_ty
349            .iter()
350            .fold(MonoSet::new(), |mut reachable_set, ty| {
351                if let Some(mono) = unify_ty(
352                    *input_ty,
353                    (*ty).into(),
354                    &fresh_args,
355                    &infcx,
356                    &dummy_cause,
357                    param_env,
358                ) {
359                    reachable_set.insert(mono);
360                }
361                reachable_set
362            });
363        // reachable_set.random_sample(&mut rng);
364        rap_debug!(
365            "[get_mono_set] size of s: {}, size of input: {}",
366            s.count(),
367            reachable_set.count()
368        );
369        rap_trace!("[get_mono_set] input = {:?}", reachable_set);
370        s = s.merge(&reachable_set, tcx);
371        s.random_sample(&mut rng);
372        rap_trace!("[get_mono_set] after merge s = {:?}", reachable_set);
373    }
374
375    rap_debug!(
376        "[get_mono_set] after input filter, size of s: {}",
377        s.count()
378    );
379
380    let mut res = MonoSet::new();
381
382    for mono in s.monos {
383        solve_unbound_type_generics(
384            fn_did,
385            mono,
386            &mut res,
387            // &fresh_args,
388            &infcx,
389            &dummy_cause,
390            param_env,
391            tcx,
392        );
393    }
394
395    // erase infer region var
396    res.erase_region_var(tcx);
397
398    // if there is still unbound generic type, we try to instantiate it with predefined candidates
399    res.instantiate_unbound(tcx)
400}
401
402fn solve_unbound_type_generics<'tcx>(
403    did: DefId,
404    mono: Mono<'tcx>,
405    res: &mut MonoSet<'tcx>,
406    infcx: &InferCtxt<'tcx>,
407    cause: &ObligationCause<'tcx>,
408    param_env: ty::ParamEnv<'tcx>,
409    tcx: TyCtxt<'tcx>,
410) {
411    if !mono.has_infer_types() {
412        res.insert(mono);
413        return;
414    }
415    let args = tcx.mk_args(&mono.value);
416    let preds = tcx.predicates_of(did).instantiate(tcx, args);
417    let mut mset = MonoSet::all(args);
418    rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
419    for pred in preds.predicates.iter() {
420        rap_debug!("[solve_unbound] pred = {:?}", pred);
421        if let Some(trait_pred) = pred.as_trait_clause() {
422            let trait_pred = trait_pred.skip_binder();
423
424            rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
425
426            let trait_def_id = trait_pred.trait_ref.def_id;
427            // ignore Sized trait
428            if tcx.is_lang_item(trait_def_id, LangItem::Sized)
429                || tcx.is_lang_item(trait_def_id, LangItem::Copy)
430            {
431                continue;
432            }
433
434            let mut p = MonoSet::new();
435
436            for impl_did in tcx.all_impls(trait_def_id)
437            // .chain(tcx.inherent_impls(trait_def_id).iter().map(|did| *did))
438            {
439                // format: <arg0 as Trait<arg1, arg2>>
440                let impl_trait_ref = tcx.impl_trait_ref(impl_did).skip_binder();
441
442                // filter irrelevant implementation. We only consider implementation that:
443                // 1. it is local
444                // 2. it is not local, but its' self_ty is a primitive
445                if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
446                    continue;
447                }
448
449                if let Some(mono) = unify_trait(
450                    trait_pred.trait_ref,
451                    impl_trait_ref,
452                    args,
453                    &infcx,
454                    &cause,
455                    param_env,
456                    tcx,
457                ) {
458                    p.insert(mono);
459                }
460            }
461            mset = mset.merge(&p, tcx);
462            rap_trace!("[solve_unbound] mset: {:?}", mset);
463        }
464    }
465
466    rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
467    for mono in mset.monos {
468        res.insert(mono);
469    }
470}
471
472/// only handle the case that rhs does not have any infer types
473/// e.g., `<T as Into<U>> == <Foo as Into<Bar>> => Some(T=Foo, U=Bar))`
474fn unify_trait<'tcx>(
475    lhs: ty::TraitRef<'tcx>,
476    rhs: ty::TraitRef<'tcx>,
477    identity: &[ty::GenericArg<'tcx>],
478    infcx: &InferCtxt<'tcx>,
479    cause: &ObligationCause<'tcx>,
480    param_env: ty::ParamEnv<'tcx>,
481    tcx: TyCtxt<'tcx>,
482) -> Option<Mono<'tcx>> {
483    rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
484    if lhs.def_id != rhs.def_id {
485        return None;
486    }
487
488    assert!(lhs.args.len() == rhs.args.len());
489    let mut s = Mono::new(identity);
490    for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
491        if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
492            if rhs_ty.has_infer_types() || rhs_ty.has_param() {
493                // if rhs has infer types, we cannot unify it with lhs
494                return None;
495            }
496            let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
497            rap_trace!("[unify_trait] unified mono: {:?}", mono);
498            s = s.merge(&mono, tcx)?;
499        }
500    }
501    Some(s)
502}
503
504pub fn resolve_mono_apis<'tcx>(
505    fn_did: DefId,
506    available_ty: &HashSet<TyWrapper<'tcx>>,
507    tcx: TyCtxt<'tcx>,
508) -> MonoSet<'tcx> {
509    // 1. check solvable condition
510    if !is_fn_solvable(fn_did, tcx) {
511        return MonoSet::empty();
512    }
513
514    // 2. get mono set from available types
515    let ret = get_mono_set(fn_did, &available_ty, tcx);
516
517    // 3. check trait bound & ty is stable
518    let ret = ret.filter(|mono| {
519        is_args_fit_trait_bound(fn_did, &mono.value, tcx)
520            && mono.value.iter().all(|arg| {
521                if let Some(ty) = arg.as_type() {
522                    !utils::is_ty_unstable(ty, tcx)
523                } else {
524                    true
525                }
526            })
527    });
528
529    rap_debug!(
530        "[resolve_mono_apis] fn_did: {:?}, size of mono: {:?}",
531        fn_did,
532        ret.count()
533    );
534
535    ret
536}
537
538pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
539    let mut new_tys = Vec::new();
540    available_ty.iter().for_each(|ty| {
541        new_tys.push(
542            Ty::new_ref(
543                tcx,
544                tcx.lifetimes.re_erased,
545                (*ty).into(),
546                ty::Mutability::Not,
547            )
548            .into(),
549        );
550        new_tys.push(Ty::new_ref(
551            tcx,
552            tcx.lifetimes.re_erased,
553            (*ty).into(),
554            ty::Mutability::Mut,
555        ));
556        new_tys.push(Ty::new_ref(
557            tcx,
558            tcx.lifetimes.re_erased,
559            Ty::new_slice(tcx, (*ty).into()),
560            ty::Mutability::Not,
561        ));
562        new_tys.push(Ty::new_ref(
563            tcx,
564            tcx.lifetimes.re_erased,
565            Ty::new_slice(tcx, (*ty).into()),
566            ty::Mutability::Mut,
567        ));
568    });
569
570    new_tys.into_iter().for_each(|ty| {
571        available_ty.insert(ty.into());
572    });
573}
574
575pub fn eliminate_infer_var<'tcx>(
576    fn_did: DefId,
577    args: &[ty::GenericArg<'tcx>],
578    tcx: TyCtxt<'tcx>,
579) -> Vec<ty::GenericArg<'tcx>> {
580    let mut res = Vec::new();
581    let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
582    for (i, arg) in args.iter().enumerate() {
583        if let Some(ty) = arg.as_type() {
584            if ty.is_ty_var() {
585                res.push(identity[i]);
586            } else {
587                res.push(*arg);
588            }
589        } else {
590            res.push(*arg);
591        }
592    }
593    res
594}
595
596/// if type parameter is unbound, e.g., `T` in `fn foo<T>()`,
597/// we use some predefined types to substitute it
598pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
599    vec![
600        tcx.types.bool,
601        tcx.types.char,
602        tcx.types.u8,
603        tcx.types.i8,
604        tcx.types.i32,
605        tcx.types.u32,
606        // tcx.types.i64,
607        // tcx.types.u64,
608        tcx.types.f32,
609        // tcx.types.f64,
610        Ty::new_imm_ref(
611            tcx,
612            tcx.lifetimes.re_erased,
613            Ty::new_slice(tcx, tcx.types.u8),
614        ),
615        Ty::new_mut_ref(
616            tcx,
617            tcx.lifetimes.re_erased,
618            Ty::new_slice(tcx, tcx.types.u8),
619        ),
620    ]
621}
622
623// calculate the complexity of monomorphic solution,
624// complexity = sum of complexity of each type argument
625pub fn get_mono_complexity<'tcx>(args: &GenericArgsRef<'tcx>) -> usize {
626    args.iter().fold(0, |acc, arg| {
627        if let Some(ty) = arg.as_type() {
628            acc + utils::ty_complexity(ty)
629        } else {
630            acc
631        }
632    })
633}
634
635pub fn get_impls<'tcx>(
636    tcx: TyCtxt<'tcx>,
637    fn_did: DefId,
638    args: GenericArgsRef<'tcx>,
639) -> HashSet<DefId> {
640    rap_debug!(
641        "get impls for fn: {:?} args: {:?}",
642        tcx.def_path_str_with_args(fn_did, args),
643        args
644    );
645    let mut impls = HashSet::new();
646    let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
647    for (pred, _) in preds {
648        if let Some(trait_pred) = pred.as_trait_clause() {
649            let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> = tcx
650                .liberate_late_bound_regions(fn_did, trait_pred)
651                .trait_ref;
652
653            let res = tcx.codegen_select_candidate(
654                TypingEnv::fully_monomorphized().as_query_input(trait_ref),
655            );
656            if let Ok(source) = res {
657                match source {
658                    ImplSource::UserDefined(data) => {
659                        if data.impl_def_id.is_local() {
660                            impls.insert(data.impl_def_id);
661                        }
662                    }
663                    _ => {}
664                }
665            }
666            // rap_debug!("{:?} => {:?}", trait_ref, res);
667        }
668    }
669    rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
670    impls
671}