@@ -101,9 +101,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
101101 end
102102end
103103
104- # even when the allocation contains an uninitialized field, we try an extra effort to check
105- # if this load at `idx` have any "safe" `setfield!` calls that define the field
106- function has_safe_def (
104+ # even when the allocation contains an uninitialized field, we try an extra effort to
105+ # check if all loads have "safe" `setfield!` calls that define the uninitialized field
106+ function has_safe_def_for_undef_field (
107107 ir:: IRCode , domtree:: DomTree , allblocks:: Vector{Int} , du:: SSADefUse ,
108108 newidx:: Int , idx:: Int )
109109 def, _, _ = find_def_for_use (ir, domtree, allblocks, du, idx)
@@ -208,14 +208,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
208208end
209209
210210function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
211- @nospecialize (typeconstraint))
212- callback = function (@nospecialize (pi ), @nospecialize (idx))
213- if isa (pi , PiNode)
214- typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
211+ @nospecialize (typeconstraint), @nospecialize (callback = nothing ) )
212+ newcallback = function (@nospecialize (x ), @nospecialize (idx))
213+ if isa (x , PiNode)
214+ typeconstraint = typeintersect (typeconstraint, widenconst (x . typ))
215215 end
216+ callback === nothing || callback (x, idx)
216217 return false
217218 end
218- def = simple_walk (compact, defssa, callback )
219+ def = simple_walk (compact, defssa, newcallback )
219220 return Pair {Any, Any} (def, typeconstraint)
220221end
221222
225226Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
226227(pruning those leaves rules out by path conditions).
227228"""
228- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
229+ function walk_to_defs (compact:: IncrementalCompact ,
230+ @nospecialize (defssa), @nospecialize (typeconstraint),
231+ @nospecialize (callback = nothing ))
229232 visited_phinodes = AnySSAValue[]
230233 isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
231234 def = compact[defssa]
@@ -261,7 +264,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
261264 val = OldSSAValue (val. id)
262265 end
263266 if isa (val, AnySSAValue)
264- new_def, new_constraint = simple_walk_constraint (compact, val, typeconstraint)
267+ new_def, new_constraint = simple_walk_constraint (compact, val, typeconstraint, callback )
265268 if isa (new_def, AnySSAValue)
266269 if ! haskey (visited_constraints, new_def)
267270 push! (worklist_defs, new_def)
@@ -722,10 +725,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
722725 continue
723726 end
724727 if defuses === nothing
725- defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
728+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse, PhiDefs }} ()
726729 end
727- mid, defuse = get! (defuses, defidx) do
728- SPCSet (), SSADefUse ()
730+ mid, defuse, phidefs = get! (defuses, defidx) do
731+ SPCSet (), SSADefUse (), PhiDefs ( nothing )
729732 end
730733 push! (defuse. ccall_preserve_uses, idx)
731734 union! (mid, intermediaries)
@@ -780,16 +783,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
780783 # Mutable stuff here
781784 isa (def, SSAValue) || continue
782785 if defuses === nothing
783- defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
786+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse, PhiDefs }} ()
784787 end
785- mid, defuse = get! (defuses, def. id) do
786- SPCSet (), SSADefUse ()
788+ mid, defuse, phidefs = get! (defuses, def. id) do
789+ SPCSet (), SSADefUse (), PhiDefs ( nothing )
787790 end
788791 if is_setfield
789792 push! (defuse. defs, idx)
790793 else
791794 push! (defuse. uses, idx)
792795 end
796+ defval = compact[def]
797+ if isa (defval, PhiNode)
798+ phicallback = function (@nospecialize (x), @nospecialize (ssa))
799+ push! (intermediaries, ssa. id)
800+ return false
801+ end
802+ defs, _ = walk_to_defs (compact, def, struct_typ, phicallback)
803+ if _any (@nospecialize (d)-> ! isa (d, SSAValue), defs)
804+ delete! (defuses, def. id)
805+ continue
806+ end
807+ phidefs[] = Int[(def:: SSAValue ). id for def in defs]
808+ end
793809 union! (mid, intermediaries)
794810 end
795811 continue
@@ -849,43 +865,73 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
849865 end
850866end
851867
868+ # TODO :
869+ # - run mutable SROA on the same IR as when we collect information about mutable allocations
870+ # - simplify and improve the eliminability check below using an escape analysis
871+
872+ const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
873+
852874function sroa_mutables! (ir:: IRCode ,
853- defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
875+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs }} , used_ssas:: Vector{Int} ,
854876 nested_loads:: NestedLoads )
855877 domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
856878 nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
857879 any_eliminated = false
880+ eliminable_defs = nothing # tracks eliminable "definitions" if initialized
858881 # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
859- for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
882+ for (idx, (intermediaries, defuse, phidefs )) in sort! (collect (defuses); by= first, rev= true )
860883 intermediaries = collect (intermediaries)
884+ phidefs = phidefs[]
861885 # Check if there are any uses we did not account for. If so, the variable
862886 # escapes and we cannot eliminate the allocation. This works, because we're guaranteed
863887 # not to include any intermediaries that have dead uses. As a result, missing uses will only ever
864888 # show up in the nuses_total count.
865- nleaves = length (defuse. uses) + length (defuse. defs) + length (defuse. ccall_preserve_uses)
889+ nleaves = count_leaves (defuse)
890+ if phidefs != = nothing
891+ # if this defines ϕ, we also track leaves of all predecessors as well
892+ # FIXME this doesn't work when any predecessor is used by another ϕ-node
893+ for pidx in phidefs
894+ haskey (defuses, pidx) || continue
895+ pdefuse = defuses[pidx][2 ]
896+ nleaves += count_leaves (pdefuse)
897+ end
898+ end
866899 nuses = 0
867900 for idx in intermediaries
868901 nuses += used_ssas[idx]
869902 end
870- nuses_total = used_ssas[idx] + nuses - length (intermediaries)
903+ nuses -= length (intermediaries)
904+ nuses_total = used_ssas[idx] + nuses
905+ if phidefs != = nothing
906+ for pidx in phidefs
907+ # NOTE we don't need to accout for intermediates for this predecessor here,
908+ # since they are already included in intermediates of this ϕ-node
909+ # FIXME this doesn't work when any predecessor is used by another ϕ-node
910+ nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
911+ end
912+ end
871913 nleaves == nuses_total || continue
872914 # Find the type for this allocation
873915 defexpr = ir[SSAValue (idx)]
874- isa (defexpr, Expr) || continue
875- if ! isexpr (defexpr, :new )
876- if is_known_call (defexpr, getfield, ir)
877- val = defexpr. args[2 ]
878- if isa (val, SSAValue)
879- struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
880- if ismutabletype (struct_typ)
881- record_nested_load! (nested_mloads, idx)
882- end
916+ if isa (defexpr, Expr)
917+ @assert phidefs === nothing
918+ if ! isexpr (defexpr, :new )
919+ maybe_record_nested_load! (nested_mloads, ir, idx)
920+ continue
921+ end
922+ elseif isa (defexpr, PhiNode)
923+ phidefs === nothing && continue
924+ for pidx in phidefs
925+ pexpr = ir[SSAValue (pidx)]
926+ if ! isexpr (pexpr, :new )
927+ maybe_record_nested_load! (nested_mloads, ir, pidx)
928+ @goto skip
883929 end
884930 end
931+ else
885932 continue
886933 end
887- newidx = idx
888- typ = ir. stmts[newidx][:type ]
934+ typ = ir. stmts[idx][:type ]
889935 if isa (typ, UnionAll)
890936 typ = unwrap_unionall (typ)
891937 end
@@ -897,25 +943,29 @@ function sroa_mutables!(ir::IRCode,
897943 fielddefuse = SSADefUse[SSADefUse () for _ = 1 : fieldcount (typ)]
898944 all_forwarded = true
899945 for use in defuse. uses
900- stmt = ir[ SSAValue (use)] # == `getfield` call
901- # We may have discovered above that this use is dead
902- # after the getfield elim of immutables. In that case,
903- # it would have been deleted. That's fine, just ignore
904- # the use in that case.
905- if stmt === nothing
946+ eliminable = check_use_eliminability! (fielddefuse, ir, use, typ)
947+ if eliminable === nothing
948+ # We may have discovered above that this use is dead
949+ # after the getfield elim of immutables. In that case,
950+ # it would have been deleted. That's fine, just ignore
951+ # the use in that case.
906952 all_forwarded = false
907953 continue
954+ elseif ! eliminable
955+ @goto skip
908956 end
909- field = try_compute_fieldidx_stmt (ir, stmt:: Expr , typ)
910- field === nothing && @goto skip
911- push! (fielddefuse[field]. uses, use)
912957 end
913958 for def in defuse. defs
914- stmt = ir[SSAValue (def)]:: Expr # == `setfield!` call
915- field = try_compute_fieldidx_stmt (ir, stmt, typ)
916- field === nothing && @goto skip
917- isconst (typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
918- push! (fielddefuse[field]. defs, def)
959+ check_def_eliminability! (fielddefuse, ir, def, typ) || @goto skip
960+ end
961+ if phidefs != = nothing
962+ for pidx in phidefs
963+ haskey (defuses, pidx) || continue
964+ pdefuse = defuses[pidx][2 ]
965+ for pdef in pdefuse. defs
966+ check_def_eliminability! (fielddefuse, ir, pdef, typ) || @goto skip
967+ end
968+ end
919969 end
920970 # Check that the defexpr has defined values for all the fields
921971 # we're accessing. In the future, we may want to relax this,
@@ -926,7 +976,13 @@ function sroa_mutables!(ir::IRCode,
926976 for fidx in 1 : ndefuse
927977 du = fielddefuse[fidx]
928978 isempty (du. uses) && continue
929- push! (du. defs, newidx)
979+ if phidefs === nothing
980+ push! (du. defs, idx)
981+ else
982+ for pidx in phidefs
983+ push! (du. defs, pidx)
984+ end
985+ end
930986 ldu = compute_live_ins (ir. cfg, du)
931987 if isempty (ldu. live_in_bbs)
932988 phiblocks = Int[]
@@ -936,10 +992,24 @@ function sroa_mutables!(ir::IRCode,
936992 end
937993 allblocks = sort (vcat (phiblocks, ldu. def_bbs))
938994 blocks[fidx] = phiblocks, allblocks
939- if fidx + 1 > length (defexpr. args)
940- for use in du. uses
995+ if phidefs != = nothing
996+ # check if all predecessors have safe definitions
997+ for pidx in phidefs
998+ newexpr = ir[SSAValue (pidx)]:: Expr # == new(...)
999+ if fidx + 1 > length (newexpr. args) # this field can be undefined
1000+ domtree === nothing && (@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks))
1001+ for use in du. uses
1002+ has_safe_def_for_undef_field (ir, domtree, allblocks, du, pidx, use) || @goto skip
1003+ end
1004+ end
1005+ end
1006+ else
1007+ newexpr = defexpr:: Expr # == new(...)
1008+ if fidx + 1 > length (newexpr. args) # this field can be undefined
9411009 domtree === nothing && (@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks))
942- has_safe_def (ir, domtree, allblocks, du, newidx, use) || @goto skip
1010+ for use in du. uses
1011+ has_safe_def_for_undef_field (ir, domtree, allblocks, du, idx, use) || @goto skip
1012+ end
9431013 end
9441014 end
9451015 end
@@ -984,28 +1054,79 @@ function sroa_mutables!(ir::IRCode,
9841054 end
9851055 end
9861056 end
987- for stmt in du. defs
988- stmt == newidx && continue
989- ir[SSAValue (stmt)] = nothing
1057+ eliminable_defs === nothing && (eliminable_defs = SPCSet ())
1058+ for def in du. defs
1059+ push! (eliminable_defs, def)
1060+ end
1061+ if phidefs != = nothing
1062+ # record ϕ-node itself eliminable here, since we didn't include it in `du.defs`
1063+ # we also modify usage counts of its predecessors so that their SROA may work
1064+ # in succeeding iteration
1065+ push! (eliminable_defs, idx)
1066+ for pidx in phidefs
1067+ used_ssas[pidx] -= 1
1068+ end
9901069 end
9911070 end
9921071 preserve_uses === nothing && continue
9931072 if all_forwarded
9941073 # this means all ccall preserves have been replaced with forwarded loads
9951074 # so we can potentially eliminate the allocation, otherwise we must preserve
9961075 # the whole allocation.
997- push! (intermediaries, newidx )
1076+ push! (intermediaries, idx )
9981077 end
9991078 # Insert the new preserves
10001079 for (use, new_preserves) in preserve_uses
10011080 ir[SSAValue (use)] = form_new_preserves (ir[SSAValue (use)]:: Expr , intermediaries, new_preserves)
10021081 end
1003-
10041082 @label skip
10051083 end
1084+ # now eliminate "definitions" (i.e. allocations, ϕ-nodes, and `setfield!` calls)
1085+ # that should have no usage at this moment
1086+ if eliminable_defs != = nothing
1087+ for idx in eliminable_defs
1088+ ir[SSAValue (idx)] = nothing
1089+ end
1090+ end
10061091 return any_eliminated ? sroa_pass! (compact! (ir), false ) : ir
10071092end
10081093
1094+ count_leaves (defuse:: SSADefUse ) =
1095+ length (defuse. uses) + length (defuse. defs) + length (defuse. ccall_preserve_uses)
1096+
1097+ function maybe_record_nested_load! (nested_mloads:: NestedLoads , ir:: IRCode , idx:: Int )
1098+ defexpr = ir[SSAValue (idx)]
1099+ if is_known_call (defexpr, getfield, ir)
1100+ val = defexpr. args[2 ]
1101+ if isa (val, SSAValue)
1102+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
1103+ if ismutabletype (struct_typ)
1104+ record_nested_load! (nested_mloads, idx)
1105+ end
1106+ end
1107+ end
1108+ end
1109+
1110+ function check_use_eliminability! (fielddefuse:: Vector{SSADefUse} ,
1111+ ir:: IRCode , useidx:: Int , struct_typ:: DataType )
1112+ stmt = ir[SSAValue (useidx)] # == `getfield` call
1113+ stmt === nothing && return nothing
1114+ field = try_compute_fieldidx_stmt (ir, stmt:: Expr , struct_typ)
1115+ field === nothing && return false
1116+ push! (fielddefuse[field]. uses, useidx)
1117+ return true
1118+ end
1119+
1120+ function check_def_eliminability! (fielddefuse:: Vector{SSADefUse} ,
1121+ ir:: IRCode , defidx:: Int , struct_typ:: DataType )
1122+ stmt = ir[SSAValue (defidx)]:: Expr # == `setfield!` call
1123+ field = try_compute_fieldidx_stmt (ir, stmt, struct_typ)
1124+ field === nothing && return false
1125+ isconst (struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1126+ push! (fielddefuse[field]. defs, defidx)
1127+ return true
1128+ end
1129+
10091130function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
10101131 newex = Expr (:foreigncall )
10111132 nccallargs = length (origex. args[3 ]:: SimpleVector )
0 commit comments