1919// PassOptions structure; see more details there.
2020//
2121
22+ #include < ranges>
23+
2224#include " ir/effects.h"
2325#include " ir/module-utils.h"
2426#include " ir/subtypes.h"
@@ -94,7 +96,7 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
9496 } else if (auto * callIndirect = curr->dynCast <CallIndirect>()) {
9597 type = callIndirect->heapType ;
9698 } else {
97- Fatal () << " Unexpected call type" ;
99+ WASM_UNREACHABLE ( " Unexpected call type" ) ;
98100 }
99101
100102 funcInfo.indirectCalledTypes .insert (type);
@@ -123,7 +125,8 @@ using CallGraphNode = std::variant<Function*, HeapType>;
123125using CallGraph =
124126 std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;
125127
126- /* Build a call graph for indirect and direct calls.
128+ /*
129+ Build a call graph for indirect and direct calls.
127130
128131 key (caller) -> value (callee)
129132 Name -> Name : direct call
@@ -137,38 +140,46 @@ using CallGraph =
137140
138141 If we're running in an open world, we only include Name -> Name edges.
139142*/
140- CallGraph buildCallGraph (Module& module ,
143+ CallGraph buildCallGraph (const Module& module ,
141144 const std::map<Function*, FuncInfo>& funcInfos,
142145 bool closedWorld) {
143146 CallGraph callGraph;
144147
145148 std::unordered_set<HeapType> allFunctionTypes;
146149 for (const auto & [caller, callerInfo] : funcInfos) {
147150 auto & callees = callGraph[caller];
151+
152+ // Name -> Name
148153 for (Name calleeFunction : callerInfo.calledFunctions ) {
149154 callees.insert (module .getFunction (calleeFunction));
150155 }
151156
152- // In open world, just connect functions. Indirect calls are already handled
153- // by giving such functions unknown effects.
154157 if (!closedWorld) {
155158 continue ;
156159 }
157160
161+ // Name -> Type
158162 allFunctionTypes.insert (caller->type .getHeapType ());
159163 for (HeapType calleeType : callerInfo.indirectCalledTypes ) {
160164 callees.insert (calleeType);
161165 allFunctionTypes.insert (calleeType);
162166 }
167+
168+ // Type -> Name
163169 callGraph[caller->type .getHeapType ()].insert (caller);
164170 }
165171
166- SubTypes subtypes ( module );
172+ // Type -> Type
167173 for (HeapType type : allFunctionTypes) {
168- subtypes.iterSubTypes (type, [&callGraph, type](HeapType sub, auto _) {
169- callGraph[type].insert (sub);
170- return true ;
171- });
174+ // Not needed but during lookup we expect the key to exist.
175+ callGraph[type];
176+
177+ for (auto super = type.getDeclaredSuperType (); super;
178+ super = super->getDeclaredSuperType ()) {
179+ if (allFunctionTypes.contains (*super)) {
180+ callGraph[*super].insert (type);
181+ }
182+ }
172183 }
173184
174185 return callGraph;
@@ -187,6 +198,31 @@ void mergeMaybeEffects(std::optional<EffectAnalyzer>& dest,
187198 dest->mergeIn (*src);
188199}
189200
201+ template <std::ranges::common_range Range>
202+ requires std::same_as<std::ranges::range_value_t <Range>, CallGraphNode>
203+ struct CallGraphSCCs
204+ : SCCs<std::ranges::iterator_t <Range>, CallGraphSCCs<Range>> {
205+ const std::map<Function*, FuncInfo>& funcInfos;
206+ const CallGraph& callGraph;
207+ const Module& module ;
208+
209+ CallGraphSCCs (
210+ Range&& nodes,
211+ const std::map<Function*, FuncInfo>& funcInfos,
212+ const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
213+ callGraph,
214+ const Module& module )
215+ : SCCs<std::ranges::iterator_t <Range>, CallGraphSCCs<Range>>(
216+ std::ranges::begin (nodes), std::ranges::end(nodes)),
217+ funcInfos (funcInfos), callGraph(callGraph), module(module ) {}
218+
219+ void pushChildren (CallGraphNode node) {
220+ for (CallGraphNode callee : callGraph.at (node)) {
221+ this ->push (callee);
222+ }
223+ }
224+ };
225+
190226// Propagate effects from callees to callers transitively
191227// e.g. if A -> B -> C (A calls B which calls C)
192228// Then B inherits effects from C and A inherits effects from both B and C.
@@ -200,29 +236,6 @@ void propagateEffects(const Module& module,
200236 const PassOptions& passOptions,
201237 std::map<Function*, FuncInfo>& funcInfos,
202238 const CallGraph& callGraph) {
203- struct CallGraphSCCs
204- : SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs> {
205- const std::map<Function*, FuncInfo>& funcInfos;
206- const CallGraph& callGraph;
207- const Module& module ;
208-
209- CallGraphSCCs (
210- const std::vector<CallGraphNode>& nodes,
211- const std::map<Function*, FuncInfo>& funcInfos,
212- const std::unordered_map<CallGraphNode,
213- std::unordered_set<CallGraphNode>>& callGraph,
214- const Module& module )
215- : SCCs<std::vector<CallGraphNode>::const_iterator, CallGraphSCCs>(
216- nodes.begin(), nodes.end()),
217- funcInfos (funcInfos), callGraph(callGraph), module (module ) {}
218-
219- void pushChildren (CallGraphNode node) {
220- for (CallGraphNode callee : callGraph.at (node)) {
221- push (callee);
222- }
223- }
224- };
225-
226239 // We only care about Functions that are roots, not types
227240 // A type would be a root if a function exists with that type, but no-one
228241 // indirect calls the type.
@@ -231,11 +244,16 @@ void propagateEffects(const Module& module,
231244 allFuncs.push_back (func);
232245 }
233246
234- CallGraphSCCs sccs (allFuncs, funcInfos, callGraph, module );
247+ auto funcNodes = std::views::keys (callGraph) |
248+ std::views::filter ([](auto node) {
249+ return std::holds_alternative<Function*>(node);
250+ }) |
251+ std::views::common;
252+ CallGraphSCCs sccs (std::move (funcNodes), funcInfos, callGraph, module );
235253
236254 std::vector<std::optional<EffectAnalyzer>> componentEffects;
237255 // Points to an index in componentEffects
238- std::unordered_map<CallGraphNode, Index> funcComponents ;
256+ std::unordered_map<CallGraphNode, Index> nodeComponents ;
239257
240258 for (auto ccIterator : sccs) {
241259 std::optional<EffectAnalyzer>& ccEffects =
@@ -244,7 +262,7 @@ void propagateEffects(const Module& module,
244262
245263 std::vector<Function*> ccFuncs;
246264 for (CallGraphNode node : cc) {
247- funcComponents .emplace (node, componentEffects.size () - 1 );
265+ nodeComponents .emplace (node, componentEffects.size () - 1 );
248266 if (auto ** func = std::get_if<Function*>(&node)) {
249267 ccFuncs.push_back (*func);
250268 }
@@ -253,7 +271,7 @@ void propagateEffects(const Module& module,
253271 std::unordered_set<int > calleeSccs;
254272 for (CallGraphNode caller : cc) {
255273 for (CallGraphNode callee : callGraph.at (caller)) {
256- calleeSccs.insert (funcComponents .at (callee));
274+ calleeSccs.insert (nodeComponents .at (callee));
257275 }
258276 }
259277
0 commit comments