1#![allow(warnings, unused)]
2
3use super::graph::TyWrapper;
4use super::utils::{self, fn_sig_with_generic_args};
5use crate::analysis::utils::def_path::path_str_def_id;
6use crate::{rap_debug, rap_trace};
7use rand::Rng;
8use rand::seq::SliceRandom;
9use rustc_hir::LangItem;
10use rustc_hir::def_id::DefId;
11use rustc_infer::infer::DefineOpaqueTypes;
12use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
13use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
14use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
15use rustc_span::DUMMY_SP;
16use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
17use std::collections::HashSet;
18
19static MAX_STEP_SET_SIZE: usize = 1000;
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct Mono<'tcx> {
23 pub value: Vec<ty::GenericArg<'tcx>>,
24}
25
26impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
27 fn from_iter<T>(iter: T) -> Self
28 where
29 T: IntoIterator<Item = ty::GenericArg<'tcx>>,
30 {
31 Mono {
32 value: iter.into_iter().collect(),
33 }
34 }
35}
36
37impl<'tcx> Mono<'tcx> {
38 pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
39 Mono {
40 value: Vec::from(identity),
41 }
42 }
43
44 fn has_infer_types(&self) -> bool {
45 self.value.iter().any(|arg| match arg.kind() {
46 ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
47 _ => false,
48 })
49 }
50
51 fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
52 &mut self.value[idx]
53 }
54
55 fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
56 assert!(self.value.len() == other.value.len());
57 let mut res = Vec::new();
58 for i in 0..self.value.len() {
59 let arg = self.value[i];
60 let other_arg = other.value[i];
61 let new_arg = if let Some(ty) = arg.as_type() {
62 let other_ty = other_arg.expect_ty();
63 if ty.is_ty_var() && other_ty.is_ty_var() {
64 arg
65 } else if ty.is_ty_var() {
66 other_arg
67 } else if other_ty.is_ty_var() {
68 arg
69 } else if utils::is_ty_eq(ty, other_ty, tcx) {
70 arg
71 } else {
72 return None;
73 }
74 } else {
75 arg
76 };
77 res.push(new_arg);
78 }
79 Some(Mono { value: res })
80 }
81
82 fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
83 let candidates = get_unbound_generic_candidates(tcx);
84 let mut res = vec![self.clone()];
85 rap_trace!("fill unbound: {:?}", self);
86
87 for (i, arg) in self.value.iter().enumerate() {
88 if let Some(ty) = arg.as_type() {
89 if ty.is_ty_var() {
90 let mut last = Vec::new();
91 std::mem::swap(&mut res, &mut last);
92 last.into_iter().for_each(|mono| {
93 for candidate in &candidates {
94 let mut new_mono = mono.clone();
95 *new_mono.mut_arg_at(i) = (*candidate).into();
96 res.push(new_mono);
97 }
98 });
99 }
100 }
101 }
102 res
103 }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct MonoSet<'tcx> {
108 pub monos: Vec<Mono<'tcx>>,
109}
110
111impl<'tcx> MonoSet<'tcx> {
112 pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
113 MonoSet {
114 monos: vec![Mono::new(identity)],
115 }
116 }
117
118 pub fn empty() -> MonoSet<'tcx> {
119 MonoSet { monos: Vec::new() }
120 }
121
122 pub fn count(&self) -> usize {
123 self.monos.len()
124 }
125
126 pub fn at(&self, no: usize) -> &Mono<'tcx> {
127 &self.monos[no]
128 }
129
130 pub fn is_empty(&self) -> bool {
131 self.monos.is_empty()
132 }
133
134 pub fn new() -> MonoSet<'tcx> {
135 MonoSet { monos: Vec::new() }
136 }
137
138 pub fn insert(&mut self, mono: Mono<'tcx>) {
139 self.monos.push(mono);
140 }
141
142 pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
143 let mut res = MonoSet::new();
144
145 for args in self.monos.iter() {
146 for other_args in other.monos.iter() {
147 let merged = args.merge(other_args, tcx);
148 if let Some(mono) = merged {
149 res.insert(mono);
150 }
151 }
152 }
153 res
154 }
155
156 fn filter_unbound_solution(mut self) -> Self {
157 self.monos.retain(|mono| mono.has_infer_types());
158 self
159 }
160
161 fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
165 let mut res = MonoSet::new();
166 for mono in &self.monos {
167 let filled = mono.fill_unbound_var(tcx);
168 res.monos.extend(filled);
169 }
170 res
171 }
172
173 fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
174 for mono in &mut self.monos {
175 mono.value
176 .iter_mut()
177 .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
178 }
179 }
180
181 pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
182 self.monos.retain(|args| f(args));
183 self
184 }
185
186 pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
187 if self.monos.len() <= MAX_STEP_SET_SIZE {
188 return;
189 }
190 self.monos.shuffle(rng);
191 self.monos.truncate(MAX_STEP_SET_SIZE);
192 }
193}
194
195fn unify_ty<'tcx>(
200 lhs: Ty<'tcx>,
201 rhs: Ty<'tcx>,
202 identity: &[ty::GenericArg<'tcx>],
203 infcx: &InferCtxt<'tcx>,
204 cause: &ObligationCause<'tcx>,
205 param_env: ty::ParamEnv<'tcx>,
206) -> Option<Mono<'tcx>> {
207 infcx.probe(|_| {
209 match infcx
210 .at(cause, param_env)
211 .eq(DefineOpaqueTypes::Yes, lhs, rhs)
212 {
213 Ok(_infer_ok) => {
214 let mono = identity
216 .iter()
217 .map(|arg| match arg.kind() {
218 ty::GenericArgKind::Lifetime(region) => {
219 infcx.resolve_vars_if_possible(region).into()
220 }
221 ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
222 ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
223 })
224 .collect();
225 Some(mono)
226 }
227 Err(_e) => {
228 None
230 }
231 }
232 })
233}
234
235fn is_args_fit_trait_bound<'tcx>(
236 fn_did: DefId,
237 args: &[ty::GenericArg<'tcx>],
238 tcx: TyCtxt<'tcx>,
239) -> bool {
240 let args = tcx.mk_args(args);
241 rap_trace!(
242 "fn: {:?} args: {:?} identity: {:?}",
243 fn_did,
244 args,
245 ty::GenericArgs::identity_for_item(tcx, fn_did)
246 );
247 let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
248 let pred = tcx.predicates_of(fn_did);
249 let inst_pred = pred.instantiate(tcx, args);
250 let param_env = tcx.param_env(fn_did);
251 rap_trace!(
252 "[trait bound] check {}",
253 tcx.def_path_str_with_args(fn_did, args)
254 );
255
256 for pred in inst_pred.predicates.iter() {
257 let obligation = Obligation::new(
258 tcx,
259 ObligationCause::dummy(),
260 param_env,
261 pred.as_predicate(),
262 );
263 rap_trace!("[trait bound] check pred: {:?}", pred);
264
265 let res = infcx.evaluate_obligation(&obligation);
266 match res {
267 Ok(eva) => {
268 if !eva.may_apply() {
269 rap_trace!("[trait bound] check fail for {pred:?}");
270 return false;
271 }
272 }
273 Err(_) => {
274 rap_trace!("[trait bound] check fail for {pred:?}");
275 return false;
276 }
277 }
278 }
279 rap_trace!("[trait bound] check succ");
280 true
281}
282
283fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
284 for pred in tcx
285 .predicates_of(fn_did)
286 .instantiate_identity(tcx)
287 .predicates
288 {
289 if let Some(pred) = pred.as_trait_clause() {
290 let trait_did = pred.skip_binder().trait_ref.def_id;
291 if tcx.is_lang_item(trait_did, LangItem::Fn)
292 || tcx.is_lang_item(trait_did, LangItem::FnMut)
293 || tcx.is_lang_item(trait_did, LangItem::FnOnce)
294 {
295 return false;
296 }
297 }
298 }
299 true
300}
301
302fn get_mono_set<'tcx>(
303 fn_did: DefId,
304 available_ty: &HashSet<TyWrapper<'tcx>>,
305 tcx: TyCtxt<'tcx>,
306) -> MonoSet<'tcx> {
307 let mut rng = rand::rng();
308
309 rap_debug!("[get_mono_set] solve {}", tcx.def_path_str(fn_did));
311 let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
312 let infcx = tcx
313 .infer_ctxt()
314 .ignoring_regions()
315 .build(ty::TypingMode::PostAnalysis);
316 let param_env = tcx.param_env(fn_did);
317 let dummy_cause = ObligationCause::dummy();
318 let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
319 let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
321 let identity_fnsig = fn_sig_with_generic_args(fn_did, identity, tcx);
322 let generics = tcx.generics_of(fn_did);
323
324 for i in 0..fresh_args.len() {
326 rap_trace!(
327 "[get_mono_set] arg#{}: {:?} -> {:?}",
328 i,
329 generics.param_at(i, tcx).name,
330 fresh_args[i]
331 );
332 }
333
334 let mut s = MonoSet::all(&fresh_args);
335
336 rap_trace!("[get_mono_set] initialize s: {:?}", s);
337
338 for (no, input_ty) in fn_sig.inputs().iter().enumerate() {
339 if !input_ty.has_infer_types() {
340 continue;
341 }
342 rap_debug!(
343 "[get_mono_set] input_ty#{}: {}",
344 no,
345 identity_fnsig.inputs()[no]
346 );
347
348 let reachable_set = available_ty
349 .iter()
350 .fold(MonoSet::new(), |mut reachable_set, ty| {
351 if let Some(mono) = unify_ty(
352 *input_ty,
353 (*ty).into(),
354 &fresh_args,
355 &infcx,
356 &dummy_cause,
357 param_env,
358 ) {
359 reachable_set.insert(mono);
360 }
361 reachable_set
362 });
363 rap_debug!(
365 "[get_mono_set] size of s: {}, size of input: {}",
366 s.count(),
367 reachable_set.count()
368 );
369 rap_trace!("[get_mono_set] input = {:?}", reachable_set);
370 s = s.merge(&reachable_set, tcx);
371 s.random_sample(&mut rng);
372 rap_trace!("[get_mono_set] after merge s = {:?}", reachable_set);
373 }
374
375 rap_debug!(
376 "[get_mono_set] after input filter, size of s: {}",
377 s.count()
378 );
379
380 let mut res = MonoSet::new();
381
382 for mono in s.monos {
383 solve_unbound_type_generics(
384 fn_did,
385 mono,
386 &mut res,
387 &infcx,
389 &dummy_cause,
390 param_env,
391 tcx,
392 );
393 }
394
395 res.erase_region_var(tcx);
397
398 res.instantiate_unbound(tcx)
400}
401
402fn solve_unbound_type_generics<'tcx>(
403 did: DefId,
404 mono: Mono<'tcx>,
405 res: &mut MonoSet<'tcx>,
406 infcx: &InferCtxt<'tcx>,
407 cause: &ObligationCause<'tcx>,
408 param_env: ty::ParamEnv<'tcx>,
409 tcx: TyCtxt<'tcx>,
410) {
411 if !mono.has_infer_types() {
412 res.insert(mono);
413 return;
414 }
415 let args = tcx.mk_args(&mono.value);
416 let preds = tcx.predicates_of(did).instantiate(tcx, args);
417 let mut mset = MonoSet::all(args);
418 rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
419 for pred in preds.predicates.iter() {
420 rap_debug!("[solve_unbound] pred = {:?}", pred);
421 if let Some(trait_pred) = pred.as_trait_clause() {
422 let trait_pred = trait_pred.skip_binder();
423
424 rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
425
426 let trait_def_id = trait_pred.trait_ref.def_id;
427 if tcx.is_lang_item(trait_def_id, LangItem::Sized)
429 || tcx.is_lang_item(trait_def_id, LangItem::Copy)
430 {
431 continue;
432 }
433
434 let mut p = MonoSet::new();
435
436 for impl_did in tcx.all_impls(trait_def_id)
437 {
439 let impl_trait_ref = tcx.impl_trait_ref(impl_did).skip_binder();
441
442 if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
446 continue;
447 }
448
449 if let Some(mono) = unify_trait(
450 trait_pred.trait_ref,
451 impl_trait_ref,
452 args,
453 &infcx,
454 &cause,
455 param_env,
456 tcx,
457 ) {
458 p.insert(mono);
459 }
460 }
461 mset = mset.merge(&p, tcx);
462 rap_trace!("[solve_unbound] mset: {:?}", mset);
463 }
464 }
465
466 rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
467 for mono in mset.monos {
468 res.insert(mono);
469 }
470}
471
472fn unify_trait<'tcx>(
475 lhs: ty::TraitRef<'tcx>,
476 rhs: ty::TraitRef<'tcx>,
477 identity: &[ty::GenericArg<'tcx>],
478 infcx: &InferCtxt<'tcx>,
479 cause: &ObligationCause<'tcx>,
480 param_env: ty::ParamEnv<'tcx>,
481 tcx: TyCtxt<'tcx>,
482) -> Option<Mono<'tcx>> {
483 rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
484 if lhs.def_id != rhs.def_id {
485 return None;
486 }
487
488 assert!(lhs.args.len() == rhs.args.len());
489 let mut s = Mono::new(identity);
490 for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
491 if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
492 if rhs_ty.has_infer_types() || rhs_ty.has_param() {
493 return None;
495 }
496 let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
497 rap_trace!("[unify_trait] unified mono: {:?}", mono);
498 s = s.merge(&mono, tcx)?;
499 }
500 }
501 Some(s)
502}
503
504pub fn resolve_mono_apis<'tcx>(
505 fn_did: DefId,
506 available_ty: &HashSet<TyWrapper<'tcx>>,
507 tcx: TyCtxt<'tcx>,
508) -> MonoSet<'tcx> {
509 if !is_fn_solvable(fn_did, tcx) {
511 return MonoSet::empty();
512 }
513
514 let ret = get_mono_set(fn_did, &available_ty, tcx);
516
517 let ret = ret.filter(|mono| {
519 is_args_fit_trait_bound(fn_did, &mono.value, tcx)
520 && mono.value.iter().all(|arg| {
521 if let Some(ty) = arg.as_type() {
522 !utils::is_ty_unstable(ty, tcx)
523 } else {
524 true
525 }
526 })
527 });
528
529 rap_debug!(
530 "[resolve_mono_apis] fn_did: {:?}, size of mono: {:?}",
531 fn_did,
532 ret.count()
533 );
534
535 ret
536}
537
538pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
539 let mut new_tys = Vec::new();
540 available_ty.iter().for_each(|ty| {
541 new_tys.push(
542 Ty::new_ref(
543 tcx,
544 tcx.lifetimes.re_erased,
545 (*ty).into(),
546 ty::Mutability::Not,
547 )
548 .into(),
549 );
550 new_tys.push(Ty::new_ref(
551 tcx,
552 tcx.lifetimes.re_erased,
553 (*ty).into(),
554 ty::Mutability::Mut,
555 ));
556 new_tys.push(Ty::new_ref(
557 tcx,
558 tcx.lifetimes.re_erased,
559 Ty::new_slice(tcx, (*ty).into()),
560 ty::Mutability::Not,
561 ));
562 new_tys.push(Ty::new_ref(
563 tcx,
564 tcx.lifetimes.re_erased,
565 Ty::new_slice(tcx, (*ty).into()),
566 ty::Mutability::Mut,
567 ));
568 });
569
570 new_tys.into_iter().for_each(|ty| {
571 available_ty.insert(ty.into());
572 });
573}
574
575pub fn eliminate_infer_var<'tcx>(
576 fn_did: DefId,
577 args: &[ty::GenericArg<'tcx>],
578 tcx: TyCtxt<'tcx>,
579) -> Vec<ty::GenericArg<'tcx>> {
580 let mut res = Vec::new();
581 let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
582 for (i, arg) in args.iter().enumerate() {
583 if let Some(ty) = arg.as_type() {
584 if ty.is_ty_var() {
585 res.push(identity[i]);
586 } else {
587 res.push(*arg);
588 }
589 } else {
590 res.push(*arg);
591 }
592 }
593 res
594}
595
596pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
599 vec![
600 tcx.types.bool,
601 tcx.types.char,
602 tcx.types.u8,
603 tcx.types.i8,
604 tcx.types.i32,
605 tcx.types.u32,
606 tcx.types.f32,
609 Ty::new_imm_ref(
611 tcx,
612 tcx.lifetimes.re_erased,
613 Ty::new_slice(tcx, tcx.types.u8),
614 ),
615 Ty::new_mut_ref(
616 tcx,
617 tcx.lifetimes.re_erased,
618 Ty::new_slice(tcx, tcx.types.u8),
619 ),
620 ]
621}
622
623pub fn get_mono_complexity<'tcx>(args: &GenericArgsRef<'tcx>) -> usize {
626 args.iter().fold(0, |acc, arg| {
627 if let Some(ty) = arg.as_type() {
628 acc + utils::ty_complexity(ty)
629 } else {
630 acc
631 }
632 })
633}
634
635pub fn get_impls<'tcx>(
636 tcx: TyCtxt<'tcx>,
637 fn_did: DefId,
638 args: GenericArgsRef<'tcx>,
639) -> HashSet<DefId> {
640 rap_debug!(
641 "get impls for fn: {:?} args: {:?}",
642 tcx.def_path_str_with_args(fn_did, args),
643 args
644 );
645 let mut impls = HashSet::new();
646 let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
647 for (pred, _) in preds {
648 if let Some(trait_pred) = pred.as_trait_clause() {
649 let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> = tcx
650 .liberate_late_bound_regions(fn_did, trait_pred)
651 .trait_ref;
652
653 let res = tcx.codegen_select_candidate(
654 TypingEnv::fully_monomorphized().as_query_input(trait_ref),
655 );
656 if let Ok(source) = res {
657 match source {
658 ImplSource::UserDefined(data) => {
659 if data.impl_def_id.is_local() {
660 impls.insert(data.impl_def_id);
661 }
662 }
663 _ => {}
664 }
665 }
666 }
668 }
669 rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
670 impls
671}