Index: head/contrib/llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- head/contrib/llvm/include/llvm/Analysis/ScalarEvolution.h (revision 312831) +++ head/contrib/llvm/include/llvm/Analysis/ScalarEvolution.h (revision 312832) @@ -1,1804 +1,1823 @@ //===- llvm/Analysis/ScalarEvolution.h - Scalar Evolution -------*- C++ -*-===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // The ScalarEvolution class is an LLVM pass which can be used to analyze and // categorize scalar expressions in loops. It specializes in recognizing // general induction variables, representing them with the abstract and opaque // SCEV class. Given this analysis, trip counts of loops and other important // properties can be obtained. // // This analysis is primarily useful for induction variable substitution and // strength reduction. // //===----------------------------------------------------------------------===// #ifndef LLVM_ANALYSIS_SCALAREVOLUTION_H #define LLVM_ANALYSIS_SCALAREVOLUTION_H #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/ValueMap.h" #include "llvm/Pass.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/DataTypes.h" namespace llvm { class APInt; class AssumptionCache; class Constant; class ConstantInt; class DominatorTree; class Type; class ScalarEvolution; class DataLayout; class TargetLibraryInfo; class LLVMContext; class Operator; class SCEV; class SCEVAddRecExpr; class SCEVConstant; class SCEVExpander; class SCEVPredicate; class SCEVUnknown; class Function; template <> struct FoldingSetTrait; template <> struct FoldingSetTrait; /// This class represents an analyzed expression in the program. These are /// opaque objects that the client is not allowed to do much with directly. /// class SCEV : public FoldingSetNode { friend struct FoldingSetTrait; /// A reference to an Interned FoldingSetNodeID for this node. The /// ScalarEvolution's BumpPtrAllocator holds the data. FoldingSetNodeIDRef FastID; // The SCEV baseclass this node corresponds to const unsigned short SCEVType; protected: /// This field is initialized to zero and may be used in subclasses to store /// miscellaneous information. unsigned short SubclassData; private: SCEV(const SCEV &) = delete; void operator=(const SCEV &) = delete; public: /// NoWrapFlags are bitfield indices into SubclassData. /// /// Add and Mul expressions may have no-unsigned-wrap or /// no-signed-wrap properties, which are derived from the IR /// operator. NSW is a misnomer that we use to mean no signed overflow or /// underflow. /// /// AddRec expressions may have a no-self-wraparound property if, in /// the integer domain, abs(step) * max-iteration(loop) <= /// unsigned-max(bitwidth). This means that the recurrence will never reach /// its start value if the step is non-zero. Computing the same value on /// each iteration is not considered wrapping, and recurrences with step = 0 /// are trivially . is independent of the sign of step and the /// value the add recurrence starts with. /// /// Note that NUW and NSW are also valid properties of a recurrence, and /// either implies NW. For convenience, NW will be set for a recurrence /// whenever either NUW or NSW are set. enum NoWrapFlags { FlagAnyWrap = 0, // No guarantee. FlagNW = (1 << 0), // No self-wrap. FlagNUW = (1 << 1), // No unsigned wrap. FlagNSW = (1 << 2), // No signed wrap. NoWrapMask = (1 << 3) -1 }; explicit SCEV(const FoldingSetNodeIDRef ID, unsigned SCEVTy) : FastID(ID), SCEVType(SCEVTy), SubclassData(0) {} unsigned getSCEVType() const { return SCEVType; } /// Return the LLVM type of this SCEV expression. /// Type *getType() const; /// Return true if the expression is a constant zero. /// bool isZero() const; /// Return true if the expression is a constant one. /// bool isOne() const; /// Return true if the expression is a constant all-ones value. /// bool isAllOnesValue() const; /// Return true if the specified scev is negated, but not a constant. bool isNonConstantNegative() const; /// Print out the internal representation of this scalar to the specified /// stream. This should really only be used for debugging purposes. void print(raw_ostream &OS) const; /// This method is used for debugging. /// void dump() const; }; // Specialize FoldingSetTrait for SCEV to avoid needing to compute // temporary FoldingSetNodeID values. template<> struct FoldingSetTrait : DefaultFoldingSetTrait { static void Profile(const SCEV &X, FoldingSetNodeID& ID) { ID = X.FastID; } static bool Equals(const SCEV &X, const FoldingSetNodeID &ID, unsigned IDHash, FoldingSetNodeID &TempID) { return ID == X.FastID; } static unsigned ComputeHash(const SCEV &X, FoldingSetNodeID &TempID) { return X.FastID.ComputeHash(); } }; inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) { S.print(OS); return OS; } /// An object of this class is returned by queries that could not be answered. /// For example, if you ask for the number of iterations of a linked-list /// traversal loop, you will get one of these. None of the standard SCEV /// operations are valid on this class, it is just a marker. struct SCEVCouldNotCompute : public SCEV { SCEVCouldNotCompute(); /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S); }; /// This class represents an assumption made using SCEV expressions which can /// be checked at run-time. class SCEVPredicate : public FoldingSetNode { friend struct FoldingSetTrait; /// A reference to an Interned FoldingSetNodeID for this node. The /// ScalarEvolution's BumpPtrAllocator holds the data. FoldingSetNodeIDRef FastID; public: enum SCEVPredicateKind { P_Union, P_Equal, P_Wrap }; protected: SCEVPredicateKind Kind; ~SCEVPredicate() = default; SCEVPredicate(const SCEVPredicate&) = default; SCEVPredicate &operator=(const SCEVPredicate&) = default; public: SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind); SCEVPredicateKind getKind() const { return Kind; } /// Returns the estimated complexity of this predicate. This is roughly /// measured in the number of run-time checks required. virtual unsigned getComplexity() const { return 1; } /// Returns true if the predicate is always true. This means that no /// assumptions were made and nothing needs to be checked at run-time. virtual bool isAlwaysTrue() const = 0; /// Returns true if this predicate implies \p N. virtual bool implies(const SCEVPredicate *N) const = 0; /// Prints a textual representation of this predicate with an indentation of /// \p Depth. virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0; /// Returns the SCEV to which this predicate applies, or nullptr if this is /// a SCEVUnionPredicate. virtual const SCEV *getExpr() const = 0; }; inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) { P.print(OS); return OS; } // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute // temporary FoldingSetNodeID values. template <> struct FoldingSetTrait : DefaultFoldingSetTrait { static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) { ID = X.FastID; } static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID, unsigned IDHash, FoldingSetNodeID &TempID) { return ID == X.FastID; } static unsigned ComputeHash(const SCEVPredicate &X, FoldingSetNodeID &TempID) { return X.FastID.ComputeHash(); } }; /// This class represents an assumption that two SCEV expressions are equal, /// and this can be checked at run-time. We assume that the left hand side is /// a SCEVUnknown and the right hand side a constant. class SCEVEqualPredicate final : public SCEVPredicate { /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a /// constant. const SCEVUnknown *LHS; const SCEVConstant *RHS; public: SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS, const SCEVConstant *RHS); /// Implementation of the SCEVPredicate interface bool implies(const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; const SCEV *getExpr() const override; /// Returns the left hand side of the equality. const SCEVUnknown *getLHS() const { return LHS; } /// Returns the right hand side of the equality. const SCEVConstant *getRHS() const { return RHS; } /// Methods for support type inquiry through isa, cast, and dyn_cast: static inline bool classof(const SCEVPredicate *P) { return P->getKind() == P_Equal; } }; /// This class represents an assumption made on an AddRec expression. Given an /// affine AddRec expression {a,+,b}, we assume that it has the nssw or nusw /// flags (defined below) in the first X iterations of the loop, where X is a /// SCEV expression returned by getPredicatedBackedgeTakenCount). /// /// Note that this does not imply that X is equal to the backedge taken /// count. This means that if we have a nusw predicate for i32 {0,+,1} with a /// predicated backedge taken count of X, we only guarantee that {0,+,1} has /// nusw in the first X iterations. {0,+,1} may still wrap in the loop if we /// have more than X iterations. class SCEVWrapPredicate final : public SCEVPredicate { public: /// Similar to SCEV::NoWrapFlags, but with slightly different semantics /// for FlagNUSW. The increment is considered to be signed, and a + b /// (where b is the increment) is considered to wrap if: /// zext(a + b) != zext(a) + sext(b) /// /// If Signed is a function that takes an n-bit tuple and maps to the /// integer domain as the tuples value interpreted as twos complement, /// and Unsigned a function that takes an n-bit tuple and maps to the /// integer domain as as the base two value of input tuple, then a + b /// has IncrementNUSW iff: /// /// 0 <= Unsigned(a) + Signed(b) < 2^n /// /// The IncrementNSSW flag has identical semantics with SCEV::FlagNSW. /// /// Note that the IncrementNUSW flag is not commutative: if base + inc /// has IncrementNUSW, then inc + base doesn't neccessarily have this /// property. The reason for this is that this is used for sign/zero /// extending affine AddRec SCEV expressions when a SCEVWrapPredicate is /// assumed. A {base,+,inc} expression is already non-commutative with /// regards to base and inc, since it is interpreted as: /// (((base + inc) + inc) + inc) ... enum IncrementWrapFlags { IncrementAnyWrap = 0, // No guarantee. IncrementNUSW = (1 << 0), // No unsigned with signed increment wrap. IncrementNSSW = (1 << 1), // No signed with signed increment wrap // (equivalent with SCEV::NSW) IncrementNoWrapMask = (1 << 2) - 1 }; /// Convenient IncrementWrapFlags manipulation methods. static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags) { assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); assert((OffFlags & IncrementNoWrapMask) == OffFlags && "Invalid flags value!"); return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & ~OffFlags); } static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT maskFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, int Mask) { assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); assert((Mask & IncrementNoWrapMask) == Mask && "Invalid mask value!"); return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & Mask); } static SCEVWrapPredicate::IncrementWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags) { assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); assert((OnFlags & IncrementNoWrapMask) == OnFlags && "Invalid flags value!"); return (SCEVWrapPredicate::IncrementWrapFlags)(Flags | OnFlags); } /// Returns the set of SCEVWrapPredicate no wrap flags implied by a /// SCEVAddRecExpr. static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE); private: const SCEVAddRecExpr *AR; IncrementWrapFlags Flags; public: explicit SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags); /// Returns the set assumed no overflow flags. IncrementWrapFlags getFlags() const { return Flags; } /// Implementation of the SCEVPredicate interface const SCEV *getExpr() const override; bool implies(const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; /// Methods for support type inquiry through isa, cast, and dyn_cast: static inline bool classof(const SCEVPredicate *P) { return P->getKind() == P_Wrap; } }; /// This class represents a composition of other SCEV predicates, and is the /// class that most clients will interact with. This is equivalent to a /// logical "AND" of all the predicates in the union. class SCEVUnionPredicate final : public SCEVPredicate { private: typedef DenseMap> PredicateMap; /// Vector with references to all predicates in this union. SmallVector Preds; /// Maps SCEVs to predicates for quick look-ups. PredicateMap SCEVToPreds; public: SCEVUnionPredicate(); const SmallVectorImpl &getPredicates() const { return Preds; } /// Adds a predicate to this union. void add(const SCEVPredicate *N); /// Returns a reference to a vector containing all predicates which apply to /// \p Expr. ArrayRef getPredicatesForExpr(const SCEV *Expr); /// Implementation of the SCEVPredicate interface bool isAlwaysTrue() const override; bool implies(const SCEVPredicate *N) const override; void print(raw_ostream &OS, unsigned Depth) const override; const SCEV *getExpr() const override; /// We estimate the complexity of a union predicate as the size number of /// predicates in the union. unsigned getComplexity() const override { return Preds.size(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: static inline bool classof(const SCEVPredicate *P) { return P->getKind() == P_Union; } }; /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. class ScalarEvolution { public: /// An enum describing the relationship between a SCEV and a loop. enum LoopDisposition { LoopVariant, ///< The SCEV is loop-variant (unknown). LoopInvariant, ///< The SCEV is loop-invariant. LoopComputable ///< The SCEV varies predictably with the loop. }; /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { DoesNotDominateBlock, ///< The SCEV does not dominate the block. DominatesBlock, ///< The SCEV dominates the block. ProperlyDominatesBlock ///< The SCEV properly dominates the block. }; /// Convenient NoWrapFlags manipulation that hides enum casts and is /// visible in the ScalarEvolution name space. static SCEV::NoWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT maskFlags(SCEV::NoWrapFlags Flags, int Mask) { return (SCEV::NoWrapFlags)(Flags & Mask); } static SCEV::NoWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags) { return (SCEV::NoWrapFlags)(Flags | OnFlags); } static SCEV::NoWrapFlags LLVM_ATTRIBUTE_UNUSED_RESULT clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags) { return (SCEV::NoWrapFlags)(Flags & ~OffFlags); } private: /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a /// Value is deleted. class SCEVCallbackVH final : public CallbackVH { ScalarEvolution *SE; void deleted() override; void allUsesReplacedWith(Value *New) override; public: SCEVCallbackVH(Value *V, ScalarEvolution *SE = nullptr); }; friend class SCEVCallbackVH; friend class SCEVExpander; friend class SCEVUnknown; /// The function we are analyzing. /// Function &F; /// Does the module have any calls to the llvm.experimental.guard intrinsic /// at all? If this is false, we avoid doing work that will only help if /// thare are guards present in the IR. /// bool HasGuards; /// The target library information for the target we are targeting. /// TargetLibraryInfo &TLI; /// The tracker for @llvm.assume intrinsics in this function. AssumptionCache &AC; /// The dominator tree. /// DominatorTree &DT; /// The loop information for the function we are currently analyzing. /// LoopInfo &LI; /// This SCEV is used to represent unknown trip counts and things. std::unique_ptr CouldNotCompute; /// The typedef for HasRecMap. /// typedef DenseMap HasRecMapType; /// This is a cache to record whether a SCEV contains any scAddRecExpr. HasRecMapType HasRecMap; /// The typedef for ExprValueMap. /// - typedef DenseMap> ExprValueMapType; + typedef std::pair ValueOffsetPair; + typedef DenseMap> ExprValueMapType; /// ExprValueMap -- This map records the original values from which /// the SCEV expr is generated from. + /// + /// We want to represent the mapping as SCEV -> ValueOffsetPair instead + /// of SCEV -> Value: + /// Suppose we know S1 expands to V1, and + /// S1 = S2 + C_a + /// S3 = S2 + C_b + /// where C_a and C_b are different SCEVConstants. Then we'd like to + /// expand S3 as V1 - C_a + C_b instead of expanding S2 literally. + /// It is helpful when S2 is a complex SCEV expr. + /// + /// In order to do that, we represent ExprValueMap as a mapping from + /// SCEV to ValueOffsetPair. We will save both S1->{V1, 0} and + /// S2->{V1, C_a} into the map when we create SCEV for V1. When S3 + /// is expanded, it will first expand S2 to V1 - C_a because of + /// S2->{V1, C_a} in the map, then expand S3 to V1 - C_a + C_b. + /// + /// Note: S->{V, Offset} in the ExprValueMap means S can be expanded + /// to V - Offset. ExprValueMapType ExprValueMap; /// The typedef for ValueExprMap. /// typedef DenseMap > ValueExprMapType; /// This is a cache of the values we have analyzed so far. /// ValueExprMapType ValueExprMap; /// Mark predicate values currently being processed by isImpliedCond. DenseSet PendingLoopPredicates; /// Set to true by isLoopBackedgeGuardedByCond when we're walking the set of /// conditions dominating the backedge of a loop. bool WalkingBEDominatingConds; /// Set to true by isKnownPredicateViaSplitting when we're trying to prove a /// predicate by splitting it into a set of independent predicates. bool ProvingSplitPredicate; /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary /// pair of exact and max expressions that are eventually summarized in /// ExitNotTakenInfo and BackedgeTakenInfo. struct ExitLimit { const SCEV *Exact; const SCEV *Max; /// A predicate union guard for this ExitLimit. The result is only /// valid if this predicate evaluates to 'true' at run-time. SCEVUnionPredicate Pred; /*implicit*/ ExitLimit(const SCEV *E) : Exact(E), Max(E) {} ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P) : Exact(E), Max(M), Pred(P) { assert((isa(Exact) || !isa(Max)) && "Exact is not allowed to be less precise than Max"); } /// Test whether this ExitLimit contains any computed information, or /// whether it's all SCEVCouldNotCompute values. bool hasAnyInfo() const { return !isa(Exact) || !isa(Max); } /// Test whether this ExitLimit contains all information. bool hasFullInfo() const { return !isa(Exact); } }; /// Forward declaration of ExitNotTakenExtras struct ExitNotTakenExtras; /// Information about the number of times a particular loop exit may be /// reached before exiting the loop. struct ExitNotTakenInfo { AssertingVH ExitingBlock; const SCEV *ExactNotTaken; ExitNotTakenExtras *ExtraInfo; bool Complete; ExitNotTakenInfo() : ExitingBlock(nullptr), ExactNotTaken(nullptr), ExtraInfo(nullptr), Complete(true) {} ExitNotTakenInfo(BasicBlock *ExitBlock, const SCEV *Expr, ExitNotTakenExtras *Ptr) : ExitingBlock(ExitBlock), ExactNotTaken(Expr), ExtraInfo(Ptr), Complete(true) {} /// Return true if all loop exits are computable. bool isCompleteList() const { return Complete; } /// Sets the incomplete property, indicating that one of the loop exits /// doesn't have a corresponding ExitNotTakenInfo entry. void setIncomplete() { Complete = false; } /// Returns a pointer to the predicate associated with this information, /// or nullptr if this doesn't exist (meaning always true). SCEVUnionPredicate *getPred() const { if (ExtraInfo) return &ExtraInfo->Pred; return nullptr; } /// Return true if the SCEV predicate associated with this information /// is always true. bool hasAlwaysTruePred() const { return !getPred() || getPred()->isAlwaysTrue(); } /// Defines a simple forward iterator for ExitNotTakenInfo. class ExitNotTakenInfoIterator : public std::iterator { const ExitNotTakenInfo *Start; unsigned Position; public: ExitNotTakenInfoIterator(const ExitNotTakenInfo *Start, unsigned Position) : Start(Start), Position(Position) {} const ExitNotTakenInfo &operator*() const { if (Position == 0) return *Start; return Start->ExtraInfo->Exits[Position - 1]; } const ExitNotTakenInfo *operator->() const { if (Position == 0) return Start; return &Start->ExtraInfo->Exits[Position - 1]; } bool operator==(const ExitNotTakenInfoIterator &RHS) const { return Start == RHS.Start && Position == RHS.Position; } bool operator!=(const ExitNotTakenInfoIterator &RHS) const { return Start != RHS.Start || Position != RHS.Position; } ExitNotTakenInfoIterator &operator++() { // Preincrement if (!Start) return *this; unsigned Elements = Start->ExtraInfo ? Start->ExtraInfo->Exits.size() + 1 : 1; ++Position; // We've run out of elements. if (Position == Elements) { Start = nullptr; Position = 0; } return *this; } ExitNotTakenInfoIterator operator++(int) { // Postincrement ExitNotTakenInfoIterator Tmp = *this; ++*this; return Tmp; } }; /// Iterators ExitNotTakenInfoIterator begin() const { return ExitNotTakenInfoIterator(this, 0); } ExitNotTakenInfoIterator end() const { return ExitNotTakenInfoIterator(nullptr, 0); } }; /// Describes the extra information that a ExitNotTakenInfo can have. struct ExitNotTakenExtras { /// The predicate associated with the ExitNotTakenInfo struct. SCEVUnionPredicate Pred; /// The extra exits in the loop. Only the ExitNotTakenExtras structure /// pointed to by the first ExitNotTakenInfo struct (associated with the /// first loop exit) will populate this vector to prevent having /// redundant information. SmallVector Exits; }; /// A struct containing the information attached to a backedge. struct EdgeInfo { EdgeInfo(BasicBlock *Block, const SCEV *Taken, SCEVUnionPredicate &P) : ExitBlock(Block), Taken(Taken), Pred(std::move(P)) {} /// The exit basic block. BasicBlock *ExitBlock; /// The (exact) number of time we take the edge back. const SCEV *Taken; /// The SCEV predicated associated with Taken. If Pred doesn't evaluate /// to true, the information in Taken is not valid (or equivalent with /// a CouldNotCompute. SCEVUnionPredicate Pred; }; /// Information about the backedge-taken count of a loop. This currently /// includes an exact count and a maximum count. /// class BackedgeTakenInfo { /// A list of computable exits and their not-taken counts. Loops almost /// never have more than one computable exit. ExitNotTakenInfo ExitNotTaken; /// An expression indicating the least maximum backedge-taken count of the /// loop that is known, or a SCEVCouldNotCompute. This expression is only /// valid if the predicates associated with all loop exits are true. const SCEV *Max; public: BackedgeTakenInfo() : Max(nullptr) {} /// Initialize BackedgeTakenInfo from a list of exact exit counts. BackedgeTakenInfo(SmallVectorImpl &ExitCounts, bool Complete, const SCEV *MaxCount); /// Test whether this BackedgeTakenInfo contains any computed information, /// or whether it's all SCEVCouldNotCompute values. bool hasAnyInfo() const { return ExitNotTaken.ExitingBlock || !isa(Max); } /// Test whether this BackedgeTakenInfo contains complete information. bool hasFullInfo() const { return ExitNotTaken.isCompleteList(); } /// Return an expression indicating the exact backedge-taken count of the /// loop if it is known or SCEVCouldNotCompute otherwise. This is the /// number of times the loop header can be guaranteed to execute, minus /// one. /// /// If the SCEV predicate associated with the answer can be different /// from AlwaysTrue, we must add a (non null) Predicates argument. /// The SCEV predicate associated with the answer will be added to /// Predicates. A run-time check needs to be emitted for the SCEV /// predicate in order for the answer to be valid. /// /// Note that we should always know if we need to pass a predicate /// argument or not from the way the ExitCounts vector was computed. /// If we allowed SCEV predicates to be generated when populating this /// vector, this information can contain them and therefore a /// SCEVPredicate argument should be added to getExact. const SCEV *getExact(ScalarEvolution *SE, SCEVUnionPredicate *Predicates = nullptr) const; /// Return the number of times this loop exit may fall through to the back /// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via /// this block before this number of iterations, but may exit via another /// block. const SCEV *getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const; /// Get the max backedge taken count for the loop. const SCEV *getMax(ScalarEvolution *SE) const; /// Return true if any backedge taken count expressions refer to the given /// subexpression. bool hasOperand(const SCEV *S, ScalarEvolution *SE) const; /// Invalidate this result and free associated memory. void clear(); }; /// Cache the backedge-taken count of the loops for this function as they /// are computed. DenseMap BackedgeTakenCounts; /// Cache the predicated backedge-taken count of the loops for this /// function as they are computed. DenseMap PredicatedBackedgeTakenCounts; /// This map contains entries for all of the PHI instructions that we /// attempt to compute constant evolutions for. This allows us to avoid /// potentially expensive recomputation of these properties. An instruction /// maps to null if we are unable to compute its exit value. DenseMap ConstantEvolutionLoopExitValue; /// This map contains entries for all the expressions that we attempt to /// compute getSCEVAtScope information for, which can be expensive in /// extreme cases. DenseMap, 2> > ValuesAtScopes; /// Memoized computeLoopDisposition results. DenseMap, 2>> LoopDispositions; /// Cache for \c loopHasNoAbnormalExits. DenseMap LoopHasNoAbnormalExits; /// Returns true if \p L contains no instruction that can abnormally exit /// the loop (i.e. via throwing an exception, by terminating the thread /// cleanly or by infinite looping in a called function). Strictly /// speaking, the last one is not leaving the loop, but is identical to /// leaving the loop for reasoning about undefined behavior. bool loopHasNoAbnormalExits(const Loop *L); /// Compute a LoopDisposition value. LoopDisposition computeLoopDisposition(const SCEV *S, const Loop *L); /// Memoized computeBlockDisposition results. DenseMap< const SCEV *, SmallVector, 2>> BlockDispositions; /// Compute a BlockDisposition value. BlockDisposition computeBlockDisposition(const SCEV *S, const BasicBlock *BB); /// Memoized results from getRange DenseMap UnsignedRanges; /// Memoized results from getRange DenseMap SignedRanges; /// Used to parameterize getRange enum RangeSignHint { HINT_RANGE_UNSIGNED, HINT_RANGE_SIGNED }; /// Set the memoized range for the given SCEV. const ConstantRange &setRange(const SCEV *S, RangeSignHint Hint, const ConstantRange &CR) { DenseMap &Cache = Hint == HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; auto Pair = Cache.insert({S, CR}); if (!Pair.second) Pair.first->second = CR; return Pair.first->second; } /// Determine the range for a particular SCEV. ConstantRange getRange(const SCEV *S, RangeSignHint Hint); /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Stop}. /// Helper for \c getRange. ConstantRange getRangeForAffineAR(const SCEV *Start, const SCEV *Stop, const SCEV *MaxBECount, unsigned BitWidth); /// Try to compute a range for the affine SCEVAddRecExpr {\p Start,+,\p /// Stop} by "factoring out" a ternary expression from the add recurrence. /// Helper called by \c getRange. ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Stop, const SCEV *MaxBECount, unsigned BitWidth); /// We know that there is no SCEV for the specified value. Analyze the /// expression. const SCEV *createSCEV(Value *V); /// Provide the special handling we need to analyze PHI SCEVs. const SCEV *createNodeForPHI(PHINode *PN); /// Helper function called from createNodeForPHI. const SCEV *createAddRecFromPHI(PHINode *PN); /// Helper function called from createNodeForPHI. const SCEV *createNodeFromSelectLikePHI(PHINode *PN); /// Provide special handling for a select-like instruction (currently this /// is either a select instruction or a phi node). \p I is the instruction /// being processed, and it is assumed equivalent to "Cond ? TrueVal : /// FalseVal". const SCEV *createNodeForSelectOrPHI(Instruction *I, Value *Cond, Value *TrueVal, Value *FalseVal); /// Provide the special handling we need to analyze GEP SCEVs. const SCEV *createNodeForGEP(GEPOperator *GEP); /// Implementation code for getSCEVAtScope; called at most once for each /// SCEV+Loop pair. /// const SCEV *computeSCEVAtScope(const SCEV *S, const Loop *L); /// This looks up computed SCEV values for all instructions that depend on /// the given instruction and removes them from the ValueExprMap map if they /// reference SymName. This is used during PHI resolution. void forgetSymbolicName(Instruction *I, const SCEV *SymName); /// Return the BackedgeTakenInfo for the given loop, lazily computing new /// values if the loop hasn't been analyzed yet. The returned result is /// guaranteed not to be predicated. const BackedgeTakenInfo &getBackedgeTakenInfo(const Loop *L); /// Similar to getBackedgeTakenInfo, but will add predicates as required /// with the purpose of returning complete information. const BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L); /// Compute the number of times the specified loop will iterate. /// If AllowPredicates is set, we will create new SCEV predicates as /// necessary in order to return an exact answer. BackedgeTakenInfo computeBackedgeTakenCount(const Loop *L, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if it exits via the specified block. If AllowPredicates is set, /// this call will try to use a minimal set of SCEV predicates in order to /// return an exact answer. ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if its exit condition were a conditional branch of ExitCond, /// TBB, and FBB. /// /// \p ControlsExit is true if ExitCond directly controls the exit /// branch. In this case, we can assume that the loop exits only if the /// condition is true and can infer that failing to meet the condition prior /// to integer wraparound results in undefined behavior. /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, bool ControlsExit, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if its exit condition were a conditional branch of the ICmpInst /// ExitCond, TBB, and FBB. If AllowPredicates is set, this call will try /// to use a minimal set of SCEV predicates in order to return an exact /// answer. ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, bool IsSubExpr, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will /// execute if its exit condition were a switch with a single exiting case /// to ExitingBB. ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB, bool IsSubExpr); /// Given an exit condition of 'icmp op load X, cst', try to see if we can /// compute the backedge-taken count. ExitLimit computeLoadConstantCompareExitLimit(LoadInst *LI, Constant *RHS, const Loop *L, ICmpInst::Predicate p); /// Compute the exit limit of a loop that is controlled by a /// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip /// count in these cases (since SCEV has no way of expressing them), but we /// can still sometimes compute an upper bound. /// /// Return an ExitLimit for a loop whose backedge is guarded by `LHS Pred /// RHS`. ExitLimit computeShiftCompareExitLimit(Value *LHS, Value *RHS, const Loop *L, ICmpInst::Predicate Pred); /// If the loop is known to execute a constant number of times (the /// condition evolves only from constants), try to evaluate a few iterations /// of the loop until we get the exit condition gets a value of ExitWhen /// (true or false). If we cannot evaluate the exit count of the loop, /// return CouldNotCompute. const SCEV *computeExitCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen); /// Return the number of times an exit condition comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. /// If AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr, bool AllowPredicates = false); /// Return the number of times an exit condition checking the specified /// value for nonzero will execute. If not computable, return /// CouldNotCompute. ExitLimit howFarToNonZero(const SCEV *V, const Loop *L); /// Return the number of times an exit condition containing the specified /// less-than comparison will execute. If not computable, return /// CouldNotCompute. /// /// \p isSigned specifies whether the less-than is signed. /// /// \p ControlsExit is true when the LHS < RHS condition directly controls /// the branch (loops exits only if condition is true). In this case, we can /// use NoWrapFlags to skip overflow checks. /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool isSigned, bool ControlsExit, bool AllowPredicates = false); ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool isSigned, bool IsSubExpr, bool AllowPredicates = false); /// Return a predecessor of BB (which may not be an immediate predecessor) /// which has exactly one successor from which BB is reachable, or null if /// no such block is found. std::pair getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the given FoundCondValue value evaluates to true. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, bool Inverse); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is /// true. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. Utility function used by isImpliedCondOperands. Tries to get /// cases like "X `sgt` 0 => X - 1 `sgt` -1". bool isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS); /// Return true if the condition denoted by \p LHS \p Pred \p RHS is implied /// by a call to \c @llvm.experimental.guard in \p BB. bool isImpliedViaGuard(BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// /// This routine tries to rule out certain kinds of integer overflow, and /// then tries to reason about arithmetic properties of the predicates. bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS); /// If we know that the specified Phi is in the header of its containing /// loop, we know the loop executes a constant number of times, and the PHI /// node is just a recurrence involving constants, fold it. Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L); /// Test if the given expression is known to satisfy the condition described /// by Pred and the known constant ranges of LHS and RHS. /// bool isKnownPredicateViaConstantRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Try to prove the condition described by "LHS Pred RHS" by ruling out /// integer overflow. /// /// For instance, this will return true for "A s< (A + C)" if C is /// positive. bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Try to split Pred LHS RHS into logical conjunctions (and's) and try to /// prove them individually. bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Try to match the Expr as "(L + R)". bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags); /// Return true if More == (Less + C), where C is a constant. This is /// intended to be used as a cheaper substitute for full SCEV subtraction. bool computeConstantDifference(const SCEV *Less, const SCEV *More, APInt &C); /// Drop memoized information computed for S. void forgetMemoizedResults(const SCEV *S); /// Return an existing SCEV for V if there is one, otherwise return nullptr. const SCEV *getExistingSCEV(Value *V); /// Return false iff given SCEV contains a SCEVUnknown with NULL value- /// pointer. bool checkValidity(const SCEV *S) const; /// Return true if `ExtendOpTy`({`Start`,+,`Step`}) can be proved to be /// equal to {`ExtendOpTy`(`Start`),+,`ExtendOpTy`(`Step`)}. This is /// equivalent to proving no signed (resp. unsigned) wrap in /// {`Start`,+,`Step`} if `ExtendOpTy` is `SCEVSignExtendExpr` /// (resp. `SCEVZeroExtendExpr`). /// template bool proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step, const Loop *L); /// Try to prove NSW or NUW on \p AR relying on ConstantRange manipulation. SCEV::NoWrapFlags proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR); bool isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing); /// Return true if, for all loop invariant X, the predicate "LHS `Pred` X" /// is monotonically increasing or decreasing. In the former case set /// `Increasing` to true and in the latter case set `Increasing` to false. /// /// A predicate is said to be monotonically increasing if may go from being /// false to being true as the loop iterates, but never the other way /// around. A predicate is said to be monotonically decreasing if may go /// from being true to being false as the loop iterates, but never the other /// way around. bool isMonotonicPredicate(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing); /// Return SCEV no-wrap flags that can be proven based on reasoning about /// how poison produced from no-wrap flags on this value (e.g. a nuw add) /// would trigger undefined behavior on overflow. SCEV::NoWrapFlags getNoWrapFlagsFromUB(const Value *V); /// Return true if the SCEV corresponding to \p I is never poison. Proving /// this is more complex than proving that just \p I is never poison, since /// SCEV commons expressions across control flow, and you can have cases /// like: /// /// idx0 = a + b; /// ptr[idx0] = 100; /// if () { /// idx1 = a +nsw b; /// ptr[idx1] = 200; /// } /// /// where the SCEV expression (+ a b) is guaranteed to not be poison (and /// hence not sign-overflow) only if "" is true. Since both /// `idx0` and `idx1` will be mapped to the same SCEV expression, (+ a b), /// it is not okay to annotate (+ a b) with in the above example. bool isSCEVExprNeverPoison(const Instruction *I); /// This is like \c isSCEVExprNeverPoison but it specifically works for /// instructions that will get mapped to SCEV add recurrences. Return true /// if \p I will never generate poison under the assumption that \p I is an /// add recurrence on the loop \p L. bool isAddRecNeverPoison(const Instruction *I, const Loop *L); public: ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI); ~ScalarEvolution(); ScalarEvolution(ScalarEvolution &&Arg); LLVMContext &getContext() const { return F.getContext(); } /// Test if values of the given type are analyzable within the SCEV /// framework. This primarily includes integer types, and it can optionally /// include pointer types if the ScalarEvolution class has access to /// target-specific information. bool isSCEVable(Type *Ty) const; /// Return the size in bits of the specified type, for which isSCEVable must /// return true. uint64_t getTypeSizeInBits(Type *Ty) const; /// Return a type with the same bitwidth as the given type and which /// represents how SCEV will treat the given type, for which isSCEVable must /// return true. For pointer types, this is the pointer-sized integer type. Type *getEffectiveSCEVType(Type *Ty) const; /// Return true if the SCEV is a scAddRecExpr or it contains /// scAddRecExpr. The result will be cached in HasRecMap. /// bool containsAddRecurrence(const SCEV *S); /// Return the Value set from which the SCEV expr is generated. - SetVector *getSCEVValues(const SCEV *S); + SetVector *getSCEVValues(const SCEV *S); /// Erase Value from ValueExprMap and ExprValueMap. void eraseValueFromMap(Value *V); /// Return a SCEV expression for the full generality of the specified /// expression. const SCEV *getSCEV(Value *V); const SCEV *getConstant(ConstantInt *V); const SCEV *getConstant(const APInt& Val); const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty); const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); const SCEV *getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { SmallVector Ops = {LHS, RHS}; return getAddExpr(Ops, Flags); } const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { SmallVector Ops = {Op0, Op1, Op2}; return getAddExpr(Ops, Flags); } const SCEV *getMulExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { SmallVector Ops = {LHS, RHS}; return getMulExpr(Ops, Flags); } const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) { SmallVector Ops = {Op0, Op1, Op2}; return getMulExpr(Ops, Flags); } const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags); const SCEV *getAddRecExpr(SmallVectorImpl &Operands, const Loop *L, SCEV::NoWrapFlags Flags); const SCEV *getAddRecExpr(const SmallVectorImpl &Operands, const Loop *L, SCEV::NoWrapFlags Flags) { SmallVector NewOp(Operands.begin(), Operands.end()); return getAddRecExpr(NewOp, L, Flags); } /// Returns an expression for a GEP /// /// \p PointeeType The type used as the basis for the pointer arithmetics /// \p BaseExpr The expression for the pointer operand. /// \p IndexExprs The expressions for the indices. /// \p InBounds Whether the GEP is in bounds. const SCEV *getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, const SmallVectorImpl &IndexExprs, bool InBounds = false); const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getSMaxExpr(SmallVectorImpl &Operands); const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUMaxExpr(SmallVectorImpl &Operands); const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUnknown(Value *V); const SCEV *getCouldNotCompute(); /// Return a SCEV for the constant 0 of a specific type. const SCEV *getZero(Type *Ty) { return getConstant(Ty, 0); } /// Return a SCEV for the constant 1 of a specific type. const SCEV *getOne(Type *Ty) { return getConstant(Ty, 1); } /// Return an expression for sizeof AllocTy that is type IntTy /// const SCEV *getSizeOfExpr(Type *IntTy, Type *AllocTy); /// Return an expression for offsetof on the given field with type IntTy /// const SCEV *getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo); /// Return the SCEV object corresponding to -V. /// const SCEV *getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); /// Return the SCEV object corresponding to ~V. /// const SCEV *getNotSCEV(const SCEV *V); /// Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. const SCEV *getTruncateOrZeroExtend(const SCEV *V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. const SCEV *getTruncateOrSignExtend(const SCEV *V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. The /// conversion must not be narrowing. const SCEV *getNoopOrZeroExtend(const SCEV *V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. The /// conversion must not be narrowing. const SCEV *getNoopOrSignExtend(const SCEV *V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is extended with /// unspecified bits. The conversion must not be narrowing. const SCEV *getNoopOrAnyExtend(const SCEV *V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. The conversion must not be widening. const SCEV *getTruncateOrNoop(const SCEV *V, Type *Ty); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umax operation with them. const SCEV *getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner /// cases do exist. const SCEV *getPointerBase(const SCEV *V); /// Return a SCEV expression for the specified value at the specified scope /// in the program. The L value specifies a loop nest to evaluate the /// expression at, where null is the top-level or a specified loop is /// immediately inside of the loop. /// /// This method can be used to compute the exit value for a variable defined /// in a loop by querying what the value will hold in the parent loop. /// /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. const SCEV *getSCEVAtScope(const SCEV *S, const Loop *L); /// This is a convenience function which does getSCEVAtScope(getSCEV(V), L). const SCEV *getSCEVAtScope(Value *V, const Loop *L); /// Test whether entry to the loop is protected by a conditional between LHS /// and RHS. This is used to help avoid max expressions in loop trip /// counts, and to eliminate casts. bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Test whether the backedge of the loop is protected by a conditional /// between LHS and RHS. This is used to to eliminate casts. bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Returns the maximum trip count of the loop if it is a single-exit /// loop and we can compute a small maximum for that loop. /// /// Implemented in terms of the \c getSmallConstantTripCount overload with /// the single exiting block passed to it. See that routine for details. unsigned getSmallConstantTripCount(Loop *L); /// Returns the maximum trip count of this loop as a normal unsigned /// value. Returns 0 if the trip count is unknown or not constant. This /// "trip count" assumes that control exits via ExitingBlock. More /// precisely, it is the number of times that control may reach ExitingBlock /// before taking the branch. For loops with multiple exits, it may not be /// the number times that the loop header executes if the loop exits /// prematurely via another branch. unsigned getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock); /// Returns the largest constant divisor of the trip count of the /// loop if it is a single-exit loop and we can compute a small maximum for /// that loop. /// /// Implemented in terms of the \c getSmallConstantTripMultiple overload with /// the single exiting block passed to it. See that routine for details. unsigned getSmallConstantTripMultiple(Loop *L); /// Returns the largest constant divisor of the trip count of this loop as a /// normal unsigned value, if possible. This means that the actual trip /// count is always a multiple of the returned value (don't forget the trip /// count could very well be zero as well!). As explained in the comments /// for getSmallConstantTripCount, this assumes that control exits the loop /// via ExitingBlock. unsigned getSmallConstantTripMultiple(Loop *L, BasicBlock *ExitingBlock); /// Get the expression for the number of loop iterations for which this loop /// is guaranteed not to exit via ExitingBlock. Otherwise return /// SCEVCouldNotCompute. const SCEV *getExitCount(Loop *L, BasicBlock *ExitingBlock); /// If the specified loop has a predictable backedge-taken count, return it, /// otherwise return a SCEVCouldNotCompute object. The backedge-taken count /// is the number of times the loop header will be branched to from within /// the loop. This is one less than the trip count of the loop, since it /// doesn't count the first iteration, when the header is branched to from /// outside the loop. /// /// Note that it is not valid to call this method on a loop without a /// loop-invariant backedge-taken count (see /// hasLoopInvariantBackedgeTakenCount). /// const SCEV *getBackedgeTakenCount(const Loop *L); /// Similar to getBackedgeTakenCount, except it will add a set of /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. const SCEV *getPredicatedBackedgeTakenCount(const Loop *L, SCEVUnionPredicate &Predicates); /// Similar to getBackedgeTakenCount, except return the least SCEV value /// that is known never to be less than the actual backedge taken count. const SCEV *getMaxBackedgeTakenCount(const Loop *L); /// Return true if the specified loop has an analyzable loop-invariant /// backedge-taken count. bool hasLoopInvariantBackedgeTakenCount(const Loop *L); /// This method should be called by the client when it has changed a loop in /// a way that may effect ScalarEvolution's ability to compute a trip count, /// or if the loop is deleted. This call is potentially expensive for large /// loop bodies. void forgetLoop(const Loop *L); /// This method should be called by the client when it has changed a value /// in a way that may effect its value, or which may disconnect it from a /// def-use chain linking it to a loop. void forgetValue(Value *V); /// Called when the client has changed the disposition of values in /// this loop. /// /// We don't have a way to invalidate per-loop dispositions. Clear and /// recompute is simpler. void forgetLoopDispositions(const Loop *L) { LoopDispositions.clear(); } /// Determine the minimum number of zero bits that S is guaranteed to end in /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t GetMinTrailingZeros(const SCEV *S); /// Determine the unsigned range for a particular SCEV. /// ConstantRange getUnsignedRange(const SCEV *S) { return getRange(S, HINT_RANGE_UNSIGNED); } /// Determine the signed range for a particular SCEV. /// ConstantRange getSignedRange(const SCEV *S) { return getRange(S, HINT_RANGE_SIGNED); } /// Test if the given expression is known to be negative. /// bool isKnownNegative(const SCEV *S); /// Test if the given expression is known to be positive. /// bool isKnownPositive(const SCEV *S); /// Test if the given expression is known to be non-negative. /// bool isKnownNonNegative(const SCEV *S); /// Test if the given expression is known to be non-positive. /// bool isKnownNonPositive(const SCEV *S); /// Test if the given expression is known to be non-zero. /// bool isKnownNonZero(const SCEV *S); /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS. /// bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); /// Return true if the result of the predicate LHS `Pred` RHS is loop /// invariant with respect to L. Set InvariantPred, InvariantLHS and /// InvariantLHS so that InvariantLHS `InvariantPred` InvariantRHS is the /// loop invariant form of LHS `Pred` RHS. bool isLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, const SCEV *&InvariantRHS); /// Simplify LHS and RHS in a comparison with predicate Pred. Return true /// iff any changes were made. If the operands are provably equal or /// unequal, LHS and RHS are set to the same value and Pred is set to either /// ICMP_EQ or ICMP_NE. /// bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth = 0); /// Return the "disposition" of the given SCEV with respect to the given /// loop. LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L); /// Return true if the value of the given SCEV is unchanging in the /// specified loop. bool isLoopInvariant(const SCEV *S, const Loop *L); /// Return true if the given SCEV changes value in a known way in the /// specified loop. This property being true implies that the value is /// variant in the loop AND that we can emit an expression to compute the /// value of the expression at any particular loop iteration. bool hasComputableLoopEvolution(const SCEV *S, const Loop *L); /// Return the "disposition" of the given SCEV with respect to the given /// block. BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV dominate the /// specified basic block. bool dominates(const SCEV *S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV properly dominate /// the specified basic block. bool properlyDominates(const SCEV *S, const BasicBlock *BB); /// Test whether the given SCEV has Op as a direct or indirect operand. bool hasOperand(const SCEV *S, const SCEV *Op) const; /// Return the size of an element read or written by Inst. const SCEV *getElementSize(Instruction *Inst); /// Compute the array dimensions Sizes from the set of Terms extracted from /// the memory access function of this SCEVAddRecExpr (second step of /// delinearization). void findArrayDimensions(SmallVectorImpl &Terms, SmallVectorImpl &Sizes, const SCEV *ElementSize) const; void print(raw_ostream &OS) const; void verify() const; /// Collect parametric terms occurring in step expressions (first step of /// delinearization). void collectParametricTerms(const SCEV *Expr, SmallVectorImpl &Terms); /// Return in Subscripts the access functions for each dimension in Sizes /// (third step of delinearization). void computeAccessFunctions(const SCEV *Expr, SmallVectorImpl &Subscripts, SmallVectorImpl &Sizes); /// Split this SCEVAddRecExpr into two vectors of SCEVs representing the /// subscripts and sizes of an array access. /// /// The delinearization is a 3 step process: the first two steps compute the /// sizes of each subscript and the third step computes the access functions /// for the delinearized array: /// /// 1. Find the terms in the step functions /// 2. Compute the array size /// 3. Compute the access function: divide the SCEV by the array size /// starting with the innermost dimensions found in step 2. The Quotient /// is the SCEV to be divided in the next step of the recursion. The /// Remainder is the subscript of the innermost dimension. Loop over all /// array dimensions computed in step 2. /// /// To compute a uniform array size for several memory accesses to the same /// object, one can collect in step 1 all the step terms for all the memory /// accesses, and compute in step 2 a unique array shape. This guarantees /// that the array shape will be the same across all memory accesses. /// /// FIXME: We could derive the result of steps 1 and 2 from a description of /// the array shape given in metadata. /// /// Example: /// /// A[][n][m] /// /// for i /// for j /// for k /// A[j+k][2i][5i] = /// /// The initial SCEV: /// /// A[{{{0,+,2*m+5}_i, +, n*m}_j, +, n*m}_k] /// /// 1. Find the different terms in the step functions: /// -> [2*m, 5, n*m, n*m] /// /// 2. Compute the array size: sort and unique them /// -> [n*m, 2*m, 5] /// find the GCD of all the terms = 1 /// divide by the GCD and erase constant terms /// -> [n*m, 2*m] /// GCD = m /// divide by GCD -> [n, 2] /// remove constant terms /// -> [n] /// size of the array is A[unknown][n][m] /// /// 3. Compute the access function /// a. Divide {{{0,+,2*m+5}_i, +, n*m}_j, +, n*m}_k by the innermost size m /// Quotient: {{{0,+,2}_i, +, n}_j, +, n}_k /// Remainder: {{{0,+,5}_i, +, 0}_j, +, 0}_k /// The remainder is the subscript of the innermost array dimension: [5i]. /// /// b. Divide Quotient: {{{0,+,2}_i, +, n}_j, +, n}_k by next outer size n /// Quotient: {{{0,+,0}_i, +, 1}_j, +, 1}_k /// Remainder: {{{0,+,2}_i, +, 0}_j, +, 0}_k /// The Remainder is the subscript of the next array dimension: [2i]. /// /// The subscript of the outermost dimension is the Quotient: [j+k]. /// /// Overall, we have: A[][n][m], and the access function: A[j+k][2i][5i]. void delinearize(const SCEV *Expr, SmallVectorImpl &Subscripts, SmallVectorImpl &Sizes, const SCEV *ElementSize); /// Return the DataLayout associated with the module this SCEV instance is /// operating on. const DataLayout &getDataLayout() const { return F.getParent()->getDataLayout(); } const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS, const SCEVConstant *RHS); const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags); /// Re-writes the SCEV according to the Predicates in \p A. const SCEV *rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &A); /// Tries to convert the \p S expression to an AddRec expression, /// adding additional predicates to \p Preds as required. const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds); private: /// Compute the backedge taken count knowing the interval difference, the /// stride and presence of the equality in the comparison. const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride, bool Equality); /// Verify if an linear IV with positive stride can overflow when in a /// less-than comparison, knowing the invariant term of the comparison, /// the stride and the knowledge of NSW/NUW flags on the recurrence. bool doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap); /// Verify if an linear IV with negative stride can overflow when in a /// greater-than comparison, knowing the invariant term of the comparison, /// the stride and the knowledge of NSW/NUW flags on the recurrence. bool doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap); private: FoldingSet UniqueSCEVs; FoldingSet UniquePreds; BumpPtrAllocator SCEVAllocator; /// The head of a linked list of all SCEVUnknown values that have been /// allocated. This is used by releaseMemory to locate them all and call /// their destructors. SCEVUnknown *FirstUnknown; }; /// Analysis pass that exposes the \c ScalarEvolution for a function. class ScalarEvolutionAnalysis : public AnalysisInfoMixin { friend AnalysisInfoMixin; static char PassID; public: typedef ScalarEvolution Result; ScalarEvolution run(Function &F, AnalysisManager &AM); }; /// Printer pass for the \c ScalarEvolutionAnalysis results. class ScalarEvolutionPrinterPass : public PassInfoMixin { raw_ostream &OS; public: explicit ScalarEvolutionPrinterPass(raw_ostream &OS) : OS(OS) {} PreservedAnalyses run(Function &F, AnalysisManager &AM); }; class ScalarEvolutionWrapperPass : public FunctionPass { std::unique_ptr SE; public: static char ID; ScalarEvolutionWrapperPass(); ScalarEvolution &getSE() { return *SE; } const ScalarEvolution &getSE() const { return *SE; } bool runOnFunction(Function &F) override; void releaseMemory() override; void getAnalysisUsage(AnalysisUsage &AU) const override; void print(raw_ostream &OS, const Module * = nullptr) const override; void verifyAnalysis() const override; }; /// An interface layer with SCEV used to manage how we see SCEV expressions /// for values in the context of existing predicates. We can add new /// predicates, but we cannot remove them. /// /// This layer has multiple purposes: /// - provides a simple interface for SCEV versioning. /// - guarantees that the order of transformations applied on a SCEV /// expression for a single Value is consistent across two different /// getSCEV calls. This means that, for example, once we've obtained /// an AddRec expression for a certain value through expression /// rewriting, we will continue to get an AddRec expression for that /// Value. /// - lowers the number of expression rewrites. class PredicatedScalarEvolution { public: PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L); const SCEVUnionPredicate &getUnionPredicate() const; /// Returns the SCEV expression of V, in the context of the current SCEV /// predicate. The order of transformations applied on the expression of V /// returned by ScalarEvolution is guaranteed to be preserved, even when /// adding new predicates. const SCEV *getSCEV(Value *V); /// Get the (predicated) backedge count for the analyzed loop. const SCEV *getBackedgeTakenCount(); /// Adds a new predicate. void addPredicate(const SCEVPredicate &Pred); /// Attempts to produce an AddRecExpr for V by adding additional SCEV /// predicates. If we can't transform the expression into an AddRecExpr we /// return nullptr and not add additional SCEV predicates to the current /// context. const SCEVAddRecExpr *getAsAddRec(Value *V); /// Proves that V doesn't overflow by adding SCEV predicate. void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags); /// Returns true if we've proved that V doesn't wrap by means of a SCEV /// predicate. bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags); /// Returns the ScalarEvolution analysis used. ScalarEvolution *getSE() const { return &SE; } /// We need to explicitly define the copy constructor because of FlagsMap. PredicatedScalarEvolution(const PredicatedScalarEvolution&); /// Print the SCEV mappings done by the Predicated Scalar Evolution. /// The printed text is indented by \p Depth. void print(raw_ostream &OS, unsigned Depth) const; private: /// Increments the version number of the predicate. This needs to be called /// every time the SCEV predicate changes. void updateGeneration(); /// Holds a SCEV and the version number of the SCEV predicate used to /// perform the rewrite of the expression. typedef std::pair RewriteEntry; /// Maps a SCEV to the rewrite result of that SCEV at a certain version /// number. If this number doesn't match the current Generation, we will /// need to do a rewrite. To preserve the transformation order of previous /// rewrites, we will rewrite the previous result instead of the original /// SCEV. DenseMap RewriteMap; /// Records what NoWrap flags we've added to a Value *. ValueMap FlagsMap; /// The ScalarEvolution analysis. ScalarEvolution &SE; /// The analyzed Loop. const Loop &L; /// The SCEVPredicate that forms our context. We will rewrite all /// expressions assuming that this predicate true. SCEVUnionPredicate Preds; /// Marks the version of the SCEV predicate used. When rewriting a SCEV /// expression we mark it with the version of the predicate. We use this to /// figure out if the predicate has changed from the last rewrite of the /// SCEV. If so, we need to perform a new rewrite. unsigned Generation; /// The backedge taken count. const SCEV *BackedgeCount; }; } #endif Index: head/contrib/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h =================================================================== --- head/contrib/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h (revision 312831) +++ head/contrib/llvm/include/llvm/Analysis/ScalarEvolutionExpander.h (revision 312832) @@ -1,384 +1,396 @@ //===---- llvm/Analysis/ScalarEvolutionExpander.h - SCEV Exprs --*- C++ -*-===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // This file defines the classes used to generate code from scalar expressions. // //===----------------------------------------------------------------------===// #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPANDER_H #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPANDER_H +#include "llvm/ADT/Optional.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/ValueHandle.h" #include namespace llvm { class TargetTransformInfo; /// Return true if the given expression is safe to expand in the sense that /// all materialized values are safe to speculate. bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE); /// This class uses information about analyze scalars to /// rewrite expressions in canonical form. /// /// Clients should create an instance of this class when rewriting is needed, /// and destroy it when finished to allow the release of the associated /// memory. class SCEVExpander : public SCEVVisitor { ScalarEvolution &SE; const DataLayout &DL; // New instructions receive a name to identifies them with the current pass. const char* IVName; // InsertedExpressions caches Values for reuse, so must track RAUW. std::map, TrackingVH > InsertedExpressions; // InsertedValues only flags inserted instructions so needs no RAUW. std::set > InsertedValues; std::set > InsertedPostIncValues; /// A memoization of the "relevant" loop for a given SCEV. DenseMap RelevantLoops; /// \brief Addrecs referring to any of the given loops are expanded /// in post-inc mode. For example, expanding {1,+,1} in post-inc mode /// returns the add instruction that adds one to the phi for {0,+,1}, /// as opposed to a new phi starting at 1. This is only supported in /// non-canonical mode. PostIncLoopSet PostIncLoops; /// \brief When this is non-null, addrecs expanded in the loop it indicates /// should be inserted with increments at IVIncInsertPos. const Loop *IVIncInsertLoop; /// \brief When expanding addrecs in the IVIncInsertLoop loop, insert the IV /// increment at this position. Instruction *IVIncInsertPos; /// \brief Phis that complete an IV chain. Reuse std::set > ChainedPhis; /// \brief When true, expressions are expanded in "canonical" form. In /// particular, addrecs are expanded as arithmetic based on a canonical /// induction variable. When false, expression are expanded in a more /// literal form. bool CanonicalMode; /// \brief When invoked from LSR, the expander is in "strength reduction" /// mode. The only difference is that phi's are only reused if they are /// already in "expanded" form. bool LSRMode; typedef IRBuilder BuilderType; BuilderType Builder; // RAII object that stores the current insertion point and restores it when // the object is destroyed. This includes the debug location. Duplicated // from InsertPointGuard to add SetInsertPoint() which is used to updated // InsertPointGuards stack when insert points are moved during SCEV // expansion. class SCEVInsertPointGuard { IRBuilderBase &Builder; AssertingVH Block; BasicBlock::iterator Point; DebugLoc DbgLoc; SCEVExpander *SE; SCEVInsertPointGuard(const SCEVInsertPointGuard &) = delete; SCEVInsertPointGuard &operator=(const SCEVInsertPointGuard &) = delete; public: SCEVInsertPointGuard(IRBuilderBase &B, SCEVExpander *SE) : Builder(B), Block(B.GetInsertBlock()), Point(B.GetInsertPoint()), DbgLoc(B.getCurrentDebugLocation()), SE(SE) { SE->InsertPointGuards.push_back(this); } ~SCEVInsertPointGuard() { // These guards should always created/destroyed in FIFO order since they // are used to guard lexically scoped blocks of code in // ScalarEvolutionExpander. assert(SE->InsertPointGuards.back() == this); SE->InsertPointGuards.pop_back(); Builder.restoreIP(IRBuilderBase::InsertPoint(Block, Point)); Builder.SetCurrentDebugLocation(DbgLoc); } BasicBlock::iterator GetInsertPoint() const { return Point; } void SetInsertPoint(BasicBlock::iterator I) { Point = I; } }; /// Stack of pointers to saved insert points, used to keep insert points /// consistent when instructions are moved. SmallVector InsertPointGuards; #ifndef NDEBUG const char *DebugType; #endif friend struct SCEVVisitor; public: /// \brief Construct a SCEVExpander in "canonical" mode. explicit SCEVExpander(ScalarEvolution &se, const DataLayout &DL, const char *name) : SE(se), DL(DL), IVName(name), IVIncInsertLoop(nullptr), IVIncInsertPos(nullptr), CanonicalMode(true), LSRMode(false), Builder(se.getContext(), TargetFolder(DL)) { #ifndef NDEBUG DebugType = ""; #endif } ~SCEVExpander() { // Make sure the insert point guard stack is consistent. assert(InsertPointGuards.empty()); } #ifndef NDEBUG void setDebugType(const char* s) { DebugType = s; } #endif /// \brief Erase the contents of the InsertedExpressions map so that users /// trying to expand the same expression into multiple BasicBlocks or /// different places within the same BasicBlock can do so. void clear() { InsertedExpressions.clear(); InsertedValues.clear(); InsertedPostIncValues.clear(); ChainedPhis.clear(); } /// \brief Return true for expressions that may incur non-trivial cost to /// evaluate at runtime. /// /// At is an optional parameter which specifies point in code where user is /// going to expand this expression. Sometimes this knowledge can lead to a /// more accurate cost estimation. bool isHighCostExpansion(const SCEV *Expr, Loop *L, const Instruction *At = nullptr) { SmallPtrSet Processed; return isHighCostExpansionHelper(Expr, L, At, Processed); } /// \brief This method returns the canonical induction variable of the /// specified type for the specified loop (inserting one if there is none). /// A canonical induction variable starts at zero and steps by one on each /// iteration. PHINode *getOrInsertCanonicalInductionVariable(const Loop *L, Type *Ty); /// \brief Return the induction variable increment's IV operand. Instruction *getIVIncOperand(Instruction *IncV, Instruction *InsertPos, bool allowScale); /// \brief Utility for hoisting an IV increment. bool hoistIVInc(Instruction *IncV, Instruction *InsertPos); /// \brief replace congruent phis with their most canonical /// representative. Return the number of phis eliminated. unsigned replaceCongruentIVs(Loop *L, const DominatorTree *DT, SmallVectorImpl &DeadInsts, const TargetTransformInfo *TTI = nullptr); /// \brief Insert code to directly compute the specified SCEV expression /// into the program. The inserted code is inserted into the specified /// block. Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I); /// \brief Insert code to directly compute the specified SCEV expression /// into the program. The inserted code is inserted into the SCEVExpander's /// current insertion point. If a type is specified, the result will be /// expanded to have that type, with a cast if necessary. Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr); /// \brief Generates a code sequence that evaluates this predicate. /// The inserted instructions will be at position \p Loc. /// The result will be of type i1 and will have a value of 0 when the /// predicate is false and 1 otherwise. Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc); /// \brief A specialized variant of expandCodeForPredicate, handling the /// case when we are expanding code for a SCEVEqualPredicate. Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *Loc); /// \brief Generates code that evaluates if the \p AR expression will /// overflow. Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, bool Signed); /// \brief A specialized variant of expandCodeForPredicate, handling the /// case when we are expanding code for a SCEVWrapPredicate. Value *expandWrapPredicate(const SCEVWrapPredicate *P, Instruction *Loc); /// \brief A specialized variant of expandCodeForPredicate, handling the /// case when we are expanding code for a SCEVUnionPredicate. Value *expandUnionPredicate(const SCEVUnionPredicate *Pred, Instruction *Loc); /// \brief Set the current IV increment loop and position. void setIVIncInsertPos(const Loop *L, Instruction *Pos) { assert(!CanonicalMode && "IV increment positions are not supported in CanonicalMode"); IVIncInsertLoop = L; IVIncInsertPos = Pos; } /// \brief Enable post-inc expansion for addrecs referring to the given /// loops. Post-inc expansion is only supported in non-canonical mode. void setPostInc(const PostIncLoopSet &L) { assert(!CanonicalMode && "Post-inc expansion is not supported in CanonicalMode"); PostIncLoops = L; } /// \brief Disable all post-inc expansion. void clearPostInc() { PostIncLoops.clear(); // When we change the post-inc loop set, cached expansions may no // longer be valid. InsertedPostIncValues.clear(); } /// \brief Disable the behavior of expanding expressions in canonical form /// rather than in a more literal form. Non-canonical mode is useful for /// late optimization passes. void disableCanonicalMode() { CanonicalMode = false; } void enableLSRMode() { LSRMode = true; } /// \brief Set the current insertion point. This is useful if multiple calls /// to expandCodeFor() are going to be made with the same insert point and /// the insert point may be moved during one of the expansions (e.g. if the /// insert point is not a block terminator). void setInsertPoint(Instruction *IP) { assert(IP); Builder.SetInsertPoint(IP); } /// \brief Clear the current insertion point. This is useful if the /// instruction that had been serving as the insertion point may have been /// deleted. void clearInsertPoint() { Builder.ClearInsertionPoint(); } /// \brief Return true if the specified instruction was inserted by the code /// rewriter. If so, the client should not modify the instruction. bool isInsertedInstruction(Instruction *I) const { return InsertedValues.count(I) || InsertedPostIncValues.count(I); } void setChainedPhi(PHINode *PN) { ChainedPhis.insert(PN); } - /// \brief Try to find LLVM IR value for S available at the point At. + /// Try to find existing LLVM IR value for S available at the point At. + Value *getExactExistingExpansion(const SCEV *S, const Instruction *At, + Loop *L); + + /// Try to find the ValueOffsetPair for S. The function is mainly + /// used to check whether S can be expanded cheaply. + /// If this returns a non-None value, we know we can codegen the + /// `ValueOffsetPair` into a suitable expansion identical with S + /// so that S can be expanded cheaply. /// /// L is a hint which tells in which loop to look for the suitable value. /// On success return value which is equivalent to the expanded S at point /// At. Return nullptr if value was not found. /// /// Note that this function does not perform an exhaustive search. I.e if it /// didn't find any value it does not mean that there is no such value. - Value *findExistingExpansion(const SCEV *S, const Instruction *At, Loop *L); + /// + Optional + getRelatedExistingExpansion(const SCEV *S, const Instruction *At, Loop *L); private: LLVMContext &getContext() const { return SE.getContext(); } /// \brief Recursive helper function for isHighCostExpansion. bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction *At, SmallPtrSetImpl &Processed); /// \brief Insert the specified binary operator, doing a small amount /// of work to avoid inserting an obviously redundant operation. Value *InsertBinop(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS); /// \brief Arrange for there to be a cast of V to Ty at IP, reusing an /// existing cast if a suitable one exists, moving an existing cast if a /// suitable one exists but isn't in the right place, or or creating a new /// one. Value *ReuseOrCreateCast(Value *V, Type *Ty, Instruction::CastOps Op, BasicBlock::iterator IP); /// \brief Insert a cast of V to the specified type, which must be possible /// with a noop cast, doing what we can to share the casts. Value *InsertNoopCastOfTo(Value *V, Type *Ty); /// \brief Expand a SCEVAddExpr with a pointer type into a GEP /// instead of using ptrtoint+arithmetic+inttoptr. Value *expandAddToGEP(const SCEV *const *op_begin, const SCEV *const *op_end, PointerType *PTy, Type *Ty, Value *V); /// \brief Find a previous Value in ExprValueMap for expand. - Value *FindValueInExprValueMap(const SCEV *S, const Instruction *InsertPt); + ScalarEvolution::ValueOffsetPair + FindValueInExprValueMap(const SCEV *S, const Instruction *InsertPt); Value *expand(const SCEV *S); /// \brief Determine the most "relevant" loop for the given SCEV. const Loop *getRelevantLoop(const SCEV *); Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } Value *visitTruncateExpr(const SCEVTruncateExpr *S); Value *visitZeroExtendExpr(const SCEVZeroExtendExpr *S); Value *visitSignExtendExpr(const SCEVSignExtendExpr *S); Value *visitAddExpr(const SCEVAddExpr *S); Value *visitMulExpr(const SCEVMulExpr *S); Value *visitUDivExpr(const SCEVUDivExpr *S); Value *visitAddRecExpr(const SCEVAddRecExpr *S); Value *visitSMaxExpr(const SCEVSMaxExpr *S); Value *visitUMaxExpr(const SCEVUMaxExpr *S); Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); } void rememberInstruction(Value *I); bool isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L); bool isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L); Value *expandAddRecExprLiterally(const SCEVAddRecExpr *); PHINode *getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, const Loop *L, Type *ExpandTy, Type *IntTy, Type *&TruncTy, bool &InvertStep); Value *expandIVInc(PHINode *PN, Value *StepV, const Loop *L, Type *ExpandTy, Type *IntTy, bool useSubtract); void hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, Instruction *Pos, PHINode *LoopPhi); void fixupInsertPoints(Instruction *I); }; } #endif Index: head/contrib/llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- head/contrib/llvm/lib/Analysis/ScalarEvolution.cpp (revision 312831) +++ head/contrib/llvm/lib/Analysis/ScalarEvolution.cpp (revision 312832) @@ -1,10467 +1,10509 @@ //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // This file contains the implementation of the scalar evolution analysis // engine, which is used primarily to analyze expressions involving induction // variables in loops. // // There are several aspects to this library. First is the representation of // scalar expressions, which are represented as subclasses of the SCEV class. // These classes are used to represent certain types of subexpressions that we // can handle. We only create one SCEV of a particular shape, so // pointer-comparisons for equality are legal. // // One important aspect of the SCEV objects is that they are never cyclic, even // if there is a cycle in the dataflow for an expression (ie, a PHI node). If // the PHI node is one of the idioms that we can represent (e.g., a polynomial // recurrence) then we represent it directly as a recurrence node, otherwise we // represent it as a SCEVUnknown node. // // In addition to being able to represent expressions of various types, we also // have folders that are used to build the *canonical* representation for a // particular expression. These folders are capable of using a variety of // rewrite rules to simplify the expressions. // // Once the folders are defined, we can implement the more interesting // higher-level code, such as the code that recognizes PHI nodes of various // types, computes the execution count of a loop, etc. // // TODO: We should use these routines and value representations to implement // dependence analysis! // //===----------------------------------------------------------------------===// // // There are several good references for the techniques used in this analysis. // // Chains of recurrences -- a method to expedite the evaluation // of closed-form functions // Olaf Bachmann, Paul S. Wang, Eugene V. Zima // // On computational properties of chains of recurrences // Eugene V. Zima // // Symbolic Evaluation of Chains of Recurrences for Loop Optimization // Robert A. van Engelen // // Efficient Symbolic Analysis for Optimizing Compilers // Robert A. van Engelen // // Using the chains of recurrences algebra for data dependence testing and // induction variable substitution // MS Thesis, Johnie Birch // //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/SaveAndRestore.h" #include using namespace llvm; #define DEBUG_TYPE "scalar-evolution" STATISTIC(NumArrayLenItCounts, "Number of trip counts computed with array length"); STATISTIC(NumTripCountsComputed, "Number of loops with predictable loop counts"); STATISTIC(NumTripCountsNotComputed, "Number of loops without predictable loop counts"); STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100)); // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. static cl::opt VerifySCEV("verify-scev", cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); static cl::opt VerifySCEVMap("verify-scev-maps", cl::desc("Verify no dangling value in ScalarEvolution's " "ExprValueMap (slow)")); //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Implementation of the SCEV class. // LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } void SCEV::print(raw_ostream &OS) const { switch (static_cast(getSCEVType())) { case scConstant: cast(this)->getValue()->printAsOperand(OS, false); return; case scTruncate: { const SCEVTruncateExpr *Trunc = cast(this); const SCEV *Op = Trunc->getOperand(); OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Trunc->getType() << ")"; return; } case scZeroExtend: { const SCEVZeroExtendExpr *ZExt = cast(this); const SCEV *Op = ZExt->getOperand(); OS << "(zext " << *Op->getType() << " " << *Op << " to " << *ZExt->getType() << ")"; return; } case scSignExtend: { const SCEVSignExtendExpr *SExt = cast(this); const SCEV *Op = SExt->getOperand(); OS << "(sext " << *Op->getType() << " " << *Op << " to " << *SExt->getType() << ")"; return; } case scAddRecExpr: { const SCEVAddRecExpr *AR = cast(this); OS << "{" << *AR->getOperand(0); for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) OS << ",+," << *AR->getOperand(i); OS << "}<"; if (AR->hasNoUnsignedWrap()) OS << "nuw><"; if (AR->hasNoSignedWrap()) OS << "nsw><"; if (AR->hasNoSelfWrap() && !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) OS << "nw><"; AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ">"; return; } case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: { const SCEVNAryExpr *NAry = cast(this); const char *OpStr = nullptr; switch (NAry->getSCEVType()) { case scAddExpr: OpStr = " + "; break; case scMulExpr: OpStr = " * "; break; case scUMaxExpr: OpStr = " umax "; break; case scSMaxExpr: OpStr = " smax "; break; } OS << "("; for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); I != E; ++I) { OS << **I; if (std::next(I) != E) OS << OpStr; } OS << ")"; switch (NAry->getSCEVType()) { case scAddExpr: case scMulExpr: if (NAry->hasNoUnsignedWrap()) OS << ""; if (NAry->hasNoSignedWrap()) OS << ""; } return; } case scUDivExpr: { const SCEVUDivExpr *UDiv = cast(this); OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } case scUnknown: { const SCEVUnknown *U = cast(this); Type *AllocTy; if (U->isSizeOf(AllocTy)) { OS << "sizeof(" << *AllocTy << ")"; return; } if (U->isAlignOf(AllocTy)) { OS << "alignof(" << *AllocTy << ")"; return; } Type *CTy; Constant *FieldNo; if (U->isOffsetOf(CTy, FieldNo)) { OS << "offsetof(" << *CTy << ", "; FieldNo->printAsOperand(OS, false); OS << ")"; return; } // Otherwise just print it normally. U->getValue()->printAsOperand(OS, false); return; } case scCouldNotCompute: OS << "***COULDNOTCOMPUTE***"; return; } llvm_unreachable("Unknown SCEV kind!"); } Type *SCEV::getType() const { switch (static_cast(getSCEVType())) { case scConstant: return cast(this)->getType(); case scTruncate: case scZeroExtend: case scSignExtend: return cast(this)->getType(); case scAddRecExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: return cast(this)->getType(); case scAddExpr: return cast(this)->getType(); case scUDivExpr: return cast(this)->getType(); case scUnknown: return cast(this)->getType(); case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } llvm_unreachable("Unknown SCEV kind!"); } bool SCEV::isZero() const { if (const SCEVConstant *SC = dyn_cast(this)) return SC->getValue()->isZero(); return false; } bool SCEV::isOne() const { if (const SCEVConstant *SC = dyn_cast(this)) return SC->getValue()->isOne(); return false; } bool SCEV::isAllOnesValue() const { if (const SCEVConstant *SC = dyn_cast(this)) return SC->getValue()->isAllOnesValue(); return false; } bool SCEV::isNonConstantNegative() const { const SCEVMulExpr *Mul = dyn_cast(this); if (!Mul) return false; // If there is a constant factor, it will be first. const SCEVConstant *SC = dyn_cast(Mul->getOperand(0)); if (!SC) return false; // Return true if the value is negative, this matches things like (-42 * V). return SC->getAPInt().isNegative(); } SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { FoldingSetNodeID ID; ID.AddInteger(scConstant); ID.AddPointer(V); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); UniqueSCEVs.InsertNode(S, IP); return S; } const SCEV *ScalarEvolution::getConstant(const APInt &Val) { return getConstant(ConstantInt::get(getContext(), Val)); } const SCEV * ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); return getConstant(ConstantInt::get(ITy, V, isSigned)); } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, unsigned SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scTruncate, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate non-integer value!"); } SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scZeroExtend, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot zero extend non-integer value!"); } SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scSignExtend, op, ty) { assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot sign extend non-integer value!"); } void SCEVUnknown::deleted() { // Clear this SCEVUnknown from various maps. SE->forgetMemoizedResults(this); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); // Release the value. setValPtr(nullptr); } void SCEVUnknown::allUsesReplacedWith(Value *New) { // Clear this SCEVUnknown from various maps. SE->forgetMemoizedResults(this); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); // Update this SCEVUnknown to point to the new value. This is needed // because there may still be outstanding SCEVs which still point to // this SCEVUnknown. setValPtr(New); } bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getOperand(0)->isNullValue() && CE->getNumOperands() == 2) if (ConstantInt *CI = dyn_cast(CE->getOperand(1))) if (CI->isOne()) { AllocTy = cast(CE->getOperand(0)->getType()) ->getElementType(); return true; } return false; } bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getOperand(0)->isNullValue()) { Type *Ty = cast(CE->getOperand(0)->getType())->getElementType(); if (StructType *STy = dyn_cast(Ty)) if (!STy->isPacked() && CE->getNumOperands() == 3 && CE->getOperand(1)->isNullValue()) { if (ConstantInt *CI = dyn_cast(CE->getOperand(2))) if (CI->isOne() && STy->getNumElements() == 2 && STy->getElementType(0)->isIntegerTy(1)) { AllocTy = STy->getElementType(1); return true; } } } return false; } bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { if (ConstantExpr *VCE = dyn_cast(getValue())) if (VCE->getOpcode() == Instruction::PtrToInt) if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getNumOperands() == 3 && CE->getOperand(0)->isNullValue() && CE->getOperand(1)->isNullValue()) { Type *Ty = cast(CE->getOperand(0)->getType())->getElementType(); // Ignore vector types here so that ScalarEvolutionExpander doesn't // emit getelementptrs that index into vectors. if (Ty->isStructTy() || Ty->isArrayTy()) { CTy = Ty; FieldNo = CE->getOperand(2); return true; } } return false; } //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. class SCEVComplexityCompare { const LoopInfo *const LI; public: explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} // Return true or false if LHS is less than, or at least RHS, respectively. bool operator()(const SCEV *LHS, const SCEV *RHS) const { return compare(LHS, RHS) < 0; } // Return negative, zero, or positive, if LHS is less than, equal to, or // greater than RHS, respectively. A three-way result allows recursive // comparisons to be more efficient. int compare(const SCEV *LHS, const SCEV *RHS) const { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) return 0; // Primarily, sort the SCEVs by their getSCEVType(). unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); if (LType != RType) return (int)LType - (int)RType; // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, // so that (a + b) and (b + a) don't end up as different expressions. switch (static_cast(LType)) { case scUnknown: { const SCEVUnknown *LU = cast(LHS); const SCEVUnknown *RU = cast(RHS); // Sort SCEVUnknown values with some loose heuristics. TODO: This is // not as complete as it could be. const Value *LV = LU->getValue(), *RV = RU->getValue(); // Order pointer values after integer values. This helps SCEVExpander // form GEPs. bool LIsPointer = LV->getType()->isPointerTy(), RIsPointer = RV->getType()->isPointerTy(); if (LIsPointer != RIsPointer) return (int)LIsPointer - (int)RIsPointer; // Compare getValueID values. unsigned LID = LV->getValueID(), RID = RV->getValueID(); if (LID != RID) return (int)LID - (int)RID; // Sort arguments by their position. if (const Argument *LA = dyn_cast(LV)) { const Argument *RA = cast(RV); unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); return (int)LArgNo - (int)RArgNo; } // For instructions, compare their loop depth, and their operand // count. This is pretty loose. if (const Instruction *LInst = dyn_cast(LV)) { const Instruction *RInst = cast(RV); // Compare loop depths. const BasicBlock *LParent = LInst->getParent(), *RParent = RInst->getParent(); if (LParent != RParent) { unsigned LDepth = LI->getLoopDepth(LParent), RDepth = LI->getLoopDepth(RParent); if (LDepth != RDepth) return (int)LDepth - (int)RDepth; } // Compare the number of operands. unsigned LNumOps = LInst->getNumOperands(), RNumOps = RInst->getNumOperands(); return (int)LNumOps - (int)RNumOps; } return 0; } case scConstant: { const SCEVConstant *LC = cast(LHS); const SCEVConstant *RC = cast(RHS); // Compare constant values. const APInt &LA = LC->getAPInt(); const APInt &RA = RC->getAPInt(); unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); if (LBitWidth != RBitWidth) return (int)LBitWidth - (int)RBitWidth; return LA.ult(RA) ? -1 : 1; } case scAddRecExpr: { const SCEVAddRecExpr *LA = cast(LHS); const SCEVAddRecExpr *RA = cast(RHS); // Compare addrec loop depths. const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); if (LLoop != RLoop) { unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); if (LDepth != RDepth) return (int)LDepth - (int)RDepth; } // Addrec complexity grows with operand count. unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { long X = compare(LA->getOperand(i), RA->getOperand(i)); if (X != 0) return X; } return 0; } case scAddExpr: case scMulExpr: case scSMaxExpr: case scUMaxExpr: { const SCEVNAryExpr *LC = cast(LHS); const SCEVNAryExpr *RC = cast(RHS); // Lexicographically compare n-ary expressions. unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; for (unsigned i = 0; i != LNumOps; ++i) { if (i >= RNumOps) return 1; long X = compare(LC->getOperand(i), RC->getOperand(i)); if (X != 0) return X; } return (int)LNumOps - (int)RNumOps; } case scUDivExpr: { const SCEVUDivExpr *LC = cast(LHS); const SCEVUDivExpr *RC = cast(RHS); // Lexicographically compare udiv expressions. long X = compare(LC->getLHS(), RC->getLHS()); if (X != 0) return X; return compare(LC->getRHS(), RC->getRHS()); } case scTruncate: case scZeroExtend: case scSignExtend: { const SCEVCastExpr *LC = cast(LHS); const SCEVCastExpr *RC = cast(RHS); // Compare cast expressions by operand. return compare(LC->getOperand(), RC->getOperand()); } case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } llvm_unreachable("Unknown SCEV kind!"); } }; } // end anonymous namespace /// Given a list of SCEV objects, order them by their complexity, and group /// objects of the same complexity together by value. When this routine is /// finished, we know that any duplicates in the vector are consecutive and that /// complexity is monotonically increasing. /// /// Note that we go take special precautions to ensure that we get deterministic /// results from this routine. In other words, we don't want the results of /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. /// static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; if (SCEVComplexityCompare(LI)(RHS, LHS)) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); // Now that we are sorted by complexity, group elements of the same // complexity. Note that this is, at worst, N^2, but the vector is likely to // be extremely short in practice. Note that we take this approach because we // do not want to depend on the addresses of the objects we are grouping. for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { const SCEV *S = Ops[i]; unsigned Complexity = S->getSCEVType(); // If there are any objects of the same complexity and same value as this // one, group them. for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { if (Ops[j] == S) { // Found a duplicate. // Move it to immediately after i'th element. std::swap(Ops[i+1], Ops[j]); ++i; // no need to rescan it. if (i == e-2) return; // Done! } } } } // Returns the size of the SCEV S. static inline int sizeOfSCEV(const SCEV *S) { struct FindSCEVSize { int Size; FindSCEVSize() : Size(0) {} bool follow(const SCEV *S) { ++Size; // Keep looking at all operands of S. return true; } bool isDone() const { return false; } }; FindSCEVSize F; SCEVTraversal ST(F); ST.visitAll(S); return F.Size; } namespace { struct SCEVDivision : public SCEVVisitor { public: // Computes the Quotient and Remainder of the division of Numerator by // Denominator. static void divide(ScalarEvolution &SE, const SCEV *Numerator, const SCEV *Denominator, const SCEV **Quotient, const SCEV **Remainder) { assert(Numerator && Denominator && "Uninitialized SCEV"); SCEVDivision D(SE, Numerator, Denominator); // Check for the trivial case here to avoid having to check for it in the // rest of the code. if (Numerator == Denominator) { *Quotient = D.One; *Remainder = D.Zero; return; } if (Numerator->isZero()) { *Quotient = D.Zero; *Remainder = D.Zero; return; } // A simple case when N/1. The quotient is N. if (Denominator->isOne()) { *Quotient = Numerator; *Remainder = D.Zero; return; } // Split the Denominator when it is a product. if (const SCEVMulExpr *T = dyn_cast(Denominator)) { const SCEV *Q, *R; *Quotient = Numerator; for (const SCEV *Op : T->operands()) { divide(SE, *Quotient, Op, &Q, &R); *Quotient = Q; // Bail out when the Numerator is not divisible by one of the terms of // the Denominator. if (!R->isZero()) { *Quotient = D.Zero; *Remainder = Numerator; return; } } *Remainder = D.Zero; return; } D.visit(Numerator); *Quotient = D.Quotient; *Remainder = D.Remainder; } // Except in the trivial case described above, we do not know how to divide // Expr by Denominator for the following functions with empty implementation. void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} void visitUDivExpr(const SCEVUDivExpr *Numerator) {} void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} void visitUnknown(const SCEVUnknown *Numerator) {} void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} void visitConstant(const SCEVConstant *Numerator) { if (const SCEVConstant *D = dyn_cast(Denominator)) { APInt NumeratorVal = Numerator->getAPInt(); APInt DenominatorVal = D->getAPInt(); uint32_t NumeratorBW = NumeratorVal.getBitWidth(); uint32_t DenominatorBW = DenominatorVal.getBitWidth(); if (NumeratorBW > DenominatorBW) DenominatorVal = DenominatorVal.sext(NumeratorBW); else if (NumeratorBW < DenominatorBW) NumeratorVal = NumeratorVal.sext(DenominatorBW); APInt QuotientVal(NumeratorVal.getBitWidth(), 0); APInt RemainderVal(NumeratorVal.getBitWidth(), 0); APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); Quotient = SE.getConstant(QuotientVal); Remainder = SE.getConstant(RemainderVal); return; } } void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; if (!Numerator->isAffine()) return cannotDivide(Numerator); divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); // Bail out if the types do not match. Type *Ty = Denominator->getType(); if (Ty != StartQ->getType() || Ty != StartR->getType() || Ty != StepQ->getType() || Ty != StepR->getType()) return cannotDivide(Numerator); Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), Numerator->getNoWrapFlags()); Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), Numerator->getNoWrapFlags()); } void visitAddExpr(const SCEVAddExpr *Numerator) { SmallVector Qs, Rs; Type *Ty = Denominator->getType(); for (const SCEV *Op : Numerator->operands()) { const SCEV *Q, *R; divide(SE, Op, Denominator, &Q, &R); // Bail out if types do not match. if (Ty != Q->getType() || Ty != R->getType()) return cannotDivide(Numerator); Qs.push_back(Q); Rs.push_back(R); } if (Qs.size() == 1) { Quotient = Qs[0]; Remainder = Rs[0]; return; } Quotient = SE.getAddExpr(Qs); Remainder = SE.getAddExpr(Rs); } void visitMulExpr(const SCEVMulExpr *Numerator) { SmallVector Qs; Type *Ty = Denominator->getType(); bool FoundDenominatorTerm = false; for (const SCEV *Op : Numerator->operands()) { // Bail out if types do not match. if (Ty != Op->getType()) return cannotDivide(Numerator); if (FoundDenominatorTerm) { Qs.push_back(Op); continue; } // Check whether Denominator divides one of the product operands. const SCEV *Q, *R; divide(SE, Op, Denominator, &Q, &R); if (!R->isZero()) { Qs.push_back(Op); continue; } // Bail out if types do not match. if (Ty != Q->getType()) return cannotDivide(Numerator); FoundDenominatorTerm = true; Qs.push_back(Q); } if (FoundDenominatorTerm) { Remainder = Zero; if (Qs.size() == 1) Quotient = Qs[0]; else Quotient = SE.getMulExpr(Qs); return; } if (!isa(Denominator)) return cannotDivide(Numerator); // The Remainder is obtained by replacing Denominator by 0 in Numerator. ValueToValueMap RewriteMap; RewriteMap[cast(Denominator)->getValue()] = cast(Zero)->getValue(); Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); if (Remainder->isZero()) { // The Quotient is obtained by replacing Denominator by 1 in Numerator. RewriteMap[cast(Denominator)->getValue()] = cast(One)->getValue(); Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); return; } // Quotient is (Numerator - Remainder) divided by Denominator. const SCEV *Q, *R; const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); // This SCEV does not seem to simplify: fail the division here. if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) return cannotDivide(Numerator); divide(SE, Diff, Denominator, &Q, &R); if (R != Zero) return cannotDivide(Numerator); Quotient = Q; } private: SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) : SE(S), Denominator(Denominator) { Zero = SE.getZero(Denominator->getType()); One = SE.getOne(Denominator->getType()); // We generally do not know how to divide Expr by Denominator. We // initialize the division to a "cannot divide" state to simplify the rest // of the code. cannotDivide(Numerator); } // Convenience function for giving up on the division. We set the quotient to // be equal to zero and the remainder to be equal to the numerator. void cannotDivide(const SCEV *Numerator) { Quotient = Zero; Remainder = Numerator; } ScalarEvolution &SE; const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; }; } //===----------------------------------------------------------------------===// // Simple SCEV method implementations //===----------------------------------------------------------------------===// /// Compute BC(It, K). The result has width W. Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); // We are using the following formula for BC(It, K): // // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! // // Suppose, W is the bitwidth of the return value. We must be prepared for // overflow. Hence, we must assure that the result of our computation is // equal to the accurate one modulo 2^W. Unfortunately, division isn't // safe in modular arithmetic. // // However, this code doesn't use exactly that formula; the formula it uses // is something like the following, where T is the number of factors of 2 in // K! (i.e. trailing zeros in the binary representation of K!), and ^ is // exponentiation: // // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) // // This formula is trivially equivalent to the previous formula. However, // this formula can be implemented much more efficiently. The trick is that // K! / 2^T is odd, and exact division by an odd number *is* safe in modular // arithmetic. To do exact division in modular arithmetic, all we have // to do is multiply by the inverse. Therefore, this step can be done at // width W. // // The next issue is how to safely do the division by 2^T. The way this // is done is by doing the multiplication step at a width of at least W + T // bits. This way, the bottom W+T bits of the product are accurate. Then, // when we perform the division by 2^T (which is equivalent to a right shift // by T), the bottom W bits are accurate. Extra bits are okay; they'll get // truncated out after the division by 2^T. // // In comparison to just directly using the first formula, this technique // is much more efficient; using the first formula requires W * K bits, // but this formula less than W + K bits. Also, the first formula requires // a division step, whereas this formula only requires multiplies and shifts. // // It doesn't matter whether the subtraction step is done in the calculation // width or the input iteration count's width; if the subtraction overflows, // the result must be zero anyway. We prefer here to do it in the width of // the induction variable because it helps a lot for certain cases; CodeGen // isn't smart enough to ignore the overflow, which leads to much less // efficient code if the width of the subtraction is wider than the native // register width. // // (It's possible to not widen at all by pulling out factors of 2 before // the multiplication; for example, K=2 can be calculated as // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires // extra arithmetic, so it's not an obvious win, and it gets // much more complicated for K > 3.) // Protection from insane SCEVs; this bound is conservative, // but it probably doesn't matter. if (K > 1000) return SE.getCouldNotCompute(); unsigned W = SE.getTypeSizeInBits(ResultTy); // Calculate K! / 2^T and T; we divide out the factors of two before // multiplying for calculating K! / 2^T to avoid overflow. // Other overflow doesn't matter because we only care about the bottom // W bits of the result. APInt OddFactorial(W, 1); unsigned T = 1; for (unsigned i = 3; i <= K; ++i) { APInt Mult(W, i); unsigned TwoFactors = Mult.countTrailingZeros(); T += TwoFactors; Mult = Mult.lshr(TwoFactors); OddFactorial *= Mult; } // We need at least W + T bits for the multiplication step unsigned CalculationBits = W + T; // Calculate 2^T, at width T+W. APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); // Calculate the multiplicative inverse of K! / 2^T; // this multiplication factor will perform the exact division by // K! / 2^T. APInt Mod = APInt::getSignedMinValue(W+1); APInt MultiplyFactor = OddFactorial.zext(W+1); MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); MultiplyFactor = MultiplyFactor.trunc(W); // Calculate the product, at width T+W IntegerType *CalculationTy = IntegerType::get(SE.getContext(), CalculationBits); const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } // Divide by 2^T const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); // Truncate the result, and divide by K! / 2^T. return SE.getMulExpr(SE.getConstant(MultiplyFactor), SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } /// Return the value of this chain of recurrences at the specified iteration /// number. We can evaluate this recurrence by multiplying each element in the /// chain by the binomial coefficient corresponding to it. In other words, we /// can evaluate {A,+,B,+,C,+,D} as: /// /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// /// where BC(It, k) stands for binomial coefficient. /// const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const { const SCEV *Result = getStart(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); if (isa(Coeff)) return Coeff; Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); } return Result; } //===----------------------------------------------------------------------===// // SCEV Expression folder implementations //===----------------------------------------------------------------------===// const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); FoldingSetNodeID ID; ID.AddInteger(scTruncate); ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( cast(ConstantExpr::getTrunc(SC->getValue(), Ty))); // trunc(trunc(x)) --> trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) return getTruncateExpr(ST->getOperand(), Ty); // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) return getTruncateOrSignExtend(SS->getOperand(), Ty); // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty); // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVAddExpr *SA = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); if (!isa(SA->getOperand(i))) hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) return getAddExpr(Operands); UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. } // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can // eliminate all the truncates, or we replace other casts with truncates. if (const SCEVMulExpr *SM = dyn_cast(Op)) { SmallVector Operands; bool hasTrunc = false; for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); if (!isa(SM->getOperand(i))) hasTrunc = isa(S); Operands.push_back(S); } if (!hasTrunc) return getMulExpr(Operands); UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. } // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { SmallVector Operands; for (const SCEV *Op : AddRec->operands()) Operands.push_back(getTruncateExpr(Op, Ty)); return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } // The cast wasn't folded; create an explicit cast node. We can reuse // the existing insert position since if we get here, we won't have // made any changes which would invalidate it. SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); return S; } // Get the limit of a recurrence such that incrementing by Step cannot cause // signed overflow as long as the value of the recurrence within the // loop does not exceed this limit before incrementing. static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); if (SE->isKnownPositive(Step)) { *Pred = ICmpInst::ICMP_SLT; return SE->getConstant(APInt::getSignedMinValue(BitWidth) - SE->getSignedRange(Step).getSignedMax()); } if (SE->isKnownNegative(Step)) { *Pred = ICmpInst::ICMP_SGT; return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - SE->getSignedRange(Step).getSignedMin()); } return nullptr; } // Get the limit of a recurrence such that incrementing by Step cannot cause // unsigned overflow as long as the value of the recurrence within the loop does // not exceed this limit before incrementing. static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); *Pred = ICmpInst::ICMP_ULT; return SE->getConstant(APInt::getMinValue(BitWidth) - SE->getUnsignedRange(Step).getUnsignedMax()); } namespace { struct ExtendOpTraitsBase { typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); }; // Used to make code generic over signed and unsigned overflow. template struct ExtendOpTraits { // Members present: // // static const SCEV::NoWrapFlags WrapType; // // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; // // static const SCEV *getOverflowLimitForStep(const SCEV *Step, // ICmpInst::Predicate *Pred, // ScalarEvolution *SE); }; template <> struct ExtendOpTraits : public ExtendOpTraitsBase { static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; static const GetExtendExprTy GetExtendExpr; static const SCEV *getOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE) { return getSignedOverflowLimitForStep(Step, Pred, SE); } }; const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; template <> struct ExtendOpTraits : public ExtendOpTraitsBase { static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; static const GetExtendExprTy GetExtendExpr; static const SCEV *getOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE) { return getUnsignedOverflowLimitForStep(Step, Pred, SE); } }; const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; } // The recurrence AR has been shown to have no signed/unsigned wrap or something // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as // easily prove NSW/NUW for its preincrement or postincrement sibling. This // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the // expression "Step + sext/zext(PreIncAR)" is congruent with // "sext/zext(PostIncAR)" template static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE) { auto WrapType = ExtendOpTraits::WrapType; auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; const Loop *L = AR->getLoop(); const SCEV *Start = AR->getStart(); const SCEV *Step = AR->getStepRecurrence(*SE); // Check for a simple looking step prior to loop entry. const SCEVAddExpr *SA = dyn_cast(Start); if (!SA) return nullptr; // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV // subtraction is expensive. For this purpose, perform a quick and dirty // difference, by checking for Step in the operand list. SmallVector DiffOps; for (const SCEV *Op : SA->operands()) if (Op != Step) DiffOps.push_back(Op); if (DiffOps.size() == SA->getNumOperands()) return nullptr; // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + // `Step`: // 1. NSW/NUW flags on the step increment. auto PreStartFlags = ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); // "{S,+,X} is /" and "the backedge is taken at least once" implies // "S+X does not sign/unsign-overflow". // const SCEV *BECount = SE->getBackedgeTakenCount(L); if (PreAR && PreAR->getNoWrapFlags(WrapType) && !isa(BECount) && SE->isKnownPositive(BECount)) return PreStart; // 2. Direct overflow check on the step operation's expression. unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); const SCEV *OperandExtendedStart = SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), (SE->*GetExtendExpr)(Step, WideTy)); if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { if (PreAR && AR->getNoWrapFlags(WrapType)) { // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. const_cast(PreAR)->setNoWrapFlags(WrapType); } return PreStart; } // 3. Loop precondition. ICmpInst::Predicate Pred; const SCEV *OverflowLimit = ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); if (OverflowLimit && SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) return PreStart; return nullptr; } // Get the normalized zero or sign extended expression for this AddRec's Start. template static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE) { auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE); if (!PreStart) return (SE->*GetExtendExpr)(AR->getStart(), Ty); return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), (SE->*GetExtendExpr)(PreStart, Ty)); } // Try to prove away overflow by looking at "nearby" add recurrences. A // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. // // Formally: // // {S,+,X} == {S-T,+,X} + T // => Ext({S,+,X}) == Ext({S-T,+,X} + T) // // If ({S-T,+,X} + T) does not overflow ... (1) // // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) // // If {S-T,+,X} does not overflow ... (2) // // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) // == {Ext(S-T)+Ext(T),+,Ext(X)} // // If (S-T)+T does not overflow ... (3) // // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} // == {Ext(S),+,Ext(X)} == LHS // // Thus, if (1), (2) and (3) are true for some T, then // Ext({S,+,X}) == {Ext(S),+,Ext(X)} // // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) // does not overflow" restricted to the 0th iteration. Therefore we only need // to check for (1) and (2). // // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T // is `Delta` (defined below). // template bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step, const Loop *L) { auto WrapType = ExtendOpTraits::WrapType; // We restrict `Start` to a constant to prevent SCEV from spending too much // time here. It is correct (but more expensive) to continue with a // non-constant `Start` and do a general SCEV subtraction to compute // `PreStart` below. // const SCEVConstant *StartC = dyn_cast(Start); if (!StartC) return false; APInt StartAI = StartC->getAPInt(); for (unsigned Delta : {-2, -1, 1, 2}) { const SCEV *PreStart = getConstant(StartAI - Delta); FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); ID.AddPointer(PreStart); ID.AddPointer(Step); ID.AddPointer(L); void *IP = nullptr; const auto *PreAR = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) const SCEV *DeltaS = getConstant(StartC->getType(), Delta); ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; const SCEV *Limit = ExtendOpTraits::getOverflowLimitForStep( DeltaS, &Pred, this); if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) return true; } } return false; } const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( cast(ConstantExpr::getZExt(SC->getValue(), Ty))); // zext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getZeroExtendExpr(SZ->getOperand(), Ty); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scZeroExtend); ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // zext(trunc(x)) --> zext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all zero bits. If // so, we should be able to simplify this further. const SCEV *X = ST->getOperand(); ConstantRange CR = getUnsignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( CR.zextOrTrunc(NewBits))) return getTruncateOrZeroExtend(X, Ty); } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can zero extend all of the // operands (often constants). This allows analysis of something like // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { const SCEV *Start = AR->getStart(); const SCEV *Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); if (!AR->hasNoUnsignedWrap()) { auto NewFlags = proveNoWrapViaConstantRanges(AR); const_cast(AR)->setNoWrapFlags(NewFlags); } // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->hasNoUnsignedWrap()) return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is // being called from within backedge-taken count analysis, such that // attempting to ask for the backedge-taken count would likely result // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. const SCEV *CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); const SCEV *WideMaxBECount = getZeroExtendExpr(CastedMaxBECount, WideTy); const SCEV *OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, getSignExtendExpr(Step, WideTy))); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NW, which is propagated to this AddRec. // Negative step causes unsigned wrap, but it still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge // taken count for the loop. The exceptions are assumptions and // guards present in the loop -- SCEV is not great at exploiting // these to compute max backedge taken counts, but can still use // these to prove lack of overflow. Use this fact to avoid // doing extra work that may not pay off. if (!isa(MaxBECount) || HasGuards || !AC.assumptions().empty()) { // If the backedge is guarded by a comparison with the pre-inc // value the addrec is safe. Also, if the entry is guarded by // a comparison with the start value and the backedge is // guarded by a comparison with the post-inc value, the addrec // is safe. if (isKnownPositive(Step)) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - getUnsignedRange(Step).getUnsignedMax()); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR->getPostIncExpr(*this), N))) { // Cache knowledge of AR NUW, which is propagated to this // AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRange(Step).getSignedMin()); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR->getPostIncExpr(*this), N))) { // Cache knowledge of AR NW, which is propagated to this // AddRec. Negative step causes unsigned wrap, but it // still can't self-wrap. const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNUW); return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } if (auto *SA = dyn_cast(Op)) { // zext((A + B + ...)) --> (zext(A) + zext(B) + ...) if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) Ops.push_back(getZeroExtendExpr(Op, Ty)); return getAddExpr(Ops, SCEV::FlagNUW); } } // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); return S; } const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( cast(ConstantExpr::getSExt(SC->getValue(), Ty))); // sext(sext(x)) --> sext(x) if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) return getSignExtendExpr(SS->getOperand(), Ty); // sext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getZeroExtendExpr(SZ->getOperand(), Ty); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scSignExtend); ID.AddPointer(Op); ID.AddPointer(Ty); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If // so, we should be able to simplify this further. const SCEV *X = ST->getOperand(); ConstantRange CR = getSignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); if (CR.truncate(TruncBits).signExtend(NewBits).contains( CR.sextOrTrunc(NewBits))) return getTruncateOrSignExtend(X, Ty); } // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 if (auto *SA = dyn_cast(Op)) { if (SA->getNumOperands() == 2) { auto *SC1 = dyn_cast(SA->getOperand(0)); auto *SMul = dyn_cast(SA->getOperand(1)); if (SMul && SC1) { if (auto *SC2 = dyn_cast(SMul->getOperand(0))) { const APInt &C1 = SC1->getAPInt(); const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) return getAddExpr(getSignExtendExpr(SC1, Ty), getSignExtendExpr(SMul, Ty)); } } } // sext((A + B + ...)) --> (sext(A) + sext(B) + ...) if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. SmallVector Ops; for (const auto *Op : SA->operands()) Ops.push_back(getSignExtendExpr(Op, Ty)); return getAddExpr(Ops, SCEV::FlagNSW); } } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the // operands (often constants). This allows analysis of something like // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { const SCEV *Start = AR->getStart(); const SCEV *Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); if (!AR->hasNoSignedWrap()) { auto NewFlags = proveNoWrapViaConstantRanges(AR); const_cast(AR)->setNoWrapFlags(NewFlags); } // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->hasNoSignedWrap()) return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is // being called from within backedge-taken count analysis, such that // attempting to ask for the backedge-taken count would likely result // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. const SCEV *CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); const SCEV *WideStart = getSignExtendExpr(Start, WideTy); const SCEV *WideMaxBECount = getZeroExtendExpr(CastedMaxBECount, WideTy); const SCEV *OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, getSignExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. OperandExtendedAdd = getAddExpr(WideStart, getMulExpr(WideMaxBECount, getZeroExtendExpr(Step, WideTy))); if (SAdd == OperandExtendedAdd) { // If AR wraps around then // // abs(Step) * MaxBECount > unsigned-max(AR->getType()) // => SAdd != OperandExtendedAdd // // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> // (SAdd == OperandExtendedAdd => AR is NW) const_cast(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } } // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge // taken count for the loop. The exceptions are assumptions and // guards present in the loop -- SCEV is not great at exploiting // these to compute max backedge taken counts, but can still use // these to prove lack of overflow. Use this fact to avoid // doing extra work that may not pay off. if (!isa(MaxBECount) || HasGuards || !AC.assumptions().empty()) { // If the backedge is guarded by a comparison with the pre-inc // value the addrec is safe. Also, if the entry is guarded by // a comparison with the start value and the backedge is // guarded by a comparison with the post-inc value, the addrec // is safe. ICmpInst::Predicate Pred; const SCEV *OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this), OverflowLimit)))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } // If Start and Step are constants, check if we can apply this // transformation: // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 auto *SC1 = dyn_cast(Start); auto *SC2 = dyn_cast(Step); if (SC1 && SC2) { const APInt &C1 = SC1->getAPInt(); const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { Start = getSignExtendExpr(Start, Ty); const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, AR->getNoWrapFlags()); return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); } } if (proveNoWrapByVaryingStart(Start, Step, L)) { const_cast(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( getExtendAddRecStart(AR, Ty, this), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } // If the input value is provably positive and we could not simplify // away the sext build a zext instead. if (isKnownNonNegative(Op)) return getZeroExtendExpr(Op, Ty); // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); return S; } /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. /// const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); // Sign-extend negative constants. if (const SCEVConstant *SC = dyn_cast(Op)) if (SC->getAPInt().isNegative()) return getSignExtendExpr(Op, Ty); // Peel off a truncate cast. if (const SCEVTruncateExpr *T = dyn_cast(Op)) { const SCEV *NewOp = T->getOperand(); if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) return getAnyExtendExpr(NewOp, Ty); return getTruncateOrNoop(NewOp, Ty); } // Next try a zext cast. If the cast is folded, use it. const SCEV *ZExt = getZeroExtendExpr(Op, Ty); if (!isa(ZExt)) return ZExt; // Next try a sext cast. If the cast is folded, use it. const SCEV *SExt = getSignExtendExpr(Op, Ty); if (!isa(SExt)) return SExt; // Force the cast to be folded into the operands of an addrec. if (const SCEVAddRecExpr *AR = dyn_cast(Op)) { SmallVector Ops; for (const SCEV *Op : AR->operands()) Ops.push_back(getAnyExtendExpr(Op, Ty)); return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); } // If the expression is obviously signed, use the sext cast value. if (isa(Op)) return SExt; // Absent any other information, use the zext cast value. return ZExt; } /// Process the given Ops list, which is a list of operands to be added under /// the given scale, update the given map. This is a helper function for /// getAddRecExpr. As an example of what it does, given a sequence of operands /// that would form an add expression like this: /// /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) /// /// where A and B are constants, update the map with these values: /// /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) /// /// and add 13 + A*B*29 to AccumulatedConstant. /// This will allow getAddRecExpr to produce this: /// /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) /// /// This form often exposes folding opportunities that are hidden in /// the original operand list. /// /// Return true iff it appears that any interesting folding opportunities /// may be exposed. This helps getAddRecExpr short-circuit extra work in /// the common case where no interesting opportunities are present, and /// is also used as a check to avoid infinite recursion. /// static bool CollectAddOperandsWithScales(DenseMap &M, SmallVectorImpl &NewOps, APInt &AccumulatedConstant, const SCEV *const *Ops, size_t NumOperands, const APInt &Scale, ScalarEvolution &SE) { bool Interesting = false; // Iterate over the add operands. They are sorted, with constants first. unsigned i = 0; while (const SCEVConstant *C = dyn_cast(Ops[i])) { ++i; // Pull a buried constant out to the outside. if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) Interesting = true; AccumulatedConstant += Scale * C->getAPInt(); } // Next comes everything else. We're especially interested in multiplies // here, but they're in the middle, so just visit the rest with one loop. for (; i != NumOperands; ++i) { const SCEVMulExpr *Mul = dyn_cast(Ops[i]); if (Mul && isa(Mul->getOperand(0))) { APInt NewScale = Scale * cast(Mul->getOperand(0))->getAPInt(); if (Mul->getNumOperands() == 2 && isa(Mul->getOperand(1))) { // A multiplication of a constant with another add; recurse. const SCEVAddExpr *Add = cast(Mul->getOperand(1)); Interesting |= CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Add->op_begin(), Add->getNumOperands(), NewScale, SE); } else { // A multiplication of a constant with some other value. Update // the map. SmallVector MulOps(Mul->op_begin()+1, Mul->op_end()); const SCEV *Key = SE.getMulExpr(MulOps); auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { Pair.first->second += NewScale; // The map already had an entry for this value, which may indicate // a folding opportunity. Interesting = true; } } } else { // An ordinary operand. Update the map. std::pair::iterator, bool> Pair = M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { Pair.first->second += Scale; // The map already had an entry for this value, which may indicate // a folding opportunity. Interesting = true; } } } return Interesting; } // We're trying to construct a SCEV of type `Type' with `Ops' as operands and // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of // can't-overflow flags for the operation if possible. static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags) { using namespace std::placeholders; typedef OverflowingBinaryOperator OBO; bool CanAnalyze = Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; (void)CanAnalyze; assert(CanAnalyze && "don't call from other places!"); int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; SCEV::NoWrapFlags SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. auto IsKnownNonNegative = [&](const SCEV *S) { return SE->isKnownNonNegative(S); }; if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) Flags = ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr && Ops.size() == 2 && isa(Ops[0])) { // (A + C) --> (A + C) if the addition does not sign overflow // (A + C) --> (A + C) if the addition does not unsign overflow const APInt &C = cast(Ops[0])->getAPInt(); if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); } if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); } } return Flags; } /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVAddExpr operand types don't match!"); #endif // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { ++Idx; assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); if (Ops.size() == 2) return Ops[0]; Ops.erase(Ops.begin()+1); // Erase the folded element LHSC = cast(Ops[0]); } // If we are left with a constant zero being added, strip it off. if (LHSC->getValue()->isZero()) { Ops.erase(Ops.begin()); --Idx; } if (Ops.size() == 1) return Ops[0]; } // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. Type *Ty = Ops[0]->getType(); bool FoundMatch = false; for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 // Scan ahead to count how many equal operands there are. unsigned Count = 2; while (i+Count != e && Ops[i+Count] == Ops[i]) ++Count; // Merge the values into a multiply. const SCEV *Scale = getConstant(Ty, Count); const SCEV *Mul = getMulExpr(Scale, Ops[i]); if (Ops.size() == Count) return Mul; Ops[i] = Mul; Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); --i; e -= Count - 1; FoundMatch = true; } if (FoundMatch) return getAddExpr(Ops, Flags); // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) // if the contents of the resulting outer trunc fold to something simple. for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); Type *DstType = Trunc->getType(); Type *SrcType = Trunc->getOperand()->getType(); SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. for (unsigned i = 0, e = Ops.size(); i != e; ++i) { if (const SCEVTruncateExpr *T = dyn_cast(Ops[i])) { if (T->getOperand()->getType() != SrcType) { Ok = false; break; } LargeOps.push_back(T->getOperand()); } else if (const SCEVConstant *C = dyn_cast(Ops[i])) { LargeOps.push_back(getAnyExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast(Ops[i])) { SmallVector LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast(M->getOperand(j))) { if (T->getOperand()->getType() != SrcType) { Ok = false; break; } LargeMulOps.push_back(T->getOperand()); } else if (const auto *C = dyn_cast(M->getOperand(j))) { LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); } else { Ok = false; break; } } if (Ok) LargeOps.push_back(getMulExpr(LargeMulOps)); } else { Ok = false; break; } } if (Ok) { // Evaluate the expression in the larger type. const SCEV *Fold = getAddExpr(LargeOps, Flags); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); } } // Skip past any other cast SCEVs. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) ++Idx; // If there are add operands they would be next. if (Idx < Ops.size()) { bool DeletedAdd = false; while (const SCEVAddExpr *Add = dyn_cast(Ops[Idx])) { // If we have an add, expand the add operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); Ops.append(Add->op_begin(), Add->op_end()); DeletedAdd = true; } // If we deleted at least one add, we added operands to the end of the list, // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedAdd) return getAddExpr(Ops); } // Skip over the add expression until we get to a multiply. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) ++Idx; // Check to see if there are any folding opportunities present with // operands multiplied by constant values. if (Idx < Ops.size() && isa(Ops[Idx])) { uint64_t BitWidth = getTypeSizeInBits(Ty); DenseMap M; SmallVector NewOps; APInt AccumulatedConstant(BitWidth, 0); if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops.data(), Ops.size(), APInt(BitWidth, 1), *this)) { struct APIntCompare { bool operator()(const APInt &LHS, const APInt &RHS) const { return LHS.ult(RHS); } }; // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. std::map, APIntCompare> MulOpLists; for (const SCEV *NewOp : NewOps) MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); if (AccumulatedConstant != 0) Ops.push_back(getConstant(AccumulatedConstant)); for (auto &MulOp : MulOpLists) if (MulOp.first != 0) Ops.push_back(getMulExpr(getConstant(MulOp.first), getAddExpr(MulOp.second))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) return Ops[0]; return getAddExpr(Ops); } } // If we are adding something to a multiply expression, make sure the // something is not already an operand of the multiply. If so, merge it into // the multiply. for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { const SCEVMulExpr *Mul = cast(Ops[Idx]); for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { const SCEV *MulOpSCEV = Mul->getOperand(MulOp); if (isa(MulOpSCEV)) continue; for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) if (MulOpSCEV == Ops[AddOp]) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) const SCEV *InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. SmallVector MulOps(Mul->op_begin(), Mul->op_begin()+MulOp); MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } const SCEV *One = getOne(Ty); const SCEV *AddOne = getAddExpr(One, InnerMul); const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); Ops.erase(Ops.begin()+Idx-1); } else { Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); return getAddExpr(Ops); } // Check this multiply against other multiplies being added together. for (unsigned OtherMulIdx = Idx+1; OtherMulIdx < Ops.size() && isa(Ops[OtherMulIdx]); ++OtherMulIdx) { const SCEVMulExpr *OtherMul = cast(Ops[OtherMulIdx]); // If MulOp occurs in OtherMul, we can fold the two multiplies // together. for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); OMulOp != e; ++OMulOp) if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { SmallVector MulOps(Mul->op_begin(), Mul->op_begin()+MulOp); MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul1 = getMulExpr(MulOps); } const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { SmallVector MulOps(OtherMul->op_begin(), OtherMul->op_begin()+OMulOp); MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); InnerMul2 = getMulExpr(MulOps); } const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); return getAddExpr(Ops); } } } } // If there are any add recurrences in the operands list, see if any other // added values are loop invariant. If so, we can fold them into the // recurrence. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) ++Idx; // Scan over all recurrences, trying to fold loop invariants into them. for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (isLoopInvariant(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; } // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); SmallVector AddRecOps(AddRec->op_begin(), AddRec->op_end()); // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. AddRecOps[0] = getAddExpr(LIOps, Flags); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. // Always propagate NW. Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; // Otherwise, add the folded AddRec by the non-invariant parts. for (unsigned i = 0;; ++i) if (Ops[i] == AddRec) { Ops[i] = NewRec; break; } return getAddExpr(Ops); } // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // added together. If so, we can fold them. for (unsigned OtherIdx = Idx+1; OtherIdx < Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} SmallVector AddRecOps(AddRec->op_begin(), AddRec->op_end()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) if (const auto *OtherAddRec = dyn_cast(Ops[OtherIdx])) if (OtherAddRec->getLoop() == AddRecLoop) { for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { if (i >= AddRecOps.size()) { AddRecOps.append(OtherAddRec->op_begin()+i, OtherAddRec->op_end()); break; } AddRecOps[i] = getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i)); } Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; } // Step size has changed, so we cannot guarantee no self-wraparound. Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); return getAddExpr(Ops); } // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. } // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVAddExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); return S; } static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { uint64_t k = i*j; if (j > 1 && k / j != i) Overflow = true; return k; } /// Compute the result of "n choose k", the binomial coefficient. If an /// intermediate computation overflows, Overflow will be set and the return will /// be garbage. Overflow is not cleared on absence of overflow. static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { // We use the multiplicative formula: // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . // At each iteration, we take the n-th term of the numeral and divide by the // (k-n)th term of the denominator. This division will always produce an // integral result, and helps reduce the chance of overflow in the // intermediate computations. However, we can still overflow even when the // final result would fit. if (n == 0 || n == k) return 1; if (k > n) return 0; if (k > n/2) k = n-k; uint64_t r = 1; for (uint64_t i = 1; i <= k; ++i) { r = umul_ov(r, n-(i-1), Overflow); r /= i; } return r; } /// Determine if any of the operands in this SCEV are a constant or if /// any of the add or multiply expressions in this SCEV contain a constant. static bool containsConstantSomewhere(const SCEV *StartExpr) { SmallVector Ops; Ops.push_back(StartExpr); while (!Ops.empty()) { const SCEV *CurrentExpr = Ops.pop_back_val(); if (isa(*CurrentExpr)) return true; if (isa(*CurrentExpr) || isa(*CurrentExpr)) { const auto *CurrentNAry = cast(CurrentExpr); Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); } } return false; } /// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags) { assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVMulExpr operand types don't match!"); #endif // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) // If any of Add's ops are Adds or Muls with a constant, // apply this transformation as well. if (Add->getNumOperands() == 2) if (containsConstantSomewhere(Add)) return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! ConstantInt *Fold = ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; LHSC = cast(Ops[0]); } // If we are left with a constant one being multiplied, strip it off. if (cast(Ops[0])->getValue()->equalsInt(1)) { Ops.erase(Ops.begin()); --Idx; } else if (cast(Ops[0])->getValue()->isZero()) { // If we have a multiply of zero, it will always be zero. return Ops[0]; } else if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the // add operands. if (Ops.size() == 2) { if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { SmallVector NewOps; bool AnyFolded = false; for (const SCEV *AddOp : Add->operands()) { const SCEV *Mul = getMulExpr(Ops[0], AddOp); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); } if (AnyFolded) return getAddExpr(NewOps); } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. SmallVector Operands; for (const SCEV *AddRecOp : AddRec->operands()) Operands.push_back(getMulExpr(Ops[0], AddRecOp)); return getAddRecExpr(Operands, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); } } } if (Ops.size() == 1) return Ops[0]; } // Skip over the add expression until we get to a multiply. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) ++Idx; // If there are mul operands inline them all into this expression. if (Idx < Ops.size()) { bool DeletedMul = false; while (const SCEVMulExpr *Mul = dyn_cast(Ops[Idx])) { // If we have an mul, expand the mul operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); Ops.append(Mul->op_begin(), Mul->op_end()); DeletedMul = true; } // If we deleted at least one mul, we added operands to the end of the list, // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedMul) return getMulExpr(Ops); } // If there are any add recurrences in the operands list, see if any other // added values are loop invariant. If so, we can fold them into the // recurrence. while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) ++Idx; // Scan over all recurrences, trying to fold loop invariants into them. for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector if // they are loop invariant w.r.t. the recurrence. SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (isLoopInvariant(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; } // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); const SCEV *Scale = getMulExpr(LIOps); for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer mul and the inner addrec are guaranteed to have no overflow. // // No self-wrap cannot be guaranteed after changing the step size, but // will be inferred if either NUW or NSW is true. Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; // Otherwise, multiply the folded AddRec by the non-invariant parts. for (unsigned i = 0;; ++i) if (Ops[i] == AddRec) { Ops[i] = NewRec; break; } return getMulExpr(Ops); } // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn} // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z // ]]],+,...up to x=2n}. // Note that the arguments to choose() are always integers with values // known at compile time, never SCEV objects. // // The implementation avoids pointless extra computations when the two // addrec's are of different length (mathematically, it's equivalent to // an infinite stream of zeros on the right). bool OpsModified = false; for (unsigned OtherIdx = Idx+1; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) { const SCEVAddRecExpr *OtherAddRec = dyn_cast(Ops[OtherIdx]); if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; bool Overflow = false; Type *Ty = AddRec->getType(); bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; SmallVector AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { const SCEV *Term = getZero(Ty); for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); z < ze && !Overflow; ++z) { uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); uint64_t Coeff; if (LargerThan64Bits) Coeff = umul_ov(Coeff1, Coeff2, Overflow); else Coeff = Coeff1*Coeff2; const SCEV *CoeffTerm = getConstant(Ty, Coeff); const SCEV *Term1 = AddRec->getOperand(y-z); const SCEV *Term2 = OtherAddRec->getOperand(z); Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); } } AddRecOps.push_back(Term); } if (!Overflow) { const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), SCEV::FlagAnyWrap); if (Ops.size() == 2) return NewAddRec; Ops[Idx] = NewAddRec; Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; OpsModified = true; AddRec = dyn_cast(NewAddRec); if (!AddRec) break; } } if (OpsModified) return getMulExpr(Ops); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. } // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scMulExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVMulExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); return S; } /// Get a canonical unsigned division expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) return LHS; // X udiv 1 --> x // If the denominator is zero, the result of the udiv is undefined. Don't // try to analyze it, because the resolution chosen here may differ from // the resolution chosen in other parts of the compiler. if (!RHSC->getValue()->isZero()) { // Determine if the division can be folded into the operands of // its operands. // TODO: Generalize this to non-constants by using known-bits information. Type *Ty = LHS->getType(); unsigned LZ = RHSC->getAPInt().countLeadingZeros(); unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; // For non-power-of-two values, effectively round the value up to the // nearest power of two. if (!RHSC->getAPInt().isPowerOf2()) ++MaxShiftAmt; IntegerType *ExtTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) if (const SCEVConstant *Step = dyn_cast(AR->getStepRecurrence(*this))) { // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. const APInt &StepInt = Step->getAPInt(); const APInt &DivInt = RHSC->getAPInt(); if (!StepInt.urem(DivInt) && getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { SmallVector Operands; for (const SCEV *Op : AR->operands()) Operands.push_back(getUDivExpr(Op, RHS)); return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } /// Get a canonical UDivExpr for a recurrence. /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. // We can currently only fold X%N if X is constant. const SCEVConstant *StartC = dyn_cast(AR->getStart()); if (StartC && !DivInt.urem(StepInt) && getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); if (StartRem != 0) LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, AR->getLoop(), SCEV::FlagNW); } } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { SmallVector Operands; for (const SCEV *Op : M->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { const SCEV *Op = M->getOperand(i); const SCEV *Div = getUDivExpr(Op, RHSC); if (!isa(Div) && getMulExpr(Div, RHSC) == Op) { Operands = SmallVector(M->op_begin(), M->op_end()); Operands[i] = Div; return getMulExpr(Operands); } } } // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast(LHS)) { SmallVector Operands; for (const SCEV *Op : A->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); if (isa(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) break; Operands.push_back(Op); } if (Operands.size() == A->getNumOperands()) return getAddExpr(Operands); } } // Fold if both operands are constant. if (const SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); return getConstant(cast(ConstantExpr::getUDiv(LHSCV, RHSCV))); } } } FoldingSetNodeID ID; ID.AddInteger(scUDivExpr); ID.AddPointer(LHS); ID.AddPointer(RHS); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); return S; } static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { APInt A = C1->getAPInt().abs(); APInt B = C2->getAPInt().abs(); uint32_t ABW = A.getBitWidth(); uint32_t BBW = B.getBitWidth(); if (ABW > BBW) B = B.zext(ABW); else if (ABW < BBW) A = A.zext(BBW); return APIntOps::GreatestCommonDivisor(A, B); } /// Get a canonical unsigned division expression, or something simpler if /// possible. There is no representation for an exact udiv in SCEV IR, but we /// can attempt to remove factors from the LHS and RHS. We can't do this when /// it's not exact because the udiv may be clearing bits. const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, const SCEV *RHS) { // TODO: we could try to find factors in all sorts of things, but for now we // just deal with u/exact (multiply, constant). See SCEVDivision towards the // end of this file for inspiration. const SCEVMulExpr *Mul = dyn_cast(LHS); if (!Mul) return getUDivExpr(LHS, RHS); if (const SCEVConstant *RHSCst = dyn_cast(RHS)) { // If the mulexpr multiplies by a constant, then that constant must be the // first element of the mulexpr. if (const auto *LHSCst = dyn_cast(Mul->getOperand(0))) { if (LHSCst == RHSCst) { SmallVector Operands; Operands.append(Mul->op_begin() + 1, Mul->op_end()); return getMulExpr(Operands); } // We can't just assume that LHSCst divides RHSCst cleanly, it could be // that there's a factor provided by one of the other terms. We need to // check. APInt Factor = gcd(LHSCst, RHSCst); if (!Factor.isIntN(1)) { LHSCst = cast(getConstant(LHSCst->getAPInt().udiv(Factor))); RHSCst = cast(getConstant(RHSCst->getAPInt().udiv(Factor))); SmallVector Operands; Operands.push_back(LHSCst); Operands.append(Mul->op_begin() + 1, Mul->op_end()); LHS = getMulExpr(Operands); RHS = RHSCst; Mul = dyn_cast(LHS); if (!Mul) return getUDivExactExpr(LHS, RHS); } } } for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { if (Mul->getOperand(i) == RHS) { SmallVector Operands; Operands.append(Mul->op_begin(), Mul->op_begin() + i); Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); return getMulExpr(Operands); } } return getUDivExpr(LHS, RHS); } /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags) { SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { Operands.append(StepChrec->op_begin(), StepChrec->op_end()); return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); } Operands.push_back(Step); return getAddRecExpr(Operands, L, Flags); } /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, const Loop *L, SCEV::NoWrapFlags Flags) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); for (unsigned i = 1, e = Operands.size(); i != e; ++i) assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && "SCEVAddRecExpr operand types don't match!"); for (unsigned i = 0, e = Operands.size(); i != e; ++i) assert(isLoopInvariant(Operands[i], L) && "SCEVAddRecExpr operand is not loop-invariant!"); #endif if (Operands.back()->isZero()) { Operands.pop_back(); return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X } // It's tempting to want to call getMaxBackedgeTakenCount count here and // use that information to infer NUW and NSW flags. However, computing a // BE count requires calling getAddRecExpr, so we may not yet have a // meaningful BE count at this point (and if we don't, we'd be stuck // with a SCEVCouldNotCompute as the cached BE count). Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { const Loop *NestedLoop = NestedAR->getLoop(); if (L->contains(NestedLoop) ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) : (!NestedLoop->contains(L) && DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { SmallVector NestedOperands(NestedAR->op_begin(), NestedAR->op_end()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. bool AllInvariant = all_of( Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. // // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the // inner recurrence has the same property. SCEV::NoWrapFlags OuterFlags = maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); }); if (AllInvariant) { // Ok, both add recurrences are valid after the transformation. // // The inner recurrence keeps its NW flag but only keeps NUW/NSW if // the outer recurrence has the same property. SCEV::NoWrapFlags InnerFlags = maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); } } // Reset Operands to its original state. Operands[0] = NestedAR; } } // Okay, it looks like we really DO need an addrec expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); for (unsigned i = 0, e = Operands.size(); i != e; ++i) ID.AddPointer(Operands[i]); ID.AddPointer(L); void *IP = nullptr; SCEVAddRecExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate(Operands.size()); std::uninitialized_copy(Operands.begin(), Operands.end(), O); S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Operands.size(), L); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); return S; } const SCEV * ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, const SmallVectorImpl &IndexExprs, bool InBounds) { // getSCEV(Base)->getType() has the same address space as Base->getType() // because SCEV::getType() preserves the address space. Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP // instruction to its SCEV, because the Instruction may be guarded by control // flow and the no-overflow bits may not be valid for the expression in any // context. This can be fixed similarly to how these flags are handled for // adds. SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; const SCEV *TotalOffset = getZero(IntPtrTy); // The address space is unimportant. The first thing we do on CurTy is getting // its element type. Type *CurTy = PointerType::getUnqual(PointeeType); for (const SCEV *IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast(CurTy)) { // For a struct, add the member offset. ConstantInt *Index = cast(IndexExpr)->getValue(); unsigned FieldNo = Index->getZExtValue(); const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); // Add the field offset to the running total offset. TotalOffset = getAddExpr(TotalOffset, FieldOffset); // Update CurTy to the type of the field at Index. CurTy = STy->getTypeAtIndex(Index); } else { // Update CurTy to its element type. CurTy = cast(CurTy)->getElementType(); // For an array, add the element offset, explicitly scaled. const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); // Getelementptr indices are signed. IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); // Multiply the index by the element size to compute the element offset. const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); // Add the element offset to the running total offset. TotalOffset = getAddExpr(TotalOffset, LocalOffset); } } // Add the total offset from all the GEP indices to the base. return getAddExpr(BaseExpr, TotalOffset, Wrap); } const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector Ops = {LHS, RHS}; return getSMaxExpr(Ops); } const SCEV * ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVSMaxExpr operand types don't match!"); #endif // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { ++Idx; assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! ConstantInt *Fold = ConstantInt::get( getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; LHSC = cast(Ops[0]); } // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(true)) { Ops.erase(Ops.begin()); --Idx; } else if (cast(Ops[0])->getValue()->isMaxValue(true)) { // If we have an smax with a constant maximum-int, it will always be // maximum-int. return Ops[0]; } if (Ops.size() == 1) return Ops[0]; } // Find the first SMax while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) ++Idx; // Check to see if one of the operands is an SMax. If so, expand its operands // onto our operand list, and recurse to simplify. if (Idx < Ops.size()) { bool DeletedSMax = false; while (const SCEVSMaxExpr *SMax = dyn_cast(Ops[Idx])) { Ops.erase(Ops.begin()+Idx); Ops.append(SMax->op_begin(), SMax->op_end()); DeletedSMax = true; } if (DeletedSMax) return getSMaxExpr(Ops); } // Okay, check to see if the same value occurs in the operand list twice. If // so, delete one. Since we sorted the list, these values are required to // be adjacent. for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) // X smax Y smax Y --> X smax Y // X smax Y --> X, if X is always greater than Y if (Ops[i] == Ops[i+1] || isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) { Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); --i; --e; } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) { Ops.erase(Ops.begin()+i, Ops.begin()+i+1); --i; --e; } if (Ops.size() == 1) return Ops[0]; assert(!Ops.empty() && "Reduced smax down to nothing!"); // Okay, it looks like we really DO need an smax expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scSMaxExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); return S; } const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { SmallVector Ops = {LHS, RHS}; return getUMaxExpr(Ops); } const SCEV * ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty umax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVUMaxExpr operand types don't match!"); #endif // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { ++Idx; assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! ConstantInt *Fold = ConstantInt::get( getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; LHSC = cast(Ops[0]); } // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(false)) { Ops.erase(Ops.begin()); --Idx; } else if (cast(Ops[0])->getValue()->isMaxValue(false)) { // If we have an umax with a constant maximum-int, it will always be // maximum-int. return Ops[0]; } if (Ops.size() == 1) return Ops[0]; } // Find the first UMax while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) ++Idx; // Check to see if one of the operands is a UMax. If so, expand its operands // onto our operand list, and recurse to simplify. if (Idx < Ops.size()) { bool DeletedUMax = false; while (const SCEVUMaxExpr *UMax = dyn_cast(Ops[Idx])) { Ops.erase(Ops.begin()+Idx); Ops.append(UMax->op_begin(), UMax->op_end()); DeletedUMax = true; } if (DeletedUMax) return getUMaxExpr(Ops); } // Okay, check to see if the same value occurs in the operand list twice. If // so, delete one. Since we sorted the list, these values are required to // be adjacent. for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) // X umax Y umax Y --> X umax Y // X umax Y --> X, if X is always greater than Y if (Ops[i] == Ops[i+1] || isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) { Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); --i; --e; } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) { Ops.erase(Ops.begin()+i, Ops.begin()+i+1); --i; --e; } if (Ops.size() == 1) return Ops[0]; assert(!Ops.empty() && "Reduced umax down to nothing!"); // Okay, it looks like we really DO need a umax expr. Check to see if we // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(scUMaxExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; const SCEV **O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); return S; } const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, const SCEV *RHS) { // ~smax(~x, ~y) == smin(x, y). return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS) { // ~umax(~x, ~y) == umin(x, y) return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo) { // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. return getConstant( IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { // Don't attempt to do anything other than create a SCEVUnknown object // here. createSCEV only calls getUnknown after checking for all other // interesting possibilities, and any other code that calls getUnknown // is doing so in order to hide a value from SCEV canonicalization. FoldingSetNodeID ID; ID.AddInteger(scUnknown); ID.AddPointer(V); void *IP = nullptr; if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { assert(cast(S)->getValue() == V && "Stale SCEVUnknown in uniquing map!"); return S; } SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, FirstUnknown); FirstUnknown = cast(S); UniqueSCEVs.InsertNode(S, IP); return S; } //===----------------------------------------------------------------------===// // Basic SCEV Analysis and PHI Idiom Recognition Code // /// Test if values of the given type are analyzable within the SCEV /// framework. This primarily includes integer types, and it can optionally /// include pointer types if the ScalarEvolution class has access to /// target-specific information. bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. return Ty->isIntegerTy() || Ty->isPointerTy(); } /// Return the size in bits of the specified type, for which isSCEVable must /// return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); return getDataLayout().getTypeSizeInBits(Ty); } /// Return a type with the same bitwidth as the given type and which represents /// how SCEV will treat the given type, for which isSCEVable must return /// true. For pointer types, this is the pointer-sized integer type. Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); if (Ty->isIntegerTy()) return Ty; // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); return getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } bool ScalarEvolution::checkValidity(const SCEV *S) const { // Helper class working with SCEVTraversal to figure out if a SCEV contains // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne // is set iff if find such SCEVUnknown. // struct FindInvalidSCEVUnknown { bool FindOne; FindInvalidSCEVUnknown() { FindOne = false; } bool follow(const SCEV *S) { switch (static_cast(S->getSCEVType())) { case scConstant: return false; case scUnknown: if (!cast(S)->getValue()) FindOne = true; return false; default: return true; } } bool isDone() const { return FindOne; } }; FindInvalidSCEVUnknown F; SCEVTraversal ST(F); ST.visitAll(S); return !F.FindOne; } namespace { // Helper class working with SCEVTraversal to figure out if a SCEV contains // a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set // iff if such sub scAddRecExpr type SCEV is found. struct FindAddRecurrence { bool FoundOne; FindAddRecurrence() : FoundOne(false) {} bool follow(const SCEV *S) { switch (static_cast(S->getSCEVType())) { case scAddRecExpr: FoundOne = true; case scConstant: case scUnknown: case scCouldNotCompute: return false; default: return true; } } bool isDone() const { return FoundOne; } }; } bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { HasRecMapType::iterator I = HasRecMap.find_as(S); if (I != HasRecMap.end()) return I->second; FindAddRecurrence F; SCEVTraversal ST(F); ST.visitAll(S); HasRecMap.insert({S, F.FoundOne}); return F.FoundOne; } -/// Return the Value set from S. -SetVector *ScalarEvolution::getSCEVValues(const SCEV *S) { +/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. +/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an +/// offset I, then return {S', I}, else return {\p S, nullptr}. +static std::pair splitAddExpr(const SCEV *S) { + const auto *Add = dyn_cast(S); + if (!Add) + return {S, nullptr}; + + if (Add->getNumOperands() != 2) + return {S, nullptr}; + + auto *ConstOp = dyn_cast(Add->getOperand(0)); + if (!ConstOp) + return {S, nullptr}; + + return {Add->getOperand(1), ConstOp->getValue()}; +} + +/// Return the ValueOffsetPair set for \p S. \p S can be represented +/// by the value and offset from any ValueOffsetPair in the set. +SetVector * +ScalarEvolution::getSCEVValues(const SCEV *S) { ExprValueMapType::iterator SI = ExprValueMap.find_as(S); if (SI == ExprValueMap.end()) return nullptr; #ifndef NDEBUG if (VerifySCEVMap) { // Check there is no dangling Value in the set returned. for (const auto &VE : SI->second) - assert(ValueExprMap.count(VE)); + assert(ValueExprMap.count(VE.first)); } #endif return &SI->second; } -/// Erase Value from ValueExprMap and ExprValueMap. If ValueExprMap.erase(V) is -/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be -/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also -/// removed from the set of ExprValueMap[S]. +/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) +/// cannot be used separately. eraseValueFromMap should be used to remove +/// V from ValueExprMap and ExprValueMap at the same time. void ScalarEvolution::eraseValueFromMap(Value *V) { ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { const SCEV *S = I->second; - SetVector *SV = getSCEVValues(S); - // Remove V from the set of ExprValueMap[S] - if (SV) - SV->remove(V); + // Remove {V, 0} from the set of ExprValueMap[S] + if (SetVector *SV = getSCEVValues(S)) + SV->remove({V, nullptr}); + + // Remove {V, Offset} from the set of ExprValueMap[Stripped] + const SCEV *Stripped; + ConstantInt *Offset; + std::tie(Stripped, Offset) = splitAddExpr(S); + if (Offset != nullptr) { + if (SetVector *SV = getSCEVValues(Stripped)) + SV->remove({V, Offset}); + } ValueExprMap.erase(V); } } /// Return an existing SCEV if it exists, otherwise analyze the expression and /// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); const SCEV *S = getExistingSCEV(V); if (S == nullptr) { S = createSCEV(V); // During PHI resolution, it is possible to create two SCEVs for the same // V, so it is needed to double check whether V->S is inserted into - // ValueExprMap before insert S->V into ExprValueMap. + // ValueExprMap before insert S->{V, 0} into ExprValueMap. std::pair Pair = ValueExprMap.insert({SCEVCallbackVH(V, this), S}); - if (Pair.second) - ExprValueMap[S].insert(V); + if (Pair.second) { + ExprValueMap[S].insert({V, nullptr}); + + // If S == Stripped + Offset, add Stripped -> {V, Offset} into + // ExprValueMap. + const SCEV *Stripped = S; + ConstantInt *Offset = nullptr; + std::tie(Stripped, Offset) = splitAddExpr(S); + // If stripped is SCEVUnknown, don't bother to save + // Stripped -> {V, offset}. It doesn't simplify and sometimes even + // increase the complexity of the expansion code. + // If V is GetElementPtrInst, don't save Stripped -> {V, offset} + // because it may generate add/sub instead of GEP in SCEV expansion. + if (Offset != nullptr && !isa(Stripped) && + !isa(V)) + ExprValueMap[Stripped].insert({V, Offset}); + } } return S; } const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { const SCEV *S = I->second; if (checkValidity(S)) return S; + eraseValueFromMap(V); forgetMemoizedResults(S); - ValueExprMap.erase(I); } return nullptr; } /// Return a SCEV corresponding to -V = -1*V /// const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); return getMulExpr( V, getConstant(cast(Constant::getAllOnesValue(Ty))), Flags); } /// Return a SCEV corresponding to ~V = -1-V const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNot(VC->getValue()))); Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); const SCEV *AllOnes = getConstant(cast(Constant::getAllOnesValue(Ty))); return getMinusSCEV(AllOnes, V); } const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { // Fast path: X - X --> 0. if (LHS == RHS) return getZero(LHS->getType()); // We represent LHS - RHS as LHS + (-1)*RHS. This transformation // makes it so that we cannot make much use of NUW. auto AddFlags = SCEV::FlagAnyWrap; const bool RHSIsNotMinSigned = !getSignedRange(RHS).getSignedMin().isMinSignedValue(); if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { // Let M be the minimum representable signed value. Then (-1)*RHS // signed-wraps if and only if RHS is M. That can happen even for // a NSW subtraction because e.g. (-1)*M signed-wraps even though // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + // (-1)*RHS, we need to prove that RHS != M. // // If LHS is non-negative and we know that LHS - RHS does not // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap // either by proving that RHS > M or that LHS >= 0. if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { AddFlags = SCEV::FlagNSW; } } // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - // RHS is NSW and LHS >= 0. // // The difficulty here is that the NSW flag may have been proven // relative to a loop that is to be found in a recurrence in LHS and // not in RHS. Applying NSW to (-1)*M may then let the NSW have a // larger scope than intended. auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) return getTruncateExpr(V, Ty); return getZeroExtendExpr(V, Ty); } const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) return getTruncateExpr(V, Ty); return getSignExtendExpr(V, Ty); } const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or zero extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrZeroExtend cannot truncate!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion return getZeroExtendExpr(V, Ty); } const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or sign extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrSignExtend cannot truncate!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion return getSignExtendExpr(V, Ty); } const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot noop or any extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrAnyExtend cannot truncate!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion return getAnyExtendExpr(V, Ty); } const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && (Ty->isIntegerTy() || Ty->isPointerTy()) && "Cannot truncate or noop with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && "getTruncateOrNoop cannot extend!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion return getTruncateExpr(V, Ty); } const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; const SCEV *PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); else PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); return getUMaxExpr(PromotedLHS, PromotedRHS); } const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; const SCEV *PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); else PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); return getUMinExpr(PromotedLHS, PromotedRHS); } const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) return V; if (const SCEVCastExpr *Cast = dyn_cast(V)) { return getPointerBase(Cast->getOperand()); } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { const SCEV *PtrOp = nullptr; for (const SCEV *NAryOp : NAry->operands()) { if (NAryOp->getType()->isPointerTy()) { // Cannot find the base of an expression with multiple pointer operands. if (PtrOp) return V; PtrOp = NAryOp; } } if (!PtrOp) return V; return getPointerBase(PtrOp); } return V; } /// Push users of the given Instruction onto the given Worklist. static void PushDefUseChildren(Instruction *I, SmallVectorImpl &Worklist) { // Push the def-use children onto the Worklist stack. for (User *U : I->users()) Worklist.push_back(cast(U)); } void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { SmallVector Worklist; PushDefUseChildren(PN, Worklist); SmallPtrSet Visited; Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I).second) continue; auto It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { const SCEV *Old = It->second; // Short-circuit the def-use traversal if the symbolic name // ceases to appear in expressions. if (Old != SymName && !hasOperand(Old, SymName)) continue; // SCEVUnknown for a PHI either means that it has an unrecognized // structure, it's a PHI that's in the progress of being computed // by createNodeForPHI, or it's a single-value PHI. In the first case, // additional loop trip count information isn't going to change anything. // In the second case, createNodeForPHI will perform the necessary // updates on its own when it gets to that point. In the third, we do // want to forget the SCEVUnknown. if (!isa(I) || !isa(Old) || (I != PN && Old == SymName)) { + eraseValueFromMap(It->first); forgetMemoizedResults(Old); - ValueExprMap.erase(It); } } PushDefUseChildren(I, Worklist); } } namespace { class SCEVInitRewriter : public SCEVRewriteVisitor { public: static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVInitRewriter Rewriter(L, SE); const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) : SCEVRewriteVisitor(SE), L(L), Valid(true) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) Valid = false; return Expr; } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { // Only allow AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getStart(); Valid = false; return Expr; } bool isValid() { return Valid; } private: const Loop *L; bool Valid; }; class SCEVShiftRewriter : public SCEVRewriteVisitor { public: static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) : SCEVRewriteVisitor(SE), L(L), Valid(true) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) Valid = false; return Expr; } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { if (Expr->getLoop() == L && Expr->isAffine()) return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); Valid = false; return Expr; } bool isValid() { return Valid; } private: const Loop *L; bool Valid; }; } // end anonymous namespace SCEV::NoWrapFlags ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { if (!AR->isAffine()) return SCEV::FlagAnyWrap; typedef OverflowingBinaryOperator OBO; SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; if (!AR->hasNoSignedWrap()) { ConstantRange AddRecRange = getSignedRange(AR); ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, IncRange, OBO::NoSignedWrap); if (NSWRegion.contains(AddRecRange)) Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); } if (!AR->hasNoUnsignedWrap()) { ConstantRange AddRecRange = getUnsignedRange(AR); ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, IncRange, OBO::NoUnsignedWrap); if (NUWRegion.contains(AddRecRange)) Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); } return Result; } namespace { /// Represents an abstract binary operation. This may exist as a /// normal instruction or constant expression, or may have been /// derived from an expression tree. struct BinaryOp { unsigned Opcode; Value *LHS; Value *RHS; bool IsNSW; bool IsNUW; /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or /// constant expression. Operator *Op; explicit BinaryOp(Operator *Op) : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), IsNSW(false), IsNUW(false), Op(Op) { if (auto *OBO = dyn_cast(Op)) { IsNSW = OBO->hasNoSignedWrap(); IsNUW = OBO->hasNoUnsignedWrap(); } } explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, bool IsNUW = false) : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), Op(nullptr) {} }; } /// Try to map \p V into a BinaryOp, and return \c None on failure. static Optional MatchBinaryOp(Value *V, DominatorTree &DT) { auto *Op = dyn_cast(V); if (!Op) return None; // Implementation detail: all the cleverness here should happen without // creating new SCEV expressions -- our caller knowns tricks to avoid creating // SCEV expressions when possible, and we should not break that. switch (Op->getOpcode()) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: case Instruction::UDiv: case Instruction::And: case Instruction::Or: case Instruction::AShr: case Instruction::Shl: return BinaryOp(Op); case Instruction::Xor: if (auto *RHSC = dyn_cast(Op->getOperand(1))) // If the RHS of the xor is a signbit, then this is just an add. // Instcombine turns add of signbit into xor as a strength reduction step. if (RHSC->getValue().isSignBit()) return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); return BinaryOp(Op); case Instruction::LShr: // Turn logical shift right of a constant into a unsigned divide. if (ConstantInt *SA = dyn_cast(Op->getOperand(1))) { uint32_t BitWidth = cast(Op->getType())->getBitWidth(); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the // resolution chosen here may differ from the resolution chosen in // other parts of the compiler. if (SA->getValue().ult(BitWidth)) { Constant *X = ConstantInt::get(SA->getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); } } return BinaryOp(Op); case Instruction::ExtractValue: { auto *EVI = cast(Op); if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) break; auto *CI = dyn_cast(EVI->getAggregateOperand()); if (!CI) break; if (auto *F = CI->getCalledFunction()) switch (F->getIntrinsicID()) { case Intrinsic::sadd_with_overflow: case Intrinsic::uadd_with_overflow: { if (!isOverflowIntrinsicNoWrap(cast(CI), DT)) return BinaryOp(Instruction::Add, CI->getArgOperand(0), CI->getArgOperand(1)); // Now that we know that all uses of the arithmetic-result component of // CI are guarded by the overflow check, we can go ahead and pretend // that the arithmetic is non-overflowing. if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) return BinaryOp(Instruction::Add, CI->getArgOperand(0), CI->getArgOperand(1), /* IsNSW = */ true, /* IsNUW = */ false); else return BinaryOp(Instruction::Add, CI->getArgOperand(0), CI->getArgOperand(1), /* IsNSW = */ false, /* IsNUW*/ true); } case Intrinsic::ssub_with_overflow: case Intrinsic::usub_with_overflow: return BinaryOp(Instruction::Sub, CI->getArgOperand(0), CI->getArgOperand(1)); case Intrinsic::smul_with_overflow: case Intrinsic::umul_with_overflow: return BinaryOp(Instruction::Mul, CI->getArgOperand(0), CI->getArgOperand(1)); default: break; } } default: break; } return None; } const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) return nullptr; // The loop may have multiple entrances or multiple exits; we can analyze // this phi as an addrec if it has a unique entry value and a unique // backedge value. Value *BEValueV = nullptr, *StartValueV = nullptr; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { Value *V = PN->getIncomingValue(i); if (L->contains(PN->getIncomingBlock(i))) { if (!BEValueV) { BEValueV = V; } else if (BEValueV != V) { BEValueV = nullptr; break; } } else if (!StartValueV) { StartValueV = V; } else if (StartValueV != V) { StartValueV = nullptr; break; } } if (BEValueV && StartValueV) { // While we are analyzing this PHI node, handle its value symbolically. const SCEV *SymbolicName = getUnknown(PN); assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && "PHI node already processed?"); ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. const SCEV *BEValue = getSCEV(BEValueV); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. // If the value coming around the backedge is an add with the symbolic // value we just inserted, then we found a simple induction variable! if (const SCEVAddExpr *Add = dyn_cast(BEValue)) { // If there is a single occurrence of the symbolic value, replace it // with a recurrence. unsigned FoundIndex = Add->getNumOperands(); for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (Add->getOperand(i) == SymbolicName) if (FoundIndex == e) { FoundIndex = i; break; } if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); const SCEV *Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. if (isLoopInvariant(Accum, L) || (isa(Accum) && cast(Accum)->getLoop() == L)) { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (auto BO = MatchBinaryOp(BEValueV, DT)) { if (BO->Opcode == Instruction::Add && BO->LHS == PN) { if (BO->IsNUW) Flags = setFlags(Flags, SCEV::FlagNUW); if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); } } else if (GEPOperator *GEP = dyn_cast(BEValueV)) { // If the increment is an inbounds GEP, then we know the address // space cannot be wrapped around. We cannot make any guarantee // about signed or unsigned overflow because pointers are // unsigned but we may have a negative index from the base // pointer. We can guarantee that no unsigned wrap occurs if the // indices form a positive value. if (GEP->isInBounds() && GEP->getOperand(0) == PN) { Flags = setFlags(Flags, SCEV::FlagNW); const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) Flags = setFlags(Flags, SCEV::FlagNUW); } // We cannot transfer nuw and nsw flags from subtraction // operations -- sub nuw X, Y is not the same as add nuw X, -Y // for instance. } const SCEV *StartVal = getSCEV(StartValueV); const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; // We can add Flags to the post-inc expression only if we // know that it us *undefined behavior* for BEValueV to // overflow. if (auto *BEInst = dyn_cast(BEValueV)) if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); return PHISCEV; } } } else { // Otherwise, this could be a loop like this: // i = 0; for (j = 1; ..; ++j) { .... i = j; } // In this case, j = {1,+,1} and BEValue is j. // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. // // We can generalize this saying that i is the shifted value of BEValue // by one iteration: // PHI(f(0), f({1,+,1})) --> f({0,+,1}) const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute()) { const SCEV *StartVal = getSCEV(StartValueV); if (Start == StartVal) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; return Shifted; } } } // Remove the temporary PHI node SCEV that has been inserted while intending // to create an AddRecExpr for this PHI node. We can not keep this temporary // as it will prevent later (possibly simpler) SCEV expressions to be added // to the ValueExprMap. - ValueExprMap.erase(PN); + eraseValueFromMap(PN); } return nullptr; } // Checks if the SCEV S is available at BB. S is considered available at BB // if S can be materialized at BB without introducing a fault. static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, BasicBlock *BB) { struct CheckAvailable { bool TraversalDone = false; bool Available = true; const Loop *L = nullptr; // The loop BB is in (can be nullptr) BasicBlock *BB = nullptr; DominatorTree &DT; CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) : L(L), BB(BB), DT(DT) {} bool setUnavailable() { TraversalDone = true; Available = false; return false; } bool follow(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: // These expressions are available if their operand(s) is/are. return true; case scAddRecExpr: { // We allow add recurrences that are on the loop BB is in, or some // outer loop. This guarantees availability because the value of the // add recurrence at BB is simply the "current" value of the induction // variable. We can relax this in the future; for instance an add // recurrence on a sibling dominating loop is also available at BB. const auto *ARLoop = cast(S)->getLoop(); if (L && (ARLoop == L || ARLoop->contains(L))) return true; return setUnavailable(); } case scUnknown: { // For SCEVUnknown, we check for simple dominance. const auto *SU = cast(S); Value *V = SU->getValue(); if (isa(V)) return false; if (isa(V) && DT.dominates(cast(V), BB)) return false; return setUnavailable(); } case scUDivExpr: case scCouldNotCompute: // We do not try to smart about these at all. return setUnavailable(); } llvm_unreachable("switch should be fully covered!"); } bool isDone() { return TraversalDone; } }; CheckAvailable CA(L, BB, DT); SCEVTraversal ST(CA); ST.visitAll(S); return CA.Available; } // Try to match a control flow sequence that branches out at BI and merges back // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful // match. static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS) { C = BI->getCondition(); BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); if (!LeftEdge.isSingleEdge()) return false; assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); Use &LeftUse = Merge->getOperandUse(0); Use &RightUse = Merge->getOperandUse(1); if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { LHS = LeftUse; RHS = RightUse; return true; } if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { LHS = RightUse; RHS = LeftUse; return true; } return false; } const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { if (PN->getNumIncomingValues() == 2) { const Loop *L = LI.getLoopFor(PN->getParent()); // We don't want to break LCSSA, even in a SCEV expression tree. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) return nullptr; // Try to match // // br %cond, label %left, label %right // left: // br label %merge // right: // br label %merge // merge: // V = phi [ %x, %left ], [ %y, %right ] // // as "select %cond, %x, %y" BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); assert(IDom && "At least the entry block should dominate PN"); auto *BI = dyn_cast(IDom->getTerminator()); Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; if (BI && BI->isConditional() && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); } return nullptr; } const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (const SCEV *S = createAddRecFromPHI(PN)) return S; if (const SCEV *S = createNodeFromSelectLikePHI(PN)) return S; // If the PHI has a single incoming value, follow that value, unless the // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) if (LI.replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, Value *Cond, Value *TrueVal, Value *FalseVal) { // Handle "constant" branch or select. This can occur for instance when a // loop pass transforms an inner loop and moves on to process the outer loop. if (auto *CI = dyn_cast(Cond)) return getSCEV(CI->isOne() ? TrueVal : FalseVal); // Try to match some simple smax or umax patterns. auto *ICI = dyn_cast(Cond); if (!ICI) return getUnknown(I); Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); switch (ICI->getPredicate()) { case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); // fall through case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: // a >s b ? a+x : b+x -> smax(a, b)+x // a >s b ? b+x : a+x -> smin(a, b)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); const SCEV *LA = getSCEV(TrueVal); const SCEV *RA = getSCEV(FalseVal); const SCEV *LDiff = getMinusSCEV(LA, LS); const SCEV *RDiff = getMinusSCEV(RA, RS); if (LDiff == RDiff) return getAddExpr(getSMaxExpr(LS, RS), LDiff); LDiff = getMinusSCEV(LA, RS); RDiff = getMinusSCEV(RA, LS); if (LDiff == RDiff) return getAddExpr(getSMinExpr(LS, RS), LDiff); } break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: std::swap(LHS, RHS); // fall through case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: // a >u b ? a+x : b+x -> umax(a, b)+x // a >u b ? b+x : a+x -> umin(a, b)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); const SCEV *LA = getSCEV(TrueVal); const SCEV *RA = getSCEV(FalseVal); const SCEV *LDiff = getMinusSCEV(LA, LS); const SCEV *RDiff = getMinusSCEV(RA, RS); if (LDiff == RDiff) return getAddExpr(getUMaxExpr(LS, RS), LDiff); LDiff = getMinusSCEV(LA, RS); RDiff = getMinusSCEV(RA, LS); if (LDiff == RDiff) return getAddExpr(getUMinExpr(LS, RS), LDiff); } break; case ICmpInst::ICMP_NE: // n != 0 ? n+x : 1+x -> umax(n, 1)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && isa(RHS) && cast(RHS)->isZero()) { const SCEV *One = getOne(I->getType()); const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); const SCEV *LA = getSCEV(TrueVal); const SCEV *RA = getSCEV(FalseVal); const SCEV *LDiff = getMinusSCEV(LA, LS); const SCEV *RDiff = getMinusSCEV(RA, One); if (LDiff == RDiff) return getAddExpr(getUMaxExpr(One, LS), LDiff); } break; case ICmpInst::ICMP_EQ: // n == 0 ? 1+x : n+x -> umax(n, 1)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && isa(RHS) && cast(RHS)->isZero()) { const SCEV *One = getOne(I->getType()); const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); const SCEV *LA = getSCEV(TrueVal); const SCEV *RA = getSCEV(FalseVal); const SCEV *LDiff = getMinusSCEV(LA, One); const SCEV *RDiff = getMinusSCEV(RA, LS); if (LDiff == RDiff) return getAddExpr(getUMaxExpr(One, LS), LDiff); } break; default: break; } return getUnknown(I); } /// Expand GEP instructions into add and multiply operations. This allows them /// to be analyzed by regular SCEV code. const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { // Don't attempt to analyze GEPs over unsized objects. if (!GEP->getSourceElementType()->isSized()) return getUnknown(GEP); SmallVector IndexExprs; for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) IndexExprs.push_back(getSCEV(*Index)); return getGEPExpr(GEP->getSourceElementType(), getSCEV(GEP->getPointerOperand()), IndexExprs, GEP->isInBounds()); } uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) return C->getAPInt().countTrailingZeros(); if (const SCEVTruncateExpr *T = dyn_cast(S)) return std::min(GetMinTrailingZeros(T->getOperand()), (uint32_t)getTypeSizeInBits(T->getType())); if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVMulExpr *M = dyn_cast(S)) { // The result is the sum of all operands results. uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } if (const SCEVAddRecExpr *A = dyn_cast(S)) { // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVSMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } if (const SCEVUMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, nullptr, &DT); return Zeros.countTrailingOnes(); } // SCEVUDivExpr return 0; } /// Helper method to assign a range to V from metadata present in the IR. static Optional GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast(V)) if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) return getConstantRangeFromMetadata(*MD); return None; } /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. ConstantRange ScalarEvolution::getRange(const SCEV *S, ScalarEvolution::RangeSignHint SignHint) { DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; // See if we've computed this range already. DenseMap::iterator I = Cache.find(S); if (I != Cache.end()) return I->second; if (const SCEVConstant *C = dyn_cast(S)) return setRange(C, SignHint, ConstantRange(C->getAPInt())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); // If the value has known zeros, the maximum value will have those known zeros // as well. uint32_t TZ = GetMinTrailingZeros(S); if (TZ != 0) { if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) ConservativeResult = ConstantRange(APInt::getMinValue(BitWidth), APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); else ConservativeResult = ConstantRange( APInt::getSignedMinValue(BitWidth), APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); } if (const SCEVAddExpr *Add = dyn_cast(S)) { ConstantRange X = getRange(Add->getOperand(0), SignHint); for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) X = X.add(getRange(Add->getOperand(i), SignHint)); return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getRange(Mul->getOperand(0), SignHint); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) X = X.multiply(getRange(Mul->getOperand(i), SignHint)); return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getRange(SMax->getOperand(0), SignHint); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) X = X.smax(getRange(SMax->getOperand(i), SignHint)); return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getRange(UMax->getOperand(0), SignHint); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) X = X.umax(getRange(UMax->getOperand(i), SignHint)); return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange X = getRange(UDiv->getLHS(), SignHint); ConstantRange Y = getRange(UDiv->getRHS(), SignHint); return setRange(UDiv, SignHint, ConservativeResult.intersectWith(X.udiv(Y))); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { ConstantRange X = getRange(ZExt->getOperand(), SignHint); return setRange(ZExt, SignHint, ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { ConstantRange X = getRange(SExt->getOperand(), SignHint); return setRange(SExt, SignHint, ConservativeResult.intersectWith(X.signExtend(BitWidth))); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { ConstantRange X = getRange(Trunc->getOperand(), SignHint); return setRange(Trunc, SignHint, ConservativeResult.intersectWith(X.truncate(BitWidth))); } if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { // If there's no unsigned wrap, the value will never be less than its // initial value. if (AddRec->hasNoUnsignedWrap()) if (const SCEVConstant *C = dyn_cast(AddRec->getStart())) if (!C->getValue()->isZero()) ConservativeResult = ConservativeResult.intersectWith( ConstantRange(C->getAPInt(), APInt(BitWidth, 0))); // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. if (AddRec->hasNoSignedWrap()) { bool AllNonNeg = true; bool AllNonPos = true; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; } if (AllNonNeg) ConservativeResult = ConservativeResult.intersectWith( ConstantRange(APInt(BitWidth, 0), APInt::getSignedMinValue(BitWidth))); else if (AllNonPos) ConservativeResult = ConservativeResult.intersectWith( ConstantRange(APInt::getSignedMinValue(BitWidth), APInt(BitWidth, 1))); } // TODO: non-affine addrec if (AddRec->isAffine()) { const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { auto RangeFromAffine = getRangeForAffineAR( AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, BitWidth); if (!RangeFromAffine.isFullSet()) ConservativeResult = ConservativeResult.intersectWith(RangeFromAffine); auto RangeFromFactoring = getRangeViaFactoring( AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, BitWidth); if (!RangeFromFactoring.isFullSet()) ConservativeResult = ConservativeResult.intersectWith(RangeFromFactoring); } } return setRange(AddRec, SignHint, ConservativeResult); } if (const SCEVUnknown *U = dyn_cast(S)) { // Check if the IR explicitly contains !range metadata. Optional MDRange = GetRangeFromMetadata(U->getValue()); if (MDRange.hasValue()) ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); // Split here to avoid paying the compile-time cost of calling both // computeKnownBits and ComputeNumSignBits. This restriction can be lifted // if needed. const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT); if (Ones != ~Zeros + 1) ConservativeResult = ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); } else { assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && "generalize as needed!"); unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); if (NS > 1) ConservativeResult = ConservativeResult.intersectWith( ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); } return setRange(U, SignHint, ConservativeResult); } return setRange(S, SignHint, ConservativeResult); } ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, const SCEV *Step, const SCEV *MaxBECount, unsigned BitWidth) { assert(!isa(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && "Precondition!"); ConstantRange Result(BitWidth, /* isFullSet = */ true); // Check for overflow. This must be done with ConstantRange arithmetic // because we could be called from within the ScalarEvolution overflow // checking code. MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); ConstantRange StepSRange = getSignedRange(Step); ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); ConstantRange StartURange = getUnsignedRange(Start); ConstantRange EndURange = StartURange.add(MaxBECountRange.multiply(StepSRange)); // Check for unsigned overflow. ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == ZExtEndURange) { APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), EndURange.getUnsignedMin()); APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), EndURange.getUnsignedMax()); bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); if (!IsFullRange) Result = Result.intersectWith(ConstantRange(Min, Max + 1)); } ConstantRange StartSRange = getSignedRange(Start); ConstantRange EndSRange = StartSRange.add(MaxBECountRange.multiply(StepSRange)); // Check for signed overflow. This must be done with ConstantRange // arithmetic because we could be called from within the ScalarEvolution // overflow checking code. ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == SExtEndSRange) { APInt Min = APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); APInt Max = APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); if (!IsFullRange) Result = Result.intersectWith(ConstantRange(Min, Max + 1)); } return Result; } ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *Step, const SCEV *MaxBECount, unsigned BitWidth) { // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) struct SelectPattern { Value *Condition = nullptr; APInt TrueValue; APInt FalseValue; explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, const SCEV *S) { Optional CastOp; APInt Offset(BitWidth, 0); assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && "Should be!"); // Peel off a constant offset: if (auto *SA = dyn_cast(S)) { // In the future we could consider being smarter here and handle // {Start+Step,+,Step} too. if (SA->getNumOperands() != 2 || !isa(SA->getOperand(0))) return; Offset = cast(SA->getOperand(0))->getAPInt(); S = SA->getOperand(1); } // Peel off a cast operation if (auto *SCast = dyn_cast(S)) { CastOp = SCast->getSCEVType(); S = SCast->getOperand(); } using namespace llvm::PatternMatch; auto *SU = dyn_cast(S); const APInt *TrueVal, *FalseVal; if (!SU || !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), m_APInt(FalseVal)))) { Condition = nullptr; return; } TrueValue = *TrueVal; FalseValue = *FalseVal; // Re-apply the cast we peeled off earlier if (CastOp.hasValue()) switch (*CastOp) { default: llvm_unreachable("Unknown SCEV cast type!"); case scTruncate: TrueValue = TrueValue.trunc(BitWidth); FalseValue = FalseValue.trunc(BitWidth); break; case scZeroExtend: TrueValue = TrueValue.zext(BitWidth); FalseValue = FalseValue.zext(BitWidth); break; case scSignExtend: TrueValue = TrueValue.sext(BitWidth); FalseValue = FalseValue.sext(BitWidth); break; } // Re-apply the constant offset we peeled off earlier TrueValue += Offset; FalseValue += Offset; } bool isRecognized() { return Condition != nullptr; } }; SelectPattern StartPattern(*this, BitWidth, Start); if (!StartPattern.isRecognized()) return ConstantRange(BitWidth, /* isFullSet = */ true); SelectPattern StepPattern(*this, BitWidth, Step); if (!StepPattern.isRecognized()) return ConstantRange(BitWidth, /* isFullSet = */ true); if (StartPattern.Condition != StepPattern.Condition) { // We don't handle this case today; but we could, by considering four // possibilities below instead of two. I'm not sure if there are cases where // that will help over what getRange already does, though. return ConstantRange(BitWidth, /* isFullSet = */ true); } // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to // construct arbitrary general SCEV expressions here. This function is called // from deep in the call stack, and calling getSCEV (on a sext instruction, // say) can end up caching a suboptimal value. // FIXME: without the explicit `this` receiver below, MSVC errors out with // C2352 and C2512 (otherwise it isn't needed). const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); ConstantRange TrueRange = this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); ConstantRange FalseRange = this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); return TrueRange.unionWith(FalseRange); } SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { if (isa(V)) return SCEV::FlagAnyWrap; const BinaryOperator *BinOp = cast(V); // Return early if there are no flags to propagate to the SCEV. SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BinOp->hasNoUnsignedWrap()) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); if (BinOp->hasNoSignedWrap()) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); if (Flags == SCEV::FlagAnyWrap) return SCEV::FlagAnyWrap; return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; } bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { // Here we check that I is in the header of the innermost loop containing I, // since we only deal with instructions in the loop header. The actual loop we // need to check later will come from an add recurrence, but getting that // requires computing the SCEV of the operands, which can be expensive. This // check we can do cheaply to rule out some cases early. Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); if (InnermostContainingLoop == nullptr || InnermostContainingLoop->getHeader() != I->getParent()) return false; // Only proceed if we can prove that I does not yield poison. if (!isKnownNotFullPoison(I)) return false; // At this point we know that if I is executed, then it does not wrap // according to at least one of NSW or NUW. If I is not executed, then we do // not know if the calculation that I represents would wrap. Multiple // instructions can map to the same SCEV. If we apply NSW or NUW from I to // the SCEV, we must guarantee no wrapping for that SCEV also when it is // derived from other instructions that map to the same SCEV. We cannot make // that guarantee for cases where I is not executed. So we need to find the // loop that I is considered in relation to and prove that I is executed for // every iteration of that loop. That implies that the value that I // calculates does not wrap anywhere in the loop, so then we can apply the // flags to the SCEV. // // We check isLoopInvariant to disambiguate in case we are adding recurrences // from different loops, so that we know which loop to prove that I is // executed in. for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { // I could be an extractvalue from a call to an overflow intrinsic. // TODO: We can do better here in some cases. if (!isSCEVable(I->getOperand(OpIndex)->getType())) return false; const SCEV *Op = getSCEV(I->getOperand(OpIndex)); if (auto *AddRec = dyn_cast(Op)) { bool AllOtherOpsLoopInvariant = true; for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); ++OtherOpIndex) { if (OtherOpIndex != OpIndex) { const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { AllOtherOpsLoopInvariant = false; break; } } } if (AllOtherOpsLoopInvariant && isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) return true; } } return false; } bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { // If we know that \c I can never be poison period, then that's enough. if (isSCEVExprNeverPoison(I)) return true; // For an add recurrence specifically, we assume that infinite loops without // side effects are undefined behavior, and then reason as follows: // // If the add recurrence is poison in any iteration, it is poison on all // future iterations (since incrementing poison yields poison). If the result // of the add recurrence is fed into the loop latch condition and the loop // does not contain any throws or exiting blocks other than the latch, we now // have the ability to "choose" whether the backedge is taken or not (by // choosing a sufficiently evil value for the poison feeding into the branch) // for every iteration including and after the one in which \p I first became // poison. There are two possibilities (let's call the iteration in which \p // I first became poison as K): // // 1. In the set of iterations including and after K, the loop body executes // no side effects. In this case executing the backege an infinte number // of times will yield undefined behavior. // // 2. In the set of iterations including and after K, the loop body executes // at least one side effect. In this case, that specific instance of side // effect is control dependent on poison, which also yields undefined // behavior. auto *ExitingBB = L->getExitingBlock(); auto *LatchBB = L->getLoopLatch(); if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) return false; SmallPtrSet Pushed; SmallVector PoisonStack; // We start by assuming \c I, the post-inc add recurrence, is poison. Only // things that are known to be fully poison under that assumption go on the // PoisonStack. Pushed.insert(I); PoisonStack.push_back(I); bool LatchControlDependentOnPoison = false; while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { const Instruction *Poison = PoisonStack.pop_back_val(); for (auto *PoisonUser : Poison->users()) { if (propagatesFullPoison(cast(PoisonUser))) { if (Pushed.insert(cast(PoisonUser)).second) PoisonStack.push_back(cast(PoisonUser)); } else if (auto *BI = dyn_cast(PoisonUser)) { assert(BI->isConditional() && "Only possibility!"); if (BI->getParent() == LatchBB) { LatchControlDependentOnPoison = true; break; } } } } return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); } bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { auto Itr = LoopHasNoAbnormalExits.find(L); if (Itr == LoopHasNoAbnormalExits.end()) { auto NoAbnormalExitInBB = [&](BasicBlock *BB) { return all_of(*BB, [](Instruction &I) { return isGuaranteedToTransferExecutionToSuccessor(&I); }); }; auto InsertPair = LoopHasNoAbnormalExits.insert( {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); assert(InsertPair.second && "We just checked!"); Itr = InsertPair.first; } return Itr->second; } const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); if (Instruction *I = dyn_cast(V)) { // Don't attempt to analyze instructions in blocks that aren't // reachable. Such instructions don't matter, and they aren't required // to obey basic rules for definitions dominating uses which this // analysis depends on. if (!DT.isReachableFromEntry(I->getParent())) return getUnknown(V); } else if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); else if (isa(V)) return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast(V)) return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); else if (!isa(V)) return getUnknown(V); Operator *U = cast(V); if (auto BO = MatchBinaryOp(U, DT)) { switch (BO->Opcode) { case Instruction::Add: { // The simple thing to do would be to just call getSCEV on both operands // and call getAddExpr with the result. However if we're looking at a // bunch of things all added together, this can be quite inefficient, // because it leads to N-1 getAddExpr calls for N ultimate operands. // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. SmallVector AddOps; do { if (BO->Op) { if (auto *OpSCEV = getExistingSCEV(BO->Op)) { AddOps.push_back(OpSCEV); break; } // If a NUW or NSW flag can be applied to the SCEV for this // addition, then compute the SCEV for this addition by itself // with a separate call to getAddExpr. We need to do that // instead of pushing the operands of the addition onto AddOps, // since the flags are only known to apply to this particular // addition - they may not apply to other additions that can be // formed with operands from AddOps. const SCEV *RHS = getSCEV(BO->RHS); SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); if (Flags != SCEV::FlagAnyWrap) { const SCEV *LHS = getSCEV(BO->LHS); if (BO->Opcode == Instruction::Sub) AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); else AddOps.push_back(getAddExpr(LHS, RHS, Flags)); break; } } if (BO->Opcode == Instruction::Sub) AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); else AddOps.push_back(getSCEV(BO->RHS)); auto NewBO = MatchBinaryOp(BO->LHS, DT); if (!NewBO || (NewBO->Opcode != Instruction::Add && NewBO->Opcode != Instruction::Sub)) { AddOps.push_back(getSCEV(BO->LHS)); break; } BO = NewBO; } while (true); return getAddExpr(AddOps); } case Instruction::Mul: { SmallVector MulOps; do { if (BO->Op) { if (auto *OpSCEV = getExistingSCEV(BO->Op)) { MulOps.push_back(OpSCEV); break; } SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); if (Flags != SCEV::FlagAnyWrap) { MulOps.push_back( getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); break; } } MulOps.push_back(getSCEV(BO->RHS)); auto NewBO = MatchBinaryOp(BO->LHS, DT); if (!NewBO || NewBO->Opcode != Instruction::Mul) { MulOps.push_back(getSCEV(BO->LHS)); break; } BO = NewBO; } while (true); return getMulExpr(MulOps); } case Instruction::UDiv: return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); case Instruction::Sub: { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BO->Op) Flags = getNoWrapFlagsFromUB(BO->Op); return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); } case Instruction::And: // For an expression like x&255 that merely masks off the high bits, // use zext(trunc(x)) as the SCEV expression. if (ConstantInt *CI = dyn_cast(BO->RHS)) { if (CI->isNullValue()) return getSCEV(BO->RHS); if (CI->isAllOnesValue()) return getSCEV(BO->LHS); const APInt &A = CI->getValue(); // Instcombine's ShrinkDemandedConstant may strip bits out of // constants, obscuring what would otherwise be a low-bits mask. // Use computeKnownBits to compute what ShrinkDemandedConstant // knew about to reconstruct a low-bits mask value. unsigned LZ = A.countLeadingZeros(); unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { const SCEV *MulCount = getConstant(ConstantInt::get( getContext(), APInt::getOneBitSet(BitWidth, TZ))); return getMulExpr( getZeroExtendExpr( getTruncateExpr( getUDivExactExpr(getSCEV(BO->LHS), MulCount), IntegerType::get(getContext(), BitWidth - LZ - TZ)), BO->LHS->getType()), MulCount); } } break; case Instruction::Or: // If the RHS of the Or is a constant, we may have something like: // X*4+1 which got turned into X*4|1. Handle this as an Add so loop // optimizations will transparently handle this case. // // In order for this transformation to be safe, the LHS must be of the // form X*(2^n) and the Or constant must be less than 2^n. if (ConstantInt *CI = dyn_cast(BO->RHS)) { const SCEV *LHS = getSCEV(BO->LHS); const APInt &CIVal = CI->getValue(); if (GetMinTrailingZeros(LHS) >= (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { // Build a plain add SCEV. const SCEV *S = getAddExpr(LHS, getSCEV(CI)); // If the LHS of the add was an addrec and it has no-wrap flags, // transfer the no-wrap flags, since an or won't introduce a wrap. if (const SCEVAddRecExpr *NewAR = dyn_cast(S)) { const SCEVAddRecExpr *OldAR = cast(LHS); const_cast(NewAR)->setNoWrapFlags( OldAR->getNoWrapFlags()); } return S; } } break; case Instruction::Xor: if (ConstantInt *CI = dyn_cast(BO->RHS)) { // If the RHS of xor is -1, then this is a not operation. if (CI->isAllOnesValue()) return getNotSCEV(getSCEV(BO->LHS)); // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. // This is a variant of the check for xor with -1, and it handles // the case where instcombine has trimmed non-demanded bits out // of an xor with -1. if (auto *LBO = dyn_cast(BO->LHS)) if (ConstantInt *LCI = dyn_cast(LBO->getOperand(1))) if (LBO->getOpcode() == Instruction::And && LCI->getValue() == CI->getValue()) if (const SCEVZeroExtendExpr *Z = dyn_cast(getSCEV(BO->LHS))) { Type *UTy = BO->LHS->getType(); const SCEV *Z0 = Z->getOperand(); Type *Z0Ty = Z0->getType(); unsigned Z0TySize = getTypeSizeInBits(Z0Ty); // If C is a low-bits mask, the zero extend is serving to // mask off the high bits. Complement the operand and // re-apply the zext. if (APIntOps::isMask(Z0TySize, CI->getValue())) return getZeroExtendExpr(getNotSCEV(Z0), UTy); // If C is a single bit, it may be in the sign-bit position // before the zero-extend. In this case, represent the xor // using an add, which is equivalent, and re-apply the zext. APInt Trunc = CI->getValue().trunc(Z0TySize); if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && Trunc.isSignBit()) return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), UTy); } } break; case Instruction::Shl: // Turn shift left of a constant amount into a multiply. if (ConstantInt *SA = dyn_cast(BO->RHS)) { uint32_t BitWidth = cast(SA->getType())->getBitWidth(); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the // resolution chosen here may differ from the resolution chosen in // other parts of the compiler. if (SA->getValue().uge(BitWidth)) break; // It is currently not resolved how to interpret NSW for left // shift by BitWidth - 1, so we avoid applying flags in that // case. Remove this check (or this comment) once the situation // is resolved. See // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html // and http://reviews.llvm.org/D8890 . auto Flags = SCEV::FlagAnyWrap; if (BO->Op && SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(BO->Op); Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); } break; case Instruction::AShr: // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. if (ConstantInt *CI = dyn_cast(BO->RHS)) if (Operator *L = dyn_cast(BO->LHS)) if (L->getOpcode() == Instruction::Shl && L->getOperand(1) == BO->RHS) { uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the // resolution chosen here may differ from the resolution chosen in // other parts of the compiler. if (CI->getValue().uge(BitWidth)) break; uint64_t Amt = BitWidth - CI->getZExtValue(); if (Amt == BitWidth) return getSCEV(L->getOperand(0)); // shift by zero --> noop return getSignExtendExpr( getTruncateExpr(getSCEV(L->getOperand(0)), IntegerType::get(getContext(), Amt)), BO->LHS->getType()); } break; } } switch (U->getOpcode()) { case Instruction::Trunc: return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::ZExt: return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::SExt: return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::BitCast: // BitCasts are no-op casts so we just eliminate the cast. if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) return getSCEV(U->getOperand(0)); break; // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can // lead to pointer expressions which cannot safely be expanded to GEPs, // because ScalarEvolution doesn't respect the GEP aliasing rules when // simplifying integer expressions. case Instruction::GetElementPtr: return createNodeForGEP(cast(U)); case Instruction::PHI: return createNodeForPHI(cast(U)); case Instruction::Select: // U can also be a select constant expr, which let fall through. Since // createNodeForSelect only works for a condition that is an `ICmpInst`, and // constant expressions cannot have instructions as operands, we'd have // returned getUnknown for a select constant expressions anyway. if (isa(U)) return createNodeForSelectOrPHI(cast(U), U->getOperand(0), U->getOperand(1), U->getOperand(2)); break; case Instruction::Call: case Instruction::Invoke: if (Value *RV = CallSite(U).getReturnedArgOperand()) return getSCEV(RV); break; } return getUnknown(V); } //===----------------------------------------------------------------------===// // Iteration Count Computation Code // unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripCount(L, ExitingBB); // No trip count information for multiple exits. return 0; } unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = dyn_cast(getExitCount(L, ExitingBlock)); if (!ExitCount) return 0; ConstantInt *ExitConst = ExitCount->getValue(); // Guard against huge trip counts. if (ExitConst->getValue().getActiveBits() > 32) return 0; // In case of integer overflow, this returns 0, which is correct. return ((unsigned)ExitConst->getZExtValue()) + 1; } unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripMultiple(L, ExitingBB); // No trip multiple information for multiple exits. return 0; } /// Returns the largest constant divisor of the trip count of this loop as a /// normal unsigned value, if possible. This means that the actual trip count is /// always a multiple of the returned value (don't forget the trip count could /// very well be zero as well!). /// /// Returns 1 if the trip count is unknown or not guaranteed to be the /// multiple of a constant (which is also the case if the trip count is simply /// constant, use getSmallConstantTripCount for that case), Will also return 1 /// if the trip count is very large (>= 2^32). /// /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); const SCEV *ExitCount = getExitCount(L, ExitingBlock); if (ExitCount == getCouldNotCompute()) return 1; // Get the trip count from the BE count by adding 1. const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt // to factor simple cases. if (const SCEVMulExpr *Mul = dyn_cast(TCMul)) TCMul = Mul->getOperand(0); const SCEVConstant *MulC = dyn_cast(TCMul); if (!MulC) return 1; ConstantInt *Result = MulC->getValue(); // Guard against huge trip counts (this requires checking // for zero to handle the case where the trip count == -1 and the // addition wraps). if (!Result || Result->getValue().getActiveBits() > 32 || Result->getValue().getActiveBits() == 0) return 1; return (unsigned)Result->getZExtValue(); } /// Get the expression for the number of loop iterations for which this loop is /// guaranteed not to exit via ExitingBlock. Otherwise return /// SCEVCouldNotCompute. const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } const SCEV * ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, SCEVUnionPredicate &Preds) { return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); } const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getExact(this); } /// Similar to getBackedgeTakenCount, except return the least SCEV value that is /// known never to be less than the actual backedge taken count. const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getMax(this); } /// Push PHI nodes in the header of the given loop onto the given Worklist. static void PushLoopPHIs(const Loop *L, SmallVectorImpl &Worklist) { BasicBlock *Header = L->getHeader(); // Push all Loop-header PHIs onto the Worklist stack. for (BasicBlock::iterator I = Header->begin(); PHINode *PN = dyn_cast(I); ++I) Worklist.push_back(PN); } const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { auto &BTI = getBackedgeTakenInfo(L); if (BTI.hasFullInfo()) return BTI; auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); if (!Pair.second) return Pair.first->second; BackedgeTakenInfo Result = computeBackedgeTakenCount(L, /*AllowPredicates=*/true); return PredicatedBackedgeTakenCounts.find(L)->second = Result; } const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and // update the value. The temporary CouldNotCompute value tells SCEV // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. std::pair::iterator, bool> Pair = BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); if (!Pair.second) return Pair.first->second; // computeBackedgeTakenCount may allocate memory for its result. Inserting it // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result // must be cleared in this scope. BackedgeTakenInfo Result = computeBackedgeTakenCount(L); if (Result.getExact(this) != getCouldNotCompute()) { assert(isLoopInvariant(Result.getExact(this), L) && isLoopInvariant(Result.getMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"); ++NumTripCountsComputed; } else if (Result.getMax(this) == getCouldNotCompute() && isa(L->getHeader()->begin())) { // Only count loops that have phi nodes as not being computable. ++NumTripCountsNotComputed; } // Now that we know more about the trip count for this loop, forget any // existing SCEV values for PHI nodes in this loop since they are only // conservative estimates made without the benefit of trip count // information. This is similar to the code in forgetLoop, except that // it handles SCEVUnknown PHI nodes specially. if (Result.hasAnyInfo()) { SmallVector Worklist; PushLoopPHIs(L, Worklist); SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I).second) continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { const SCEV *Old = It->second; // SCEVUnknown for a PHI either means that it has an unrecognized // structure, or it's a PHI that's in the progress of being computed // by createNodeForPHI. In the former case, additional loop trip // count information isn't going to change anything. In the later // case, createNodeForPHI will perform the necessary updates on its // own when it gets to that point. if (!isa(I) || !isa(Old)) { + eraseValueFromMap(It->first); forgetMemoizedResults(Old); - ValueExprMap.erase(It); } if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } PushDefUseChildren(I, Worklist); } } // Re-lookup the insert position, since the call to // computeBackedgeTakenCount above could result in a // recusive call to getBackedgeTakenInfo (on a different // loop), which would invalidate the iterator computed // earlier. return BackedgeTakenCounts.find(L)->second = Result; } void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. auto RemoveLoopFromBackedgeMap = [L](DenseMap &Map) { auto BTCPos = Map.find(L); if (BTCPos != Map.end()) { BTCPos->second.clear(); Map.erase(BTCPos); } }; RemoveLoopFromBackedgeMap(BackedgeTakenCounts); RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; PushLoopPHIs(L, Worklist); SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); if (!Visited.insert(I).second) continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); forgetMemoizedResults(It->second); - ValueExprMap.erase(It); if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } PushDefUseChildren(I, Worklist); } // Forget all contained loops too, to avoid dangling entries in the // ValuesAtScopes map. for (Loop *I : *L) forgetLoop(I); LoopHasNoAbnormalExits.erase(L); } void ScalarEvolution::forgetValue(Value *V) { Instruction *I = dyn_cast(V); if (!I) return; // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; Worklist.push_back(I); SmallPtrSet Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); if (!Visited.insert(I).second) continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); forgetMemoizedResults(It->second); - ValueExprMap.erase(It); if (PHINode *PN = dyn_cast(I)) ConstantEvolutionLoopExitValue.erase(PN); } PushDefUseChildren(I, Worklist); } } /// Get the exact loop backedge taken count considering all loop exits. A /// computable result can only be returned for loops with a single exit. /// Returning the minimum taken count among all exits is incorrect because one /// of the loop's exit limit's may have been skipped. howFarToZero assumes that /// the limit of each loop test is never skipped. This is a valid assumption as /// long as the loop exits via that test. For precise results, it is the /// caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact( ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); // We need exactly one computable exit. if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); const SCEV *BECount = nullptr; for (auto &ENT : ExitNotTaken) { assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); if (!BECount) BECount = ENT.ExactNotTaken; else if (BECount != ENT.ExactNotTaken) return SE->getCouldNotCompute(); if (Preds && ENT.getPred()) Preds->add(ENT.getPred()); assert((Preds || ENT.hasAlwaysTruePred()) && "Predicate should be always true!"); } assert(BECount && "Invalid not taken count for loop exit"); return BECount; } /// Get the exact not taken count for this loop exit. const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { for (auto &ENT : ExitNotTaken) if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) return ENT.ExactNotTaken; return SE->getCouldNotCompute(); } /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { for (auto &ENT : ExitNotTaken) if (!ENT.hasAlwaysTruePred()) return SE->getCouldNotCompute(); return Max ? Max : SE->getCouldNotCompute(); } bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, ScalarEvolution *SE) const { if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S)) return true; if (!ExitNotTaken.ExitingBlock) return false; for (auto &ENT : ExitNotTaken) if (ENT.ExactNotTaken != SE->getCouldNotCompute() && SE->hasOperand(ENT.ExactNotTaken, S)) return true; return false; } /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( SmallVectorImpl &ExitCounts, bool Complete, const SCEV *MaxCount) : Max(MaxCount) { if (!Complete) ExitNotTaken.setIncomplete(); unsigned NumExits = ExitCounts.size(); if (NumExits == 0) return; ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; // Determine the number of ExitNotTakenExtras structures that we need. unsigned ExtraInfoSize = 0; if (NumExits > 1) ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), ExitCounts.end(), [](EdgeInfo &Entry) { return !Entry.Pred.isAlwaysTrue(); }); else if (!ExitCounts[0].Pred.isAlwaysTrue()) ExtraInfoSize = 1; ExitNotTakenExtras *ENT = nullptr; // Allocate the ExitNotTakenExtras structures and initialize the first // element (ExitNotTaken). if (ExtraInfoSize > 0) { ENT = new ExitNotTakenExtras[ExtraInfoSize]; ExitNotTaken.ExtraInfo = &ENT[0]; *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); } if (NumExits == 1) return; assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); auto &Exits = ExitNotTaken.ExtraInfo->Exits; // Handle the rare case of multiple computable exits. for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { ExitNotTakenExtras *Ptr = nullptr; if (!ExitCounts[i].Pred.isAlwaysTrue()) { Ptr = &ENT[PredPos++]; Ptr->Pred = std::move(ExitCounts[i].Pred); } Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); } } /// Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { ExitNotTaken.ExitingBlock = nullptr; ExitNotTaken.ExactNotTaken = nullptr; delete[] ExitNotTaken.ExtraInfo; } /// Compute the number of times the backedge of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo ScalarEvolution::computeBackedgeTakenCount(const Loop *L, bool AllowPredicates) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); SmallVector ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; const SCEV *MayExitMaxBECount = nullptr; // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. // Do a union of all the predicates here. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. if (EL.Exact == getCouldNotCompute()) // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: // LoopMustExits and LoopMayExits. // // If the exit dominates the loop latch, it is a LoopMustExit otherwise it // is a LoopMayExit. If any computable LoopMustExit is found, then // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is // considered greater than any computable EL.Max. if (EL.Max != getCouldNotCompute() && Latch && DT.dominates(ExitBB, Latch)) { if (!MustExitMaxBECount) MustExitMaxBECount = EL.Max; else { MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max); } } else if (MayExitMaxBECount != getCouldNotCompute()) { if (!MayExitMaxBECount || EL.Max == getCouldNotCompute()) MayExitMaxBECount = EL.Max; else { MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max); } } } const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { // Okay, we've chosen an exiting block. See what condition causes us to exit // at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; BasicBlock *Exit = nullptr; for (auto *SBB : successors(ExitingBlock)) if (!L->contains(SBB)) { if (Exit) // Multiple exit successors. return getCouldNotCompute(); Exit = SBB; } else if (SBB != L->getHeader()) { MustExecuteLoopHeader = false; } // At this point, we know we have a conditional branch that determines whether // the loop is exited. However, we don't know if the branch is executed each // time through the loop. If not, then the execution count of the branch will // not be equal to the trip count of the loop. // // Currently we check for this by checking to see if the Exit branch goes to // the loop header. If so, we know it will always execute the same number of // times as the loop. We also handle the case where the exit block *is* the // loop header. This is common for un-rotated loops. // // If both of those tests fail, walk up the unique predecessor chain to the // header, stopping if there is an edge that doesn't exit the loop. If the // header is reached, the execution count of the branch will be equal to the // trip count of the loop. // // More extensive analysis could be done to handle more cases here. // if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) { // The simple checks failed, try climbing the unique predecessor chain // up to the header. bool Ok = false; for (BasicBlock *BB = ExitingBlock; BB; ) { BasicBlock *Pred = BB->getUniquePredecessor(); if (!Pred) return getCouldNotCompute(); TerminatorInst *PredTerm = Pred->getTerminator(); for (const BasicBlock *PredSucc : PredTerm->successors()) { if (PredSucc == BB) continue; // If the predecessor has a successor that isn't BB and isn't // outside the loop, assume the worst. if (L->contains(PredSucc)) return getCouldNotCompute(); } if (Pred == L->getHeader()) { Ok = true; break; } BB = Pred; } if (!Ok) return getCouldNotCompute(); } bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. return computeExitLimitFromCond( L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), /*ControlsExit=*/IsOnlyExit, AllowPredicates); } if (SwitchInst *SI = dyn_cast(Term)) return computeExitLimitFromSingleExitSwitch(L, SI, Exit, /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. if (EL0.Exact == getCouldNotCompute() || EL1.Exact == getCouldNotCompute()) BECount = getCouldNotCompute(); else BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); if (EL0.Max == getCouldNotCompute()) MaxBECount = EL1.Max; else if (EL1.Max == getCouldNotCompute()) MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. assert(L->contains(FBB) && "Loop block has no successor in loop!"); if (EL0.Max == EL1.Max) MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; } SCEVUnionPredicate NP; NP.add(&EL0.Pred); NP.add(&EL1.Pred); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact // to match, but for EL0.Max and EL1.Max to not. if (isa(MaxBECount) && !isa(BECount)) MaxBECount = BECount; return ExitLimit(BECount, MaxBECount, NP); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. if (EL0.Exact == getCouldNotCompute() || EL1.Exact == getCouldNotCompute()) BECount = getCouldNotCompute(); else BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); if (EL0.Max == getCouldNotCompute()) MaxBECount = EL1.Max; else if (EL1.Max == getCouldNotCompute()) MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. assert(L->contains(TBB) && "Loop block has no successor in loop!"); if (EL0.Max == EL1.Max) MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; } SCEVUnionPredicate NP; NP.add(&EL0.Pred); NP.add(&EL1.Pred); return ExitLimit(BECount, MaxBECount, NP); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) { ExitLimit EL = computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); if (EL.hasFullInfo() || !AllowPredicates) return EL; // Try again, but use SCEV predicates this time. return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, /*AllowPredicates=*/true); } // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to // preserve the CFG and is temporarily leaving constant conditions // in place. if (ConstantInt *CI = dyn_cast(ExitCond)) { if (L->contains(FBB) == !CI->getZExtValue()) // The backedge is always taken. return getCouldNotCompute(); else // The backedge is never taken. return getZero(CI->getType()); } // If it's not an integer or pointer comparison then compute it the hard way. return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; if (!L->contains(FBB)) Cond = ExitCond->getPredicate(); else Cond = ExitCond->getInversePredicate(); // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { ExitLimit ItCnt = computeLoadConstantCompareExitLimit(LI, RHS, L, Cond); if (ItCnt.hasAnyInfo()) return ItCnt; } const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); RHS = getSCEVAtScope(RHS, L); // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { // If there is a loop-invariant, force it into the RHS. std::swap(LHS, RHS); Cond = ICmpInst::getSwappedPredicate(Cond); } // Simplify the operands before analyzing them. (void)SimplifyICmpOperands(Cond, LHS, RHS); // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. if (const SCEVConstant *RHSC = dyn_cast(RHS)) if (const SCEVAddRecExpr *AddRec = dyn_cast(LHS)) if (AddRec->getLoop() == L) { // Form the constant range. ConstantRange CompRange( ICmpInst::makeConstantRange(Cond, RHSC->getAPInt())); const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; } switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } default: break; } auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); if (!isa(ExhaustiveCount)) return ExhaustiveCount; return computeShiftCompareExitLimit(ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, bool ControlsExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. if (Switch->getDefaultDest() == ExitingBlock) return getCouldNotCompute(); assert(L->contains(Switch->getDefaultDest()) && "Default case must not exit the loop!"); const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; return getCouldNotCompute(); } static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE) { const SCEV *InVal = SE.getConstant(C); const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); } /// Given an exit condition of 'icmp op load X, cst', try to see if we can /// compute the backedge execution count. ScalarEvolution::ExitLimit ScalarEvolution::computeLoadConstantCompareExitLimit( LoadInst *LI, Constant *RHS, const Loop *L, ICmpInst::Predicate predicate) { if (LI->isVolatile()) return getCouldNotCompute(); // Check to see if the loaded pointer is a getelementptr of a global. // TODO: Use SCEV instead of manually grubbing with GEPs. GetElementPtrInst *GEP = dyn_cast(LI->getOperand(0)); if (!GEP) return getCouldNotCompute(); // Make sure that it is really a constant global we are gepping, with an // initializer, and make sure the first IDX is really 0. GlobalVariable *GV = dyn_cast(GEP->getOperand(0)); if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || GEP->getNumOperands() < 3 || !isa(GEP->getOperand(1)) || !cast(GEP->getOperand(1))->isNullValue()) return getCouldNotCompute(); // Okay, we allow one non-constant index into the GEP instruction. Value *VarIdx = nullptr; std::vector Indexes; unsigned VarIdxNum = 0; for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) if (ConstantInt *CI = dyn_cast(GEP->getOperand(i))) { Indexes.push_back(CI); } else if (!isa(GEP->getOperand(i))) { if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. VarIdx = GEP->getOperand(i); VarIdxNum = i-2; Indexes.push_back(nullptr); } // Loop-invariant loads may be a byproduct of loop optimization. Skip them. if (!VarIdx) return getCouldNotCompute(); // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. // Check to see if X is a loop variant variable value now. const SCEV *Idx = getSCEV(VarIdx); Idx = getSCEVAtScope(Idx, L); // We can only recognize very limited forms of loop index expressions, in // particular, only affine AddRec's like {C1,+,C2}. const SCEVAddRecExpr *IdxExpr = dyn_cast(Idx); if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || !isa(IdxExpr->getOperand(0)) || !isa(IdxExpr->getOperand(1))) return getCouldNotCompute(); unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { ConstantInt *ItCst = ConstantInt::get( cast(IdxExpr->getType()), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); // Form the GEP offset. Indexes[VarIdxNum] = Val; Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), Indexes); if (!Result) break; // Cannot compute! // Evaluate the condition for this iteration. Result = ConstantExpr::getICmp(predicate, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure if (cast(Result)->getValue().isMinValue()) { ++NumArrayLenItCounts; return getConstant(ItCst); // Found terminating iteration! } } return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { ConstantInt *RHS = dyn_cast(RHSV); if (!RHS) return getCouldNotCompute(); const BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return getCouldNotCompute(); const BasicBlock *Predecessor = L->getLoopPredecessor(); if (!Predecessor) return getCouldNotCompute(); // Return true if V is of the form "LHS `shift_op` ". // Return LHS in OutLHS and shift_opt in OutOpCode. auto MatchPositiveShift = [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { using namespace PatternMatch; ConstantInt *ShiftAmt; if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) OutOpCode = Instruction::LShr; else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) OutOpCode = Instruction::AShr; else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) OutOpCode = Instruction::Shl; else return false; return ShiftAmt->getValue().isStrictlyPositive(); }; // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in // // loop: // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] // %iv.shifted = lshr i32 %iv, // // Return true on a succesful match. Return the corresponding PHI node (%iv // above) in PNOut and the opcode of the shift operation in OpCodeOut. auto MatchShiftRecurrence = [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { Optional PostShiftOpCode; { Instruction::BinaryOps OpC; Value *V; // If we encounter a shift instruction, "peel off" the shift operation, // and remember that we did so. Later when we inspect %iv's backedge // value, we will make sure that the backedge value uses the same // operation. // // Note: the peeled shift operation does not have to be the same // instruction as the one feeding into the PHI's backedge value. We only // really care about it being the same *kind* of shift instruction -- // that's all that is required for our later inferences to hold. if (MatchPositiveShift(LHS, V, OpC)) { PostShiftOpCode = OpC; LHS = V; } } PNOut = dyn_cast(LHS); if (!PNOut || PNOut->getParent() != L->getHeader()) return false; Value *BEValue = PNOut->getIncomingValueForBlock(Latch); Value *OpLHS; return // The backedge value for the PHI node must be a shift by a positive // amount MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && // of the PHI node itself OpLHS == PNOut && // and the kind of shift should be match the kind of shift we peeled // off, if any. (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); }; PHINode *PN; Instruction::BinaryOps OpCode; if (!MatchShiftRecurrence(LHS, PN, OpCode)) return getCouldNotCompute(); const DataLayout &DL = getDataLayout(); // The key rationale for this optimization is that for some kinds of shift // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 // within a finite number of iterations. If the condition guarding the // backedge (in the sense that the backedge is taken if the condition is true) // is false for the value the shift recurrence stabilizes to, then we know // that the backedge is taken only a finite number of times. ConstantInt *StableValue = nullptr; switch (OpCode) { default: llvm_unreachable("Impossible case!"); case Instruction::AShr: { // {K,ashr,} stabilizes to signum(K) in at most // bitwidth(K) iterations. Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); bool KnownZero, KnownOne; ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, Predecessor->getTerminator(), &DT); auto *Ty = cast(RHS->getType()); if (KnownZero) StableValue = ConstantInt::get(Ty, 0); else if (KnownOne) StableValue = ConstantInt::get(Ty, -1, true); else return getCouldNotCompute(); break; } case Instruction::LShr: case Instruction::Shl: // Both {K,lshr,} and {K,shl,} // stabilize to 0 in at most bitwidth(K) iterations. StableValue = ConstantInt::get(cast(RHS->getType()), 0); break; } auto *Result = ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); assert(Result->getType()->isIntegerTy(1) && "Otherwise cannot be an operand to a branch instruction"); if (Result->isZeroValue()) { unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); SCEVUnionPredicate P; return ExitLimit(getCouldNotCompute(), UpperBound, P); } return getCouldNotCompute(); } /// Return true if we can constant fold an instruction of the specified type, /// assuming that all operands were constants. static bool CanConstantFold(const Instruction *I) { if (isa(I) || isa(I) || isa(I) || isa(I) || isa(I) || isa(I)) return true; if (const CallInst *CI = dyn_cast(I)) if (const Function *F = CI->getCalledFunction()) return canConstantFoldCallTo(F); return false; } /// Determine whether this instruction can constant evolve within this loop /// assuming its operands can all constant evolve. static bool canConstantEvolve(Instruction *I, const Loop *L) { // An instruction outside of the loop can't be derived from a loop PHI. if (!L->contains(I)) return false; if (isa(I)) { // We don't currently keep track of the control flow needed to evaluate // PHIs, so we cannot handle PHIs inside of loops. return L->getHeader() == I->getParent(); } // If we won't be able to constant fold this expression even if the operands // are constants, bail early. return CanConstantFold(I); } /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by /// recursing through each instruction operand until reaching a loop header phi. static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap &PHIMap) { // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. PHINode *PHI = nullptr; for (Value *Op : UseInst->operands()) { if (isa(Op)) continue; Instruction *OpInst = dyn_cast(Op); if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; PHINode *P = dyn_cast(OpInst); if (!P) // If this operand is already visited, reuse the prior result. // We may have P != PHI if this is the deepest point at which the // inconsistent paths meet. P = PHIMap.lookup(OpInst); if (!P) { // Recurse and memoize the results, whether a phi is found or not. // This recursive call invalidates pointers into PHIMap. P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); PHIMap[OpInst] = P; } if (!P) return nullptr; // Not evolving from PHI if (PHI && PHI != P) return nullptr; // Evolving from multiple different PHIs. PHI = P; } // This is a expression evolving from a constant PHI! return PHI; } /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node /// in the loop that V is derived from. We allow arbitrary operations along the /// way, but the operands of an operation must either be constants or a value /// derived from a constant PHI. If this expression does not fit with these /// constraints, return null. static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (!I || !canConstantEvolve(I, L)) return nullptr; if (PHINode *PN = dyn_cast(I)) return PN; // Record non-constant instructions contained by the loop. DenseMap PHIMap; return getConstantEvolvingPHIOperands(I, L, PHIMap); } /// EvaluateExpression - Given an expression that passes the /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node /// in the loop has the value PHIVal. If we can't fold this expression for some /// reason, return null. static Constant *EvaluateExpression(Value *V, const Loop *L, DenseMap &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI) { // Convenient constant check, but redundant for recursive calls. if (Constant *C = dyn_cast(V)) return C; Instruction *I = dyn_cast(V); if (!I) return nullptr; if (Constant *C = Vals.lookup(I)) return C; // An instruction inside the loop depends on a value outside the loop that we // weren't given a mapping for, or a value such as a call inside the loop. if (!canConstantEvolve(I, L)) return nullptr; // An unmapped PHI can be due to a branch or another loop inside this loop, // or due to this not being the initial iteration through a loop where we // couldn't compute the evolution of this particular PHI last time. if (isa(I)) return nullptr; std::vector Operands(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { Instruction *Operand = dyn_cast(I->getOperand(i)); if (!Operand) { Operands[i] = dyn_cast(I->getOperand(i)); if (!Operands[i]) return nullptr; continue; } Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); Vals[Operand] = C; if (!C) return nullptr; Operands[i] = C; } if (CmpInst *CI = dyn_cast(I)) return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], DL, TLI); if (LoadInst *LI = dyn_cast(I)) { if (!LI->isVolatile()) return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } return ConstantFoldInstOperands(I, Operands, DL, TLI); } // If every incoming value to PN except the one for BB is a specific Constant, // return that, else return nullptr. static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { Constant *IncomingVal = nullptr; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { if (PN->getIncomingBlock(i) == BB) continue; auto *CurrentVal = dyn_cast(PN->getIncomingValue(i)); if (!CurrentVal) return nullptr; if (IncomingVal != CurrentVal) { if (IncomingVal) return nullptr; IncomingVal = CurrentVal; } } return IncomingVal; } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. Constant * ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, const APInt &BEs, const Loop *L) { auto I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; if (BEs.ugt(MaxBruteForceIterations)) return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it. Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; DenseMap CurrentIterVals; BasicBlock *Header = L->getHeader(); assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return nullptr; for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } if (!CurrentIterVals.count(PN)) return RetVal = nullptr; Value *BEValue = PN->getIncomingValueForBlock(Latch); // Execute the loop symbolically to determine the exit value. if (BEs.getActiveBits() >= 32) return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; const DataLayout &DL = getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! // Compute the value of the PHIs for the next iteration. // EvaluateExpression adds non-phi values to the CurrentIterVals map. DenseMap NextIterVals; Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); if (!NextPHI) return nullptr; // Couldn't evaluate! NextIterVals[PN] = NextPHI; bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; // Also evaluate the other PHI nodes. However, we don't get to stop if we // cease to be able to evaluate one of them or if they stop evolving, // because that doesn't necessarily prevent us from computing PN. SmallVector, 8> PHIsToCompute; for (const auto &I : CurrentIterVals) { PHINode *PHI = dyn_cast(I.first); if (!PHI || PHI == PN || PHI->getParent() != Header) continue; PHIsToCompute.emplace_back(PHI, I.second); } // We use two distinct loops because EvaluateExpression may invalidate any // iterators into CurrentIterVals. for (const auto &I : PHIsToCompute) { PHINode *PHI = I.first; Constant *&NextPHI = NextIterVals[PHI]; if (!NextPHI) { // Not already computed. Value *BEValue = PHI->getIncomingValueForBlock(Latch); NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); } if (NextPHI != I.second) StoppedEvolving = false; } // If all entries in CurrentIterVals == NextIterVals then we can stop // iterating, the loop can't continue to change. if (StoppedEvolving) return RetVal = CurrentIterVals[PN]; CurrentIterVals.swap(NextIterVals); } } const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); if (!PN) return getCouldNotCompute(); // If the loop is canonicalized, the PHI will have exactly two entries. // That's the only form we support here. if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); DenseMap CurrentIterVals; BasicBlock *Header = L->getHeader(); assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Should follow from NumIncomingValues == 2!"); for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } if (!CurrentIterVals.count(PN)) return getCouldNotCompute(); // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of // "ExitWhen". unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. const DataLayout &DL = getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ auto *CondVal = dyn_cast_or_null( EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); // Couldn't symbolically evaluate. if (!CondVal) return getCouldNotCompute(); if (CondVal->getValue() == uint64_t(ExitWhen)) { ++NumBruteForceTripCountsComputed; return getConstant(Type::getInt32Ty(getContext()), IterationNum); } // Update all the PHI nodes for the next iteration. DenseMap NextIterVals; // Create a list of which PHIs we need to compute. We want to do this before // calling EvaluateExpression on them because that may invalidate iterators // into CurrentIterVals. SmallVector PHIsToCompute; for (const auto &I : CurrentIterVals) { PHINode *PHI = dyn_cast(I.first); if (!PHI || PHI->getParent() != Header) continue; PHIsToCompute.push_back(PHI); } for (PHINode *PHI : PHIsToCompute) { Constant *&NextPHI = NextIterVals[PHI]; if (NextPHI) continue; // Already computed! Value *BEValue = PHI->getIncomingValueForBlock(Latch); NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); } CurrentIterVals.swap(NextIterVals); } // Too many iterations were needed to evaluate. return getCouldNotCompute(); } const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { SmallVector, 2> &Values = ValuesAtScopes[V]; // Check to see if we've folded this expression at this loop before. for (auto &LS : Values) if (LS.first == L) return LS.second ? LS.second : V; Values.emplace_back(L, nullptr); // Otherwise compute it. const SCEV *C = computeSCEVAtScope(V, L); for (auto &LS : reverse(ValuesAtScopes[V])) if (LS.first == L) { LS.second = C; break; } return C; } /// This builds up a Constant using the ConstantExpr interface. That way, we /// will return Constants for objects which aren't represented by a /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// Returns NULL if the SCEV isn't representable as a Constant. static Constant *BuildConstantFromSCEV(const SCEV *V) { switch (static_cast(V->getSCEVType())) { case scCouldNotCompute: case scAddRecExpr: break; case scConstant: return cast(V)->getValue(); case scUnknown: return dyn_cast(cast(V)->getValue()); case scSignExtend: { const SCEVSignExtendExpr *SS = cast(V); if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) return ConstantExpr::getSExt(CastOp, SS->getType()); break; } case scZeroExtend: { const SCEVZeroExtendExpr *SZ = cast(V); if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) return ConstantExpr::getZExt(CastOp, SZ->getType()); break; } case scTruncate: { const SCEVTruncateExpr *ST = cast(V); if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) return ConstantExpr::getTrunc(CastOp, ST->getType()); break; } case scAddExpr: { const SCEVAddExpr *SA = cast(V); if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { if (PointerType *PTy = dyn_cast(C->getType())) { unsigned AS = PTy->getAddressSpace(); Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); C = ConstantExpr::getBitCast(C, DestPtrTy); } for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); if (!C2) return nullptr; // First pointer! if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { unsigned AS = C2->getType()->getPointerAddressSpace(); std::swap(C, C2); Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); // The offsets have been converted to bytes. We can add bytes to an // i8* by GEP with the byte count in the first index. C = ConstantExpr::getBitCast(C, DestPtrTy); } // Don't bother trying to sum two pointers. We probably can't // statically compute a load that results from it anyway. if (C2->getType()->isPointerTy()) return nullptr; if (PointerType *PTy = dyn_cast(C->getType())) { if (PTy->getElementType()->isStructTy()) C2 = ConstantExpr::getIntegerCast( C2, Type::getInt32Ty(C->getContext()), true); C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); } else C = ConstantExpr::getAdd(C, C2); } return C; } break; } case scMulExpr: { const SCEVMulExpr *SM = cast(V); if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { // Don't bother with pointers at all. if (C->getType()->isPointerTy()) return nullptr; for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); if (!C2 || C2->getType()->isPointerTy()) return nullptr; C = ConstantExpr::getMul(C, C2); } return C; } break; } case scUDivExpr: { const SCEVUDivExpr *SU = cast(V); if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) if (LHS->getType() == RHS->getType()) return ConstantExpr::getUDiv(LHS, RHS); break; } case scSMaxExpr: case scUMaxExpr: break; // TODO: smax, umax. } return nullptr; } const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (isa(V)) return V; // If this instruction is evolved from a constant-evolving PHI, compute the // exit value from the loop without using SCEVs. if (const SCEVUnknown *SU = dyn_cast(V)) { if (Instruction *I = dyn_cast(SU->getValue())) { const Loop *LI = this->LI[I->getParent()]; if (LI && LI->getParentLoop() == L) // Looking for loop exit value. if (PHINode *PN = dyn_cast(I)) if (PN->getParent() == LI->getHeader()) { // Okay, there is no closed form solution for the PHI node. Check // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); if (const SCEVConstant *BTCC = dyn_cast(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. Constant *RV = getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); if (RV) return getSCEV(RV); } } // Okay, this is an expression that we cannot symbolically evaluate // into a SCEV. Check to see if it's possible to symbolically evaluate // the arguments into constants, and if so, try to constant propagate the // result. This is particularly useful for computing loop exit values. if (CanConstantFold(I)) { SmallVector Operands; bool MadeImprovement = false; for (Value *Op : I->operands()) { if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); continue; } // If any of the operands is non-constant and if they are // non-integer and non-pointer, don't even try to analyze them // with scev techniques. if (!isSCEVable(Op->getType())) return V; const SCEV *OrigV = getSCEV(Op); const SCEV *OpV = getSCEVAtScope(OrigV, L); MadeImprovement |= OrigV != OpV; Constant *C = BuildConstantFromSCEV(OpV); if (!C) return V; if (C->getType() != Op->getType()) C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, Op->getType(), false), C, Op->getType()); Operands.push_back(C); } // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { Constant *C = nullptr; const DataLayout &DL = getDataLayout(); if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], DL, &TLI); else if (const LoadInst *LI = dyn_cast(I)) { if (!LI->isVolatile()) C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } else C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; return getSCEV(C); } } } // This is some other type of SCEVUnknown, just return it. return V; } if (const SCEVCommutativeExpr *Comm = dyn_cast(V)) { // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); if (OpAtScope != Comm->getOperand(i)) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. SmallVector NewOps(Comm->op_begin(), Comm->op_begin()+i); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) { OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); NewOps.push_back(OpAtScope); } if (isa(Comm)) return getAddExpr(NewOps); if (isa(Comm)) return getMulExpr(NewOps); if (isa(Comm)) return getSMaxExpr(NewOps); if (isa(Comm)) return getUMaxExpr(NewOps); llvm_unreachable("Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. return Comm; } if (const SCEVUDivExpr *Div = dyn_cast(V)) { const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant return getUDivExpr(LHS, RHS); } // If this is a loop recurrence for a loop that does not contain L, then we // are dealing with the final value computed by the loop. if (const SCEVAddRecExpr *AddRec = dyn_cast(V)) { // First, attempt to evaluate each operand. // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); if (OpAtScope == AddRec->getOperand(i)) continue; // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. SmallVector NewOps(AddRec->op_begin(), AddRec->op_begin()+i); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); const SCEV *FoldedRec = getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); AddRec = dyn_cast(FoldedRec); // The addrec may be folded to a nonrecurrence, for example, if the // induction variable is multiplied by zero after constant folding. Go // ahead and return the folded value. if (!AddRec) return FoldedRec; break; } // If the scope is outside the addrec's loop, evaluate it by using the // loop exit value of the addrec. if (!AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; // Then, evaluate the AddRec. return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); } return AddRec; } if (const SCEVZeroExtendExpr *Cast = dyn_cast(V)) { const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getZeroExtendExpr(Op, Cast->getType()); } if (const SCEVSignExtendExpr *Cast = dyn_cast(V)) { const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getSignExtendExpr(Op, Cast->getType()); } if (const SCEVTruncateExpr *Cast = dyn_cast(V)) { const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getTruncateExpr(Op, Cast->getType()); } llvm_unreachable("Unknown SCEV type!"); } const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } /// Finds the minimum unsigned root of the following equation: /// /// A * X = B (mod N) /// /// where N = 2^BW and BW is the common bit width of A and B. The signedness of /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); assert(BW == B.getBitWidth() && "Bit widths must be the same."); assert(A != 0 && "A must be non-zero."); // 1. D = gcd(A, N) // // The gcd of A and N may have only one prime factor: 2. The number of // trailing zeros in A is its multiplicity uint32_t Mult2 = A.countTrailingZeros(); // D = 2^Mult2 // 2. Check if B is divisible by D. // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. if (B.countTrailingZeros() < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic // modulo (N / D). // // (N / D) may need BW+1 bits in its representation. Hence, we'll use this // bit width during computations. APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D APInt Mod(BW + 1, 0); Mod.setBit(BW - Mult2); // Mod = N / D APInt I = AD.multiplicativeInverse(Mod); // 4. Compute the minimum unsigned root of the equation: // I * (B / D) mod (N / D) APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); // The result is guaranteed to be less than 2^BW so we may truncate it to BW // bits. return SE.getConstant(Result.trunc(BW)); } /// Find the roots of the quadratic equation for the given quadratic chrec /// {L,+,M,+,N}. This returns either the two roots (which might be the same) or /// two SCEVCouldNotCompute objects. /// static Optional> SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); const SCEVConstant *MC = dyn_cast(AddRec->getOperand(1)); const SCEVConstant *NC = dyn_cast(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. if (!LC || !MC || !NC) return None; uint32_t BitWidth = LC->getAPInt().getBitWidth(); const APInt &L = LC->getAPInt(); const APInt &M = MC->getAPInt(); const APInt &N = NC->getAPInt(); APInt Two(BitWidth, 2); APInt Four(BitWidth, 4); { using namespace APIntOps; const APInt& C = L; // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C // The B coefficient is M-N/2 APInt B(M); B -= sdiv(N,Two); // The A coefficient is N/2 APInt A(N.sdiv(Two)); // Compute the B^2-4ac term. APInt SqrtTerm(B); SqrtTerm *= B; SqrtTerm -= Four * (A * C); if (SqrtTerm.isNegative()) { // The loop is provably infinite. return None; } // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest // integer value or else APInt::sqrt() will assert. APInt SqrtVal(SqrtTerm.sqrt()); // Compute the two solutions for the quadratic formula. // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA(A << 1); if (TwoA.isMinValue()) return None; LLVMContext &Context = SE.getContext(); ConstantInt *Solution1 = ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); return std::make_pair(cast(SE.getConstant(Solution1)), cast(SE.getConstant(Solution2))); } // end APIntOps namespace } ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, bool AllowPredicates) { // This is only used for loops with a "x != y" exit test. The exit condition // is now expressed as a single expression, V = x-y. So the exit test is // effectively V != 0. We know and take advantage of the fact that this // expression only being used in a comparison by zero context. SCEVUnionPredicate P; // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. if (C->getValue()->isZero()) return C; return getCouldNotCompute(); // Otherwise it will loop infinitely. } const SCEVAddRecExpr *AddRec = dyn_cast(V); if (!AddRec && AllowPredicates) // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. AddRec = convertSCEVToAddRecWithPredicates(V, L, P); if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { const SCEVConstant *R1 = Roots->first; const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) return ExitLimit(R1, R1, P); // We found a quadratic root! } } return getCouldNotCompute(); } // Otherwise we can only handle this if it is affine. if (!AddRec->isAffine()) return getCouldNotCompute(); // If this is an affine expression, the execution count of this branch is // the minimum unsigned root of the following equation: // // Start + Step*N = 0 (mod 2^BW) // // equivalent to: // // Step*N = -Start (mod 2^BW) // // where BW is the common bit width of Start and Step. // Get the initial value for the loop. const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); // For now we handle only constant steps. // // TODO: Handle a nonconstant Step given AddRec. If the // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. // We have not yet seen any such cases. const SCEVConstant *StepC = dyn_cast(Step); if (!StepC || StepC->getValue()->equalsInt(0)) return getCouldNotCompute(); // For positive steps (counting up until unsigned overflow): // N = -Start/Step (as unsigned) // For negative steps (counting down to zero): // N = Start/-Step // First compute the unsigned distance from zero in the direction of Step. bool CountDown = StepC->getAPInt().isNegative(); const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); // Handle unitary steps, which cannot wraparound. // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { ConstantRange CR = getUnsignedRange(Start); const SCEV *MaxBECount; if (!CountDown && CR.getUnsignedMin().isMinValue()) // When counting up, the worst starting value is 1, not 0. MaxBECount = CR.getUnsignedMax().isMinValue() ? getConstant(APInt::getMinValue(CR.getBitWidth())) : getConstant(APInt::getMaxValue(CR.getBitWidth())); else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); return ExitLimit(Distance, MaxBECount, P); } // As a special case, handle the instance where Step is a positive power of // two. In this case, determining whether Step divides Distance evenly can be // done by counting and comparing the number of trailing zeros of Step and // Distance. if (!CountDown) { const APInt &StepV = StepC->getAPInt(); // StepV.isPowerOf2() returns true if StepV is an positive power of two. It // also returns true if StepV is maximally negative (eg, INT_MIN), but that // case is not handled as this code is guarded by !CountDown. if (StepV.isPowerOf2() && GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { // Here we've constrained the equation to be of the form // // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) // // where we're operating on a W bit wide integer domain and k is // non-negative. The smallest unsigned solution for X is the trip count. // // (0) is equivalent to: // // 2^(N + k) * Distance' - 2^N * X = L * 2^W // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N // <=> 2^k * Distance' - X = L * 2^(W - N) // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) // // The smallest X satisfying (1) is unsigned remainder of dividing the LHS // by 2^(W - N). // // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) // // E.g. say we're solving // // 2 * Val = 2 * X (in i8) ... (3) // // then from (2), we get X = Val URem i8 128 (k = 0 in this case). // // Note: It is tempting to solve (3) by setting X = Val, but Val is not // necessarily the smallest unsigned value of X that satisfies (3). // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) // is i8 1, not i8 -127 const auto *ModuloResult = getUDivExactExpr(Distance, Step); // Since SCEV does not have a URem node, we construct one using a truncate // and a zero extend. unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros(); auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); auto *WideTy = Distance->getType(); const SCEV *Limit = getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); return ExitLimit(Limit, Limit, P); } } // If the condition controls loop exit (the loop exits only if the expression // is true) and the addition is no-wrap we can use unsigned divide to // compute the backedge count. In this case, the step may not divide the // distance, but we don't care because if the condition is "missed" the loop // will have undefined behavior due to wrapping. if (ControlsExit && AddRec->hasNoSelfWrap() && loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); return ExitLimit(Exact, Exact, P); } // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) { const SCEV *E = SolveLinEquationWithOverflow( StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); return ExitLimit(E, E, P); } return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. // If the value is a constant, check to see if it is known to be non-zero // already. If so, the backedge will execute zero times. if (const SCEVConstant *C = dyn_cast(V)) { if (!C->getValue()->isNullValue()) return getZero(C->getType()); return getCouldNotCompute(); // Otherwise it will loop infinitely. } // We could implement others, but I really doubt anyone writes loops like // this, and if they did, they would already be constant folded. return getCouldNotCompute(); } std::pair ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // If the block has a unique predecessor, then there is no path from the // predecessor to the block that does not go through the direct edge // from the predecessor to the block. if (BasicBlock *Pred = BB->getSinglePredecessor()) return {Pred, BB}; // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. if (Loop *L = LI.getLoopFor(BB)) return {L->getLoopPredecessor(), L->getHeader()}; return {nullptr, nullptr}; } /// SCEV structural equivalence is usually sufficient for testing whether two /// expressions are equal, however for the purposes of looking for a condition /// guarding a loop, it can be useful to be a little more general, since a /// front-end may have replicated the controlling expression. /// static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. if (A == B) return true; auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { // Not all instructions that are "identical" compute the same value. For // instance, two distinct alloca instructions allocating the same type are // identical and do not read memory; but compute distinct values. return A->isIdenticalTo(B) && (isa(A) || isa(A)); }; // Otherwise, if they're both SCEVUnknown, it's possible that they hold // two different instructions with the same value. Check for this case. if (const SCEVUnknown *AU = dyn_cast(A)) if (const SCEVUnknown *BU = dyn_cast(B)) if (const Instruction *AI = dyn_cast(AU->getValue())) if (const Instruction *BI = dyn_cast(BU->getValue())) if (ComputesEqualValues(AI, BI)) return true; // Otherwise assume they may have a different value. return false; } bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { bool Changed = false; // If we hit the max recursion limit bail out. if (Depth >= 3) return false; // Canonicalize a constant to the right side. if (const SCEVConstant *LHSC = dyn_cast(LHS)) { // Check for both operands constant. if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (ConstantExpr::getICmp(Pred, LHSC->getValue(), RHSC->getValue())->isNullValue()) goto trivially_false; else goto trivially_true; } // Otherwise swap the operands to put the constant on the right. std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); Changed = true; } // If we're comparing an addrec with a value which is loop-invariant in the // addrec's loop, put the addrec on the left. Also make a dominance check, // as both operands could be addrecs loop-invariant in each other's loop. if (const SCEVAddRecExpr *AR = dyn_cast(RHS)) { const Loop *L = AR->getLoop(); if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); Changed = true; } } // If there's a constant operand, canonicalize comparisons with boundary // cases, and canonicalize *-or-equal comparisons to regular comparisons. if (const SCEVConstant *RC = dyn_cast(RHS)) { const APInt &RA = RC->getAPInt(); switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_NE: // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. if (!RA) if (const SCEVAddExpr *AE = dyn_cast(LHS)) if (const SCEVMulExpr *ME = dyn_cast(AE->getOperand(0))) if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { RHS = AE->getOperand(1); LHS = ME->getOperand(1); Changed = true; } break; case ICmpInst::ICMP_UGE: if ((RA - 1).isMinValue()) { Pred = ICmpInst::ICMP_NE; RHS = getConstant(RA - 1); Changed = true; break; } if (RA.isMaxValue()) { Pred = ICmpInst::ICMP_EQ; Changed = true; break; } if (RA.isMinValue()) goto trivially_true; Pred = ICmpInst::ICMP_UGT; RHS = getConstant(RA - 1); Changed = true; break; case ICmpInst::ICMP_ULE: if ((RA + 1).isMaxValue()) { Pred = ICmpInst::ICMP_NE; RHS = getConstant(RA + 1); Changed = true; break; } if (RA.isMinValue()) { Pred = ICmpInst::ICMP_EQ; Changed = true; break; } if (RA.isMaxValue()) goto trivially_true; Pred = ICmpInst::ICMP_ULT; RHS = getConstant(RA + 1); Changed = true; break; case ICmpInst::ICMP_SGE: if ((RA - 1).isMinSignedValue()) { Pred = ICmpInst::ICMP_NE; RHS = getConstant(RA - 1); Changed = true; break; } if (RA.isMaxSignedValue()) { Pred = ICmpInst::ICMP_EQ; Changed = true; break; } if (RA.isMinSignedValue()) goto trivially_true; Pred = ICmpInst::ICMP_SGT; RHS = getConstant(RA - 1); Changed = true; break; case ICmpInst::ICMP_SLE: if ((RA + 1).isMaxSignedValue()) { Pred = ICmpInst::ICMP_NE; RHS = getConstant(RA + 1); Changed = true; break; } if (RA.isMinSignedValue()) { Pred = ICmpInst::ICMP_EQ; Changed = true; break; } if (RA.isMaxSignedValue()) goto trivially_true; Pred = ICmpInst::ICMP_SLT; RHS = getConstant(RA + 1); Changed = true; break; case ICmpInst::ICMP_UGT: if (RA.isMinValue()) { Pred = ICmpInst::ICMP_NE; Changed = true; break; } if ((RA + 1).isMaxValue()) { Pred = ICmpInst::ICMP_EQ; RHS = getConstant(RA + 1); Changed = true; break; } if (RA.isMaxValue()) goto trivially_false; break; case ICmpInst::ICMP_ULT: if (RA.isMaxValue()) { Pred = ICmpInst::ICMP_NE; Changed = true; break; } if ((RA - 1).isMinValue()) { Pred = ICmpInst::ICMP_EQ; RHS = getConstant(RA - 1); Changed = true; break; } if (RA.isMinValue()) goto trivially_false; break; case ICmpInst::ICMP_SGT: if (RA.isMinSignedValue()) { Pred = ICmpInst::ICMP_NE; Changed = true; break; } if ((RA + 1).isMaxSignedValue()) { Pred = ICmpInst::ICMP_EQ; RHS = getConstant(RA + 1); Changed = true; break; } if (RA.isMaxSignedValue()) goto trivially_false; break; case ICmpInst::ICMP_SLT: if (RA.isMaxSignedValue()) { Pred = ICmpInst::ICMP_NE; Changed = true; break; } if ((RA - 1).isMinSignedValue()) { Pred = ICmpInst::ICMP_EQ; RHS = getConstant(RA - 1); Changed = true; break; } if (RA.isMinSignedValue()) goto trivially_false; break; } } // Check for obvious equality. if (HasSameValue(LHS, RHS)) { if (ICmpInst::isTrueWhenEqual(Pred)) goto trivially_true; if (ICmpInst::isFalseWhenEqual(Pred)) goto trivially_false; } // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by // adding or subtracting 1 from one of the operands. switch (Pred) { case ICmpInst::ICMP_SLE: if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; Changed = true; } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; Changed = true; } break; case ICmpInst::ICMP_SGE: if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; Changed = true; } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; Changed = true; } break; case ICmpInst::ICMP_ULE: if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_ULT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); Pred = ICmpInst::ICMP_ULT; Changed = true; } break; case ICmpInst::ICMP_UGE: if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_UGT; Changed = true; } break; default: break; } // TODO: More simplifications are possible here. // Recursively simplify until we either hit a recursion limit or nothing // changes. if (Changed) return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); return Changed; trivially_true: // Return 0 == 0. LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); Pred = ICmpInst::ICMP_EQ; return true; trivially_false: // Return 0 != 0. LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); Pred = ICmpInst::ICMP_NE; return true; } bool ScalarEvolution::isKnownNegative(const SCEV *S) { return getSignedRange(S).getSignedMax().isNegative(); } bool ScalarEvolution::isKnownPositive(const SCEV *S) { return getSignedRange(S).getSignedMin().isStrictlyPositive(); } bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { return !getSignedRange(S).getSignedMin().isNegative(); } bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { return !getSignedRange(S).getSignedMax().isStrictlyPositive(); } bool ScalarEvolution::isKnownNonZero(const SCEV *S) { return isKnownNegative(S) || isKnownPositive(S); } bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); // If LHS or RHS is an addrec, check to see if the condition is true in // every iteration of the loop. // If LHS and RHS are both addrec, both conditions must be true in // every iteration of the loop. const SCEVAddRecExpr *LAR = dyn_cast(LHS); const SCEVAddRecExpr *RAR = dyn_cast(RHS); bool LeftGuarded = false; bool RightGuarded = false; if (LAR) { const Loop *L = LAR->getLoop(); if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) && isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) { if (!RAR) return true; LeftGuarded = true; } } if (RAR) { const Loop *L = RAR->getLoop(); if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) && isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) { if (!LAR) return true; RightGuarded = true; } } if (LeftGuarded && RightGuarded) return true; if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) return true; // Otherwise see what can be done with known constant ranges. return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); } bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing) { bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); #ifndef NDEBUG // Verify an invariant: inverting the predicate should turn a monotonically // increasing change to a monotonically decreasing one, and vice versa. bool IncreasingSwapped; bool ResultSwapped = isMonotonicPredicateImpl( LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); assert(Result == ResultSwapped && "should be able to analyze both!"); if (ResultSwapped) assert(Increasing == !IncreasingSwapped && "monotonicity should flip as we flip the predicate"); #endif return Result; } bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred, bool &Increasing) { // A zero step value for LHS means the induction variable is essentially a // loop invariant value. We don't really depend on the predicate actually // flipping from false to true (for increasing predicates, and the other way // around for decreasing predicates), all we care about is that *if* the // predicate changes then it only changes from false to true. // // A zero step value in itself is not very useful, but there may be places // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be // as general as possible. switch (Pred) { default: return false; // Conservative answer case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: if (!LHS->hasNoUnsignedWrap()) return false; Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; return true; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: { if (!LHS->hasNoSignedWrap()) return false; const SCEV *Step = LHS->getStepRecurrence(*this); if (isKnownNonNegative(Step)) { Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; return true; } if (isKnownNonPositive(Step)) { Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; return true; } return false; } } llvm_unreachable("switch has default clause!"); } bool ScalarEvolution::isLoopInvariantPredicate( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, const SCEV *&InvariantRHS) { // If there is a loop-invariant, force it into the RHS, otherwise bail out. if (!isLoopInvariant(RHS, L)) { if (!isLoopInvariant(LHS, L)) return false; std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); } const SCEVAddRecExpr *ArLHS = dyn_cast(LHS); if (!ArLHS || ArLHS->getLoop() != L) return false; bool Increasing; if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) return false; // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to // true as the loop iterates, and the backedge is control dependent on // "ArLHS `Pred` RHS" == true then we can reason as follows: // // * if the predicate was false in the first iteration then the predicate // is never evaluated again, since the loop exits without taking the // backedge. // * if the predicate was true in the first iteration then it will // continue to be true for all future iterations since it is // monotonically increasing. // // For both the above possibilities, we can replace the loop varying // predicate with its value on the first iteration of the loop (which is // loop invariant). // // A similar reasoning applies for a monotonically decreasing predicate, by // replacing true with false and false with true in the above two bullets. auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) return false; InvariantPred = Pred; InvariantLHS = ArLHS->getStart(); InvariantRHS = RHS; return true; } bool ScalarEvolution::isKnownPredicateViaConstantRanges( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); // This code is split out from isKnownPredicate because it is called from // within isLoopEntryGuardedByCond. auto CheckRanges = [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) .contains(RangeLHS); }; // The check at the top of the function catches the case where the values are // known to be equal. if (Pred == CmpInst::ICMP_EQ) return false; if (Pred == CmpInst::ICMP_NE) return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || isKnownNonZero(getMinusSCEV(LHS, RHS)); if (CmpInst::isSigned(Pred)) return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Match Result to (X + Y) where Y is a constant integer. // Return Y via OutY. auto MatchBinaryAddToConst = [this](const SCEV *Result, const SCEV *X, APInt &OutY, SCEV::NoWrapFlags ExpectedFlags) { const SCEV *NonConstOp, *ConstOp; SCEV::NoWrapFlags FlagsPresent; if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || !isa(ConstOp) || NonConstOp != X) return false; OutY = cast(ConstOp)->getAPInt(); return (FlagsPresent & ExpectedFlags) == ExpectedFlags; }; APInt C; switch (Pred) { default: break; case ICmpInst::ICMP_SGE: std::swap(LHS, RHS); case ICmpInst::ICMP_SLE: // X s<= (X + C) if C >= 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) return true; // (X + C) s<= X if C <= 0 if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && !C.isStrictlyPositive()) return true; break; case ICmpInst::ICMP_SGT: std::swap(LHS, RHS); case ICmpInst::ICMP_SLT: // X s< (X + C) if C > 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isStrictlyPositive()) return true; // (X + C) s< X if C < 0 if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) return true; break; } return false; } bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) return false; // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on // the stack can result in exponential time complexity. SaveAndRestore Restore(ProvingSplitPredicate, true); // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L // // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use // isKnownPredicate. isKnownPredicate is more powerful, but also more // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the // interesting cases seen in practice. We can consider "upgrading" L >= 0 to // use isKnownPredicate later if needed. return isKnownNonNegative(RHS) && isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // No need to even try if we know the module has no guards. if (!HasGuards) return false; return any_of(*BB, [&](Instruction &I) { using namespace llvm::PatternMatch; Value *Condition; return match(&I, m_Intrinsic( m_Value(Condition))) && isImpliedCond(Pred, LHS, RHS, Condition, false); }); } /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. bool ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) return true; if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) return true; BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return false; BranchInst *LoopContinuePredicate = dyn_cast(Latch->getTerminator()); if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(), LoopContinuePredicate->getSuccessor(0) != L->getHeader())) return true; // We don't want more than one activation of the following loops on the stack // -- that can lead to O(n!) time complexity. if (WalkingBEDominatingConds) return false; SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true); // See if we can exploit a trip count to prove the predicate. const auto &BETakenInfo = getBackedgeTakenInfo(L); const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); if (LatchBECount != getCouldNotCompute()) { // We know that Latch branches back to the loop header exactly // LatchBECount times. This means the backdege condition at Latch is // equivalent to "{0,+,1} u< LatchBECount". Type *Ty = LatchBECount->getType(); auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); const SCEV *LoopCounter = getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, LatchBECount)) return true; } // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { if (!AssumeVH) continue; auto *CI = cast(AssumeVH); if (!DT.dominates(CI, Latch->getTerminator())) continue; if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) return true; } // If the loop is not reachable from the entry block, we risk running into an // infinite loop as we walk up into the dom tree. These loops do not matter // anyway, so we just return a conservative answer when we see them. if (!DT.isReachableFromEntry(L->getHeader())) return false; if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) return true; for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; DTN != HeaderDTN; DTN = DTN->getIDom()) { assert(DTN && "should reach the loop header before reaching the root!"); BasicBlock *BB = DTN->getBlock(); if (isImpliedViaGuard(BB, Pred, LHS, RHS)) return true; BasicBlock *PBB = BB->getSinglePredecessor(); if (!PBB) continue; BranchInst *ContinuePredicate = dyn_cast(PBB->getTerminator()); if (!ContinuePredicate || !ContinuePredicate->isConditional()) continue; Value *Condition = ContinuePredicate->getCondition(); // If we have an edge `E` within the loop body that dominates the only // latch, the condition guarding `E` also guards the backedge. This // reasoning works only for loops with a single latch. BasicBlockEdge DominatingEdge(PBB, BB); if (DominatingEdge.isSingleEdge()) { // We're constructively (and conservatively) enumerating edges within the // loop body that dominate the latch. The dominator tree better agree // with us on this: assert(DT.dominates(DominatingEdge, Latch) && "should be!"); if (isImpliedCond(Pred, LHS, RHS, Condition, BB != ContinuePredicate->getSuccessor(0))) return true; } } return false; } bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) return false; if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) return true; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. for (std::pair Pair(L->getLoopPredecessor(), L->getHeader()); Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) return true; BranchInst *LoopEntryPredicate = dyn_cast(Pair.first->getTerminator()); if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional()) continue; if (isImpliedCond(Pred, LHS, RHS, LoopEntryPredicate->getCondition(), LoopEntryPredicate->getSuccessor(0) != Pair.second)) return true; } // Check conditions due to any @llvm.assume intrinsics. for (auto &AssumeVH : AC.assumptions()) { if (!AssumeVH) continue; auto *CI = cast(AssumeVH); if (!DT.dominates(CI, L->getHeader())) continue; if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) return true; } return false; } namespace { /// RAII wrapper to prevent recursive application of isImpliedCond. /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are /// currently evaluating isImpliedCond. struct MarkPendingLoopPredicate { Value *Cond; DenseSet &LoopPreds; bool Pending; MarkPendingLoopPredicate(Value *C, DenseSet &LP) : Cond(C), LoopPreds(LP) { Pending = !LoopPreds.insert(Cond).second; } ~MarkPendingLoopPredicate() { if (!Pending) LoopPreds.erase(Cond); } }; } // end anonymous namespace bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, bool Inverse) { MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates); if (Mark.Pending) return false; // Recursively handle And and Or conditions. if (BinaryOperator *BO = dyn_cast(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); } } ICmpInst *ICI = dyn_cast(FoundCondValue); if (!ICI) return false; // Now that we found a conditional branch that dominates the loop or controls // the loop latch. Check to see if it is the comparison we are looking for. ICmpInst::Predicate FoundPred; if (Inverse) FoundPred = ICI->getInversePredicate(); else FoundPred = ICI->getPredicate(); const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); } bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < getTypeSizeInBits(FoundLHS->getType())) { if (CmpInst::isSigned(Pred)) { LHS = getSignExtendExpr(LHS, FoundLHS->getType()); RHS = getSignExtendExpr(RHS, FoundLHS->getType()); } else { LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); } } else if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(FoundLHS->getType())) { if (CmpInst::isSigned(FoundPred)) { FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); } else { FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); } } // Canonicalize the query to match the way instcombine will have // canonicalized the comparison. if (SimplifyICmpOperands(Pred, LHS, RHS)) if (LHS == RHS) return CmpInst::isTrueWhenEqual(Pred); if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) if (FoundLHS == FoundRHS) return CmpInst::isFalseWhenEqual(FoundPred); // Check to see if we can make the LHS or RHS match. if (LHS == FoundRHS || RHS == FoundLHS) { if (isa(RHS)) { std::swap(FoundLHS, FoundRHS); FoundPred = ICmpInst::getSwappedPredicate(FoundPred); } else { std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); } } // Check whether the found predicate is the same as the desired predicate. if (FoundPred == Pred) return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); // Check whether swapping the found predicate makes it the same as the // desired predicate. if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { if (isa(RHS)) return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); else return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, LHS, FoundLHS, FoundRHS); } // Unsigned comparison is the same as signed comparison when both the operands // are non-negative. if (CmpInst::isUnsigned(FoundPred) && CmpInst::getSignedPredicate(FoundPred) == Pred && isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && (isa(FoundLHS) || isa(FoundRHS))) { const SCEVConstant *C = nullptr; const SCEV *V = nullptr; if (isa(FoundLHS)) { C = cast(FoundLHS); V = FoundRHS; } else { C = cast(FoundRHS); V = FoundLHS; } // The guarding predicate tells us that C != V. If the known range // of V is [C, t), we can sharpen the range to [C + 1, t). The // range we consider has to correspond to same signedness as the // predicate we're interested in folding. APInt Min = ICmpInst::isSigned(Pred) ? getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); if (Min == C->getAPInt()) { // Given (V >= Min && V != Min) we conclude V >= (Min + 1). // This is true even if (Min + 1) wraps around -- in case of // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). APInt SharperMin = Min + 1; switch (Pred) { case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_UGE: // We know V `Pred` SharperMin. If this implies LHS `Pred` // RHS, we're done. if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin))) return true; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: // We know from the range information that (V `Pred` Min || // V == Min). We know from the guarding condition that !(V // == Min). This gives us // // V `Pred` Min || V == Min && !(V == Min) // => V `Pred` Min // // If V `Pred` Min implies LHS `Pred` RHS, we're done. if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) return true; default: // No change break; } } } // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; if (Pred == ICmpInst::ICMP_NE) if (!ICmpInst::isTrueWhenEqual(FoundPred)) if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) return true; // Otherwise assume the worst. return false; } bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags) { const auto *AE = dyn_cast(Expr); if (!AE || AE->getNumOperands() != 2) return false; L = AE->getOperand(0); R = AE->getOperand(1); Flags = AE->getNoWrapFlags(); return true; } bool ScalarEvolution::computeConstantDifference(const SCEV *Less, const SCEV *More, APInt &C) { // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). if (isa(Less) && isa(More)) { const auto *LAR = cast(Less); const auto *MAR = cast(More); if (LAR->getLoop() != MAR->getLoop()) return false; // We look at affine expressions only; not for correctness but to keep // getStepRecurrence cheap. if (!LAR->isAffine() || !MAR->isAffine()) return false; if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) return false; Less = LAR->getStart(); More = MAR->getStart(); // fall through } if (isa(Less) && isa(More)) { const auto &M = cast(More)->getAPInt(); const auto &L = cast(Less)->getAPInt(); C = M - L; return true; } const SCEV *L, *R; SCEV::NoWrapFlags Flags; if (splitBinaryAdd(Less, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == More) { C = -(LC->getAPInt()); return true; } if (splitBinaryAdd(More, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == Less) { C = LC->getAPInt(); return true; } return false; } bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) return false; const auto *AddRecLHS = dyn_cast(LHS); if (!AddRecLHS) return false; const auto *AddRecFoundLHS = dyn_cast(FoundLHS); if (!AddRecFoundLHS) return false; // We'd like to let SCEV reason about control dependencies, so we constrain // both the inequalities to be about add recurrences on the same loop. This // way we can use isLoopEntryGuardedByCond later. const Loop *L = AddRecFoundLHS->getLoop(); if (L != AddRecLHS->getLoop()) return false; // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) // // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) // ... (2) // // Informal proof for (2), assuming (1) [*]: // // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] // // Then // // FoundLHS s< FoundRHS s< INT_MIN - C // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] // <=> FoundLHS + C s< FoundRHS + C // // [*]: (1) can be proved by ruling out overflow. // // [**]: This can be proved by analyzing all the four possibilities: // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and // (A s>= 0, B s>= 0). // // Note: // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + // C)". APInt LDiff, RDiff; if (!computeConstantDifference(FoundLHS, LHS, LDiff) || !computeConstantDifference(FoundRHS, RHS, RDiff) || LDiff != RDiff) return false; if (LDiff == 0) return true; APInt FoundRHSLimit; if (Pred == CmpInst::ICMP_ULT) { FoundRHSLimit = -RDiff; } else { assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff; } // Try to prove (1) or (2), as needed. return isLoopEntryGuardedByCond(L, Pred, FoundRHS, getConstant(FoundRHSLimit)); } bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y isImpliedCondOperandsHelper(Pred, LHS, RHS, getNotSCEV(FoundRHS), getNotSCEV(FoundLHS)); } /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { const SCEVAddExpr *Add = dyn_cast(Expr); if (!Add || Add->getNumOperands() != 2 || !Add->getOperand(0)->isAllOnesValue()) return nullptr; const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); if (!AddRHS || AddRHS->getNumOperands() != 2 || !AddRHS->getOperand(0)->isAllOnesValue()) return nullptr; return AddRHS->getOperand(1); } /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? template static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, const SCEV *Candidate) { const MaxExprType *MaxExpr = dyn_cast(MaybeMaxExpr); if (!MaxExpr) return false; return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); } /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? template static bool IsMinConsistingOf(ScalarEvolution &SE, const SCEV *MaybeMinExpr, const SCEV *Candidate) { const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); if (!MaybeMaxExpr) return false; return IsMaxConsistingOf(MaybeMaxExpr, SE.getNotSCEV(Candidate)); } static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // If both sides are affine addrecs for the same loop, with equal // steps, and we know the recurrences don't wrap, then we only // need to check the predicate on the starting values. if (!ICmpInst::isRelational(Pred)) return false; const SCEVAddRecExpr *LAR = dyn_cast(LHS); if (!LAR) return false; const SCEVAddRecExpr *RAR = dyn_cast(RHS); if (!RAR) return false; if (LAR->getLoop() != RAR->getLoop()) return false; if (!LAR->isAffine() || !RAR->isAffine()) return false; if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE)) return false; SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? SCEV::FlagNSW : SCEV::FlagNUW; if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) return false; return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart()); } /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max /// expression? static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { switch (Pred) { default: return false; case ICmpInst::ICMP_SGE: std::swap(LHS, RHS); // fall through case ICmpInst::ICMP_SLE: return // min(A, ...) <= A IsMinConsistingOf(SE, LHS, RHS) || // A <= max(A, ...) IsMaxConsistingOf(RHS, LHS); case ICmpInst::ICMP_UGE: std::swap(LHS, RHS); // fall through case ICmpInst::ICMP_ULE: return // min(A, ...) <= A IsMinConsistingOf(SE, LHS, RHS) || // A <= max(A, ...) IsMaxConsistingOf(RHS, LHS); } llvm_unreachable("covered switch fell through?!"); } bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || isKnownPredicateViaNoOverflow(Pred, LHS, RHS); }; switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_NE: if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } return false; } bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { if (!isa(RHS) || !isa(FoundRHS)) // The restriction on `FoundRHS` be lifted easily -- it exists only to // reduce the compile time impact of this optimization. return false; const SCEVAddExpr *AddLHS = dyn_cast(LHS); if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || !isa(AddLHS->getOperand(0))) return false; APInt ConstFoundRHS = cast(FoundRHS)->getAPInt(); // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the // antecedent "`FoundLHS` `Pred` `FoundRHS`". ConstantRange FoundLHSRange = ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range // for `LHS`: APInt Addend = cast(AddLHS->getOperand(0))->getAPInt(); ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": APInt ConstRHS = cast(RHS)->getAPInt(); ConstantRange SatisfyingLHSRange = ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); // The antecedent implies the consequent if every value of `LHS` that // satisfies the antecedent also satisfies the consequent. return SatisfyingLHSRange.contains(LHSRange); } bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *One = getOne(Stride->getType()); if (IsSigned) { APInt MaxRHS = getSignedRange(RHS).getSignedMax(); APInt MaxValue = APInt::getSignedMaxValue(BitWidth); APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) .getSignedMax(); // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! return (MaxValue - MaxStrideMinusOne).slt(MaxRHS); } APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax(); APInt MaxValue = APInt::getMaxValue(BitWidth); APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) .getUnsignedMax(); // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *One = getOne(Stride->getType()); if (IsSigned) { APInt MinRHS = getSignedRange(RHS).getSignedMin(); APInt MinValue = APInt::getSignedMinValue(BitWidth); APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) .getSignedMax(); // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! return (MinValue + MaxStrideMinusOne).sgt(MinRHS); } APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin(); APInt MinValue = APInt::getMinValue(BitWidth); APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) .getUnsignedMax(); // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! return (MinValue + MaxStrideMinusOne).ugt(MinRHS); } const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getOne(Step->getType()); Delta = Equality ? getAddExpr(Delta, Step) : getAddExpr(Delta, getMinusSCEV(Step, One)); return getUDivExpr(Delta, Step); } ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { SCEVUnionPredicate P; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast(LHS); if (!IV && AllowPredicates) // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = IV->getStepRecurrence(*this); // Avoid negative or zero stride values if (!isKnownPositive(Stride)) return getCouldNotCompute(); // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() : getUnsignedRange(Start).getUnsignedMin(); APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() : getUnsignedRange(Stride).getUnsignedMin(); unsigned BitWidth = getTypeSizeInBits(LHS->getType()); APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1) : APInt::getMaxValue(BitWidth) - (MinStride - 1); // Although End can be a MAX expression we estimate MaxEnd considering only // the case End = RHS. This is safe because in the other case (End - Start) // is zero, leading to a zero maximum backedge taken count. APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); const SCEV *MaxBECount; if (isa(BECount)) MaxBECount = BECount; else MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), getConstant(MinStride), false); if (isa(MaxBECount)) MaxBECount = BECount; return ExitLimit(BECount, MaxBECount, P); } ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { SCEVUnionPredicate P; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast(LHS); if (!IV && AllowPredicates) // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); // Avoid negative or zero stride values if (!isKnownPositive(Stride)) return getCouldNotCompute(); // Avoid proven overflow cases: this will ensure that the backedge taken count // will not generate any unsigned overflow. Relaxed no-overflow conditions // exploit NoWrapFlags, allowing to optimize in presence of undefined // behaviors like the case of C language. if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) return getCouldNotCompute(); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax() : getUnsignedRange(Start).getUnsignedMax(); APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() : getUnsignedRange(Stride).getUnsignedMin(); unsigned BitWidth = getTypeSizeInBits(LHS->getType()); APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) : APInt::getMinValue(BitWidth) + (MinStride - 1); // Although End can be a MIN expression we estimate MinEnd considering only // the case End = RHS. This is safe because in the other case (Start - End) // is zero, leading to a zero maximum backedge taken count. APInt MinEnd = IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit) : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit); const SCEV *MaxBECount = getCouldNotCompute(); if (isa(BECount)) MaxBECount = BECount; else MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), getConstant(MinStride), false); if (isa(MaxBECount)) MaxBECount = BECount; return ExitLimit(BECount, MaxBECount, P); } const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { SmallVector Operands(op_begin(), op_end()); Operands[0] = SE.getZero(SC->getType()); const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( Range.subtract(SC->getAPInt()), SE); // This is strange and shouldn't happen. return SE.getCouldNotCompute(); } // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. if (any_of(operands(), [](const SCEV *Op) { return !isa(Op); })) return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and // that the start element is zero. // First check to see if the range contains zero. If not, the first // iteration exits. unsigned BitWidth = SE.getTypeSizeInBits(getType()); if (!Range.contains(APInt(BitWidth, 0))) return SE.getZero(getType()); if (isAffine()) { // If this is an affine expression then we have this situation: // Solve {0,+,A} in Range === Ax in Range // We know that zero is in the range. If A is positive then we know that // the upper value of the range must be the first possible exit value. // If A is negative then the lower of the range is the last possible loop // value. Also note that we already checked for a full range. APInt One(BitWidth,1); APInt A = cast(getOperand(1))->getAPInt(); APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); // The exit value should be (End+A)/A. APInt ExitVal = (End + A).udiv(A); ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); // Evaluate at the exit value. If we really did fall out of the valid // range, then we computed our trip count, otherwise wrap around or other // things must have happened. ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); if (Range.contains(Val->getValue())) return SE.getCouldNotCompute(); // Something strange happened // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( EvaluateConstantChrecAtConstant(this, ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); } else if (isQuadratic()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. SmallVector NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), // getNoWrapFlags(FlagNW) FlagAnyWrap); // Next, solve the constructed addrec if (auto Roots = SolveQuadraticEquation(cast(NewAddRec), SE)) { const SCEVConstant *R1 = Roots->first; const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) return SE.getConstant(NextVal); return SE.getCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. ConstantInt *NextVal = ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; return SE.getCouldNotCompute(); // Something strange happened } } } return SE.getCouldNotCompute(); } namespace { struct FindUndefs { bool Found; FindUndefs() : Found(false) {} bool follow(const SCEV *S) { if (const SCEVUnknown *C = dyn_cast(S)) { if (isa(C->getValue())) Found = true; } else if (const SCEVConstant *C = dyn_cast(S)) { if (isa(C->getValue())) Found = true; } // Keep looking if we haven't found it yet. return !Found; } bool isDone() const { // Stop recursion if we have found an undef. return Found; } }; } // Return true when S contains at least an undef value. static inline bool containsUndefs(const SCEV *S) { FindUndefs F; SCEVTraversal ST(F); ST.visitAll(S); return F.Found; } namespace { // Collect all steps of SCEV expressions. struct SCEVCollectStrides { ScalarEvolution &SE; SmallVectorImpl &Strides; SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl &S) : SE(SE), Strides(S) {} bool follow(const SCEV *S) { if (const SCEVAddRecExpr *AR = dyn_cast(S)) Strides.push_back(AR->getStepRecurrence(SE)); return true; } bool isDone() const { return false; } }; // Collect all SCEVUnknown and SCEVMulExpr expressions. struct SCEVCollectTerms { SmallVectorImpl &Terms; SCEVCollectTerms(SmallVectorImpl &T) : Terms(T) {} bool follow(const SCEV *S) { if (isa(S) || isa(S)) { if (!containsUndefs(S)) Terms.push_back(S); // Stop recursion: once we collected a term, do not walk its operands. return false; } // Keep looking. return true; } bool isDone() const { return false; } }; // Check if a SCEV contains an AddRecExpr. struct SCEVHasAddRec { bool &ContainsAddRec; SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { ContainsAddRec = false; } bool follow(const SCEV *S) { if (isa(S)) { ContainsAddRec = true; // Stop recursion: once we collected a term, do not walk its operands. return false; } // Keep looking. return true; } bool isDone() const { return false; } }; // Find factors that are multiplied with an expression that (possibly as a // subexpression) contains an AddRecExpr. In the expression: // // 8 * (100 + %p * %q * (%a + {0, +, 1}_loop)) // // "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" // that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size // parameters as they form a product with an induction variable. // // This collector expects all array size parameters to be in the same MulExpr. // It might be necessary to later add support for collecting parameters that are // spread over different nested MulExpr. struct SCEVCollectAddRecMultiplies { SmallVectorImpl &Terms; ScalarEvolution &SE; SCEVCollectAddRecMultiplies(SmallVectorImpl &T, ScalarEvolution &SE) : Terms(T), SE(SE) {} bool follow(const SCEV *S) { if (auto *Mul = dyn_cast(S)) { bool HasAddRec = false; SmallVector Operands; for (auto Op : Mul->operands()) { if (isa(Op)) { Operands.push_back(Op); } else { bool ContainsAddRec; SCEVHasAddRec ContiansAddRec(ContainsAddRec); visitAll(Op, ContiansAddRec); HasAddRec |= ContainsAddRec; } } if (Operands.size() == 0) return true; if (!HasAddRec) return false; Terms.push_back(SE.getMulExpr(Operands)); // Stop recursion: once we collected a term, do not walk its operands. return false; } // Keep looking. return true; } bool isDone() const { return false; } }; } /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in /// two places: /// 1) The strides of AddRec expressions. /// 2) Unknowns that are multiplied with AddRec expressions. void ScalarEvolution::collectParametricTerms(const SCEV *Expr, SmallVectorImpl &Terms) { SmallVector Strides; SCEVCollectStrides StrideCollector(*this, Strides); visitAll(Expr, StrideCollector); DEBUG({ dbgs() << "Strides:\n"; for (const SCEV *S : Strides) dbgs() << *S << "\n"; }); for (const SCEV *S : Strides) { SCEVCollectTerms TermCollector(Terms); visitAll(S, TermCollector); } DEBUG({ dbgs() << "Terms:\n"; for (const SCEV *T : Terms) dbgs() << *T << "\n"; }); SCEVCollectAddRecMultiplies MulCollector(Terms, *this); visitAll(Expr, MulCollector); } static bool findArrayDimensionsRec(ScalarEvolution &SE, SmallVectorImpl &Terms, SmallVectorImpl &Sizes) { int Last = Terms.size() - 1; const SCEV *Step = Terms[Last]; // End of recursion. if (Last == 0) { if (const SCEVMulExpr *M = dyn_cast(Step)) { SmallVector Qs; for (const SCEV *Op : M->operands()) if (!isa(Op)) Qs.push_back(Op); Step = SE.getMulExpr(Qs); } Sizes.push_back(Step); return true; } for (const SCEV *&Term : Terms) { // Normalize the terms before the next call to findArrayDimensionsRec. const SCEV *Q, *R; SCEVDivision::divide(SE, Term, Step, &Q, &R); // Bail out when GCD does not evenly divide one of the terms. if (!R->isZero()) return false; Term = Q; } // Remove all SCEVConstants. Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) { return isa(E); }), Terms.end()); if (Terms.size() > 0) if (!findArrayDimensionsRec(SE, Terms, Sizes)) return false; Sizes.push_back(Step); return true; } // Returns true when S contains at least a SCEVUnknown parameter. static inline bool containsParameters(const SCEV *S) { struct FindParameter { bool FoundParameter; FindParameter() : FoundParameter(false) {} bool follow(const SCEV *S) { if (isa(S)) { FoundParameter = true; // Stop recursion: we found a parameter. return false; } // Keep looking. return true; } bool isDone() const { // Stop recursion if we have found a parameter. return FoundParameter; } }; FindParameter F; SCEVTraversal ST(F); ST.visitAll(S); return F.FoundParameter; } // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. static inline bool containsParameters(SmallVectorImpl &Terms) { for (const SCEV *T : Terms) if (containsParameters(T)) return true; return false; } // Return the number of product terms in S. static inline int numberOfTerms(const SCEV *S) { if (const SCEVMulExpr *Expr = dyn_cast(S)) return Expr->getNumOperands(); return 1; } static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { if (isa(T)) return nullptr; if (isa(T)) return T; if (const SCEVMulExpr *M = dyn_cast(T)) { SmallVector Factors; for (const SCEV *Op : M->operands()) if (!isa(Op)) Factors.push_back(Op); return SE.getMulExpr(Factors); } return T; } /// Return the size of an element read or written by Inst. const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { Type *Ty; if (StoreInst *Store = dyn_cast(Inst)) Ty = Store->getValueOperand()->getType(); else if (LoadInst *Load = dyn_cast(Inst)) Ty = Load->getType(); else return nullptr; Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); return getSizeOfExpr(ETy, Ty); } void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, SmallVectorImpl &Sizes, const SCEV *ElementSize) const { if (Terms.size() < 1 || !ElementSize) return; // Early return when Terms do not contain parameters: we do not delinearize // non parametric SCEVs. if (!containsParameters(Terms)) return; DEBUG({ dbgs() << "Terms:\n"; for (const SCEV *T : Terms) dbgs() << *T << "\n"; }); // Remove duplicates. std::sort(Terms.begin(), Terms.end()); Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { return numberOfTerms(LHS) > numberOfTerms(RHS); }); ScalarEvolution &SE = *const_cast(this); // Try to divide all terms by the element size. If term is not divisible by // element size, proceed with the original term. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); if (!Q->isZero()) Term = Q; } SmallVector NewTerms; // Remove constant factors. for (const SCEV *T : Terms) if (const SCEV *NewT = removeConstantFactors(SE, T)) NewTerms.push_back(NewT); DEBUG({ dbgs() << "Terms after sorting:\n"; for (const SCEV *T : NewTerms) dbgs() << *T << "\n"; }); if (NewTerms.empty() || !findArrayDimensionsRec(SE, NewTerms, Sizes)) { Sizes.clear(); return; } // The last element to be pushed into Sizes is the size of an element. Sizes.push_back(ElementSize); DEBUG({ dbgs() << "Sizes:\n"; for (const SCEV *S : Sizes) dbgs() << *S << "\n"; }); } void ScalarEvolution::computeAccessFunctions( const SCEV *Expr, SmallVectorImpl &Subscripts, SmallVectorImpl &Sizes) { // Early exit in case this SCEV is not an affine multivariate function. if (Sizes.empty()) return; if (auto *AR = dyn_cast(Expr)) if (!AR->isAffine()) return; const SCEV *Res = Expr; int Last = Sizes.size() - 1; for (int i = Last; i >= 0; i--) { const SCEV *Q, *R; SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); DEBUG({ dbgs() << "Res: " << *Res << "\n"; dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; dbgs() << "Res divided by Sizes[i]:\n"; dbgs() << "Quotient: " << *Q << "\n"; dbgs() << "Remainder: " << *R << "\n"; }); Res = Q; // Do not record the last subscript corresponding to the size of elements in // the array. if (i == Last) { // Bail out if the remainder is too complex. if (isa(R)) { Subscripts.clear(); Sizes.clear(); return; } continue; } // Record the access function for the current subscript. Subscripts.push_back(R); } // Also push in last position the remainder of the last division: it will be // the access function of the innermost dimension. Subscripts.push_back(Res); std::reverse(Subscripts.begin(), Subscripts.end()); DEBUG({ dbgs() << "Subscripts:\n"; for (const SCEV *S : Subscripts) dbgs() << *S << "\n"; }); } /// Splits the SCEV into two vectors of SCEVs representing the subscripts and /// sizes of an array access. Returns the remainder of the delinearization that /// is the offset start of the array. The SCEV->delinearize algorithm computes /// the multiples of SCEV coefficients: that is a pattern matching of sub /// expressions in the stride and base of a SCEV corresponding to the /// computation of a GCD (greatest common divisor) of base and stride. When /// SCEV->delinearize fails, it returns the SCEV unchanged. /// /// For example: when analyzing the memory access A[i][j][k] in this loop nest /// /// void foo(long n, long m, long o, double A[n][m][o]) { /// /// for (long i = 0; i < n; i++) /// for (long j = 0; j < m; j++) /// for (long k = 0; k < o; k++) /// A[i][j][k] = 1.0; /// } /// /// the delinearization input is the following AddRec SCEV: /// /// AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k> /// /// From this SCEV, we are able to say that the base offset of the access is %A /// because it appears as an offset that does not divide any of the strides in /// the loops: /// /// CHECK: Base offset: %A /// /// and then SCEV->delinearize determines the size of some of the dimensions of /// the array as these are the multiples by which the strides are happening: /// /// CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes. /// /// Note that the outermost dimension remains of UnknownSize because there are /// no strides that would help identifying the size of the last dimension: when /// the array has been statically allocated, one could compute the size of that /// dimension by dividing the overall size of the array by the size of the known /// dimensions: %m * %o * 8. /// /// Finally delinearize provides the access functions for the array reference /// that does correspond to A[i][j][k] of the above C testcase: /// /// CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>] /// /// The testcases are checking the output of a function pass: /// DelinearizationPass that walks through all loads and stores of a function /// asking for the SCEV of the memory access with respect to all enclosing /// loops, calling SCEV->delinearize on that and printing the results. void ScalarEvolution::delinearize(const SCEV *Expr, SmallVectorImpl &Subscripts, SmallVectorImpl &Sizes, const SCEV *ElementSize) { // First step: collect parametric terms. SmallVector Terms; collectParametricTerms(Expr, Terms); if (Terms.empty()) return; // Second step: find subscript sizes. findArrayDimensions(Terms, Sizes, ElementSize); if (Sizes.empty()) return; // Third step: compute the access functions for each subscript. computeAccessFunctions(Expr, Subscripts, Sizes); if (Subscripts.empty()) return; DEBUG({ dbgs() << "succeeded to delinearize " << *Expr << "\n"; dbgs() << "ArrayDecl[UnknownSize]"; for (const SCEV *S : Sizes) dbgs() << "[" << *S << "]"; dbgs() << "\nArrayRef"; for (const SCEV *S : Subscripts) dbgs() << "[" << *S << "]"; dbgs() << "\n"; }); } //===----------------------------------------------------------------------===// // SCEVCallbackVH Class Implementation //===----------------------------------------------------------------------===// void ScalarEvolution::SCEVCallbackVH::deleted() { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); SE->eraseValueFromMap(getValPtr()); // this now dangles! } void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); // Forget all the expressions associated with users of the old value, // so that future queries will recompute the expressions using the new // value. Value *Old = getValPtr(); SmallVector Worklist(Old->user_begin(), Old->user_end()); SmallPtrSet Visited; while (!Worklist.empty()) { User *U = Worklist.pop_back_val(); // Deleting the Old value will cause this to dangle. Postpone // that until everything else is done. if (U == Old) continue; if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); SE->eraseValueFromMap(U); Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); } // Delete the Old value. if (PHINode *PN = dyn_cast(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); SE->eraseValueFromMap(Old); // this now dangles! } ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) : CallbackVH(V), SE(se) {} //===----------------------------------------------------------------------===// // ScalarEvolution Class Implementation //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI) : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), CouldNotCompute(new SCEVCouldNotCompute()), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), FirstUnknown(nullptr) { // To use guards for proving predicates, we need to scan every instruction in // relevant basic blocks, and not just terminators. Doing this is a waste of // time if the IR does not actually contain any calls to // @llvm.experimental.guard, so do a quick check and remember this beforehand. // // This pessimizes the case where a pass that preserves ScalarEvolution wants // to _add_ guards to the module when there weren't any before, and wants // ScalarEvolution to optimize based on those guards. For now we prefer to be // efficient in lieu of being smart in that rather obscure case. auto *GuardDecl = F.getParent()->getFunction( Intrinsic::getName(Intrinsic::experimental_guard)); HasGuards = GuardDecl && !GuardDecl->use_empty(); } ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), ConstantEvolutionLoopExitValue( std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), LoopDispositions(std::move(Arg.LoopDispositions)), BlockDispositions(std::move(Arg.BlockDispositions)), UnsignedRanges(std::move(Arg.UnsignedRanges)), SignedRanges(std::move(Arg.SignedRanges)), UniqueSCEVs(std::move(Arg.UniqueSCEVs)), UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; } ScalarEvolution::~ScalarEvolution() { // Iterate through all the SCEVUnknown instances and call their // destructors, so that they release their references to their values. for (SCEVUnknown *U = FirstUnknown; U;) { SCEVUnknown *Tmp = U; U = U->Next; Tmp->~SCEVUnknown(); } FirstUnknown = nullptr; ExprValueMap.clear(); ValueExprMap.clear(); HasRecMap.clear(); // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. for (auto &BTCI : BackedgeTakenCounts) BTCI.second.clear(); for (auto &BTCI : PredicatedBackedgeTakenCounts) BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); } bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { return !isa(getBackedgeTakenCount(L)); } static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L) { // Print all inner loops first for (Loop *I : *L) PrintLoopInfo(OS, SE, I); OS << "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) OS << " "; if (SE->hasLoopInvariantBackedgeTakenCount(L)) { OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); } else { OS << "Unpredictable backedge-taken count. "; } OS << "\n" "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; if (!isa(SE->getMaxBackedgeTakenCount(L))) { OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); } else { OS << "Unpredictable max backedge-taken count. "; } OS << "\n" "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; SCEVUnionPredicate Pred; auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); if (!isa(PBT)) { OS << "Predicated backedge-taken count is " << *PBT << "\n"; OS << " Predicates:\n"; Pred.print(OS, 4); } else { OS << "Unpredictable predicated backedge-taken count. "; } OS << "\n"; } static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { switch (LD) { case ScalarEvolution::LoopVariant: return "Variant"; case ScalarEvolution::LoopInvariant: return "Invariant"; case ScalarEvolution::LoopComputable: return "Computable"; } llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); } void ScalarEvolution::print(raw_ostream &OS) const { // ScalarEvolution's implementation of the print method is to print // out SCEV values of all instructions that are interesting. Doing // this potentially causes it to create new SCEV objects though, // which technically conflicts with the const qualifier. This isn't // observable from outside the class though, so casting away the // const isn't dangerous. ScalarEvolution &SE = *const_cast(this); OS << "Classifying expressions for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; for (Instruction &I : instructions(F)) if (isSCEVable(I.getType()) && !isa(I)) { OS << I << '\n'; OS << " --> "; const SCEV *SV = SE.getSCEV(&I); SV->print(OS); if (!isa(SV)) { OS << " U: "; SE.getUnsignedRange(SV).print(OS); OS << " S: "; SE.getSignedRange(SV).print(OS); } const Loop *L = LI.getLoopFor(I.getParent()); const SCEV *AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { OS << " --> "; AtUse->print(OS); if (!isa(AtUse)) { OS << " U: "; SE.getUnsignedRange(AtUse).print(OS); OS << " S: "; SE.getSignedRange(AtUse).print(OS); } } if (L) { OS << "\t\t" "Exits: "; const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!SE.isLoopInvariant(ExitValue, L)) { OS << "<>"; } else { OS << *ExitValue; } bool First = true; for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { if (First) { OS << "\t\t" "LoopDispositions: { "; First = false; } else { OS << ", "; } Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); } for (auto *InnerL : depth_first(L)) { if (InnerL == L) continue; if (First) { OS << "\t\t" "LoopDispositions: { "; First = false; } else { OS << ", "; } InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); } OS << " }"; } OS << "\n"; } OS << "Determining loop execution counts for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; for (Loop *I : LI) PrintLoopInfo(OS, &SE, I); } ScalarEvolution::LoopDisposition ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { auto &Values = LoopDispositions[S]; for (auto &V : Values) { if (V.getPointer() == L) return V.getInt(); } Values.emplace_back(L, LoopVariant); LoopDisposition D = computeLoopDisposition(S, L); auto &Values2 = LoopDispositions[S]; for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { if (V.getPointer() == L) { V.setInt(D); break; } } return D; } ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { switch (static_cast(S->getSCEVType())) { case scConstant: return LoopInvariant; case scTruncate: case scZeroExtend: case scSignExtend: return getLoopDisposition(cast(S)->getOperand(), L); case scAddRecExpr: { const SCEVAddRecExpr *AR = cast(S); // If L is the addrec's loop, it's computable. if (AR->getLoop() == L) return LoopComputable; // Add recurrences are never invariant in the function-body (null loop). if (!L) return LoopVariant; // This recurrence is variant w.r.t. L if L contains AR's loop. if (L->contains(AR->getLoop())) return LoopVariant; // This recurrence is invariant w.r.t. L if AR's loop contains L. if (AR->getLoop()->contains(L)) return LoopInvariant; // This recurrence is variant w.r.t. L if any of its operands // are variant. for (auto *Op : AR->operands()) if (!isLoopInvariant(Op, L)) return LoopVariant; // Otherwise it's loop-invariant. return LoopInvariant; } case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: { bool HasVarying = false; for (auto *Op : cast(S)->operands()) { LoopDisposition D = getLoopDisposition(Op, L); if (D == LoopVariant) return LoopVariant; if (D == LoopComputable) HasVarying = true; } return HasVarying ? LoopComputable : LoopInvariant; } case scUDivExpr: { const SCEVUDivExpr *UDiv = cast(S); LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); if (LD == LoopVariant) return LoopVariant; LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); if (RD == LoopVariant) return LoopVariant; return (LD == LoopInvariant && RD == LoopInvariant) ? LoopInvariant : LoopComputable; } case scUnknown: // All non-instruction values are loop invariant. All instructions are loop // invariant if they are not contained in the specified loop. // Instructions are never considered invariant in the function body // (null loop) because they are defined within the "loop". if (auto *I = dyn_cast(cast(S)->getValue())) return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; return LoopInvariant; case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } llvm_unreachable("Unknown SCEV kind!"); } bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { return getLoopDisposition(S, L) == LoopInvariant; } bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { return getLoopDisposition(S, L) == LoopComputable; } ScalarEvolution::BlockDisposition ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { auto &Values = BlockDispositions[S]; for (auto &V : Values) { if (V.getPointer() == BB) return V.getInt(); } Values.emplace_back(BB, DoesNotDominateBlock); BlockDisposition D = computeBlockDisposition(S, BB); auto &Values2 = BlockDispositions[S]; for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { if (V.getPointer() == BB) { V.setInt(D); break; } } return D; } ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { switch (static_cast(S->getSCEVType())) { case scConstant: return ProperlyDominatesBlock; case scTruncate: case scZeroExtend: case scSignExtend: return getBlockDisposition(cast(S)->getOperand(), BB); case scAddRecExpr: { // This uses a "dominates" query instead of "properly dominates" query // to test for proper dominance too, because the instruction which // produces the addrec's value is a PHI, and a PHI effectively properly // dominates its entire containing block. const SCEVAddRecExpr *AR = cast(S); if (!DT.dominates(AR->getLoop()->getHeader(), BB)) return DoesNotDominateBlock; } // FALL THROUGH into SCEVNAryExpr handling. case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: { const SCEVNAryExpr *NAry = cast(S); bool Proper = true; for (const SCEV *NAryOp : NAry->operands()) { BlockDisposition D = getBlockDisposition(NAryOp, BB); if (D == DoesNotDominateBlock) return DoesNotDominateBlock; if (D == DominatesBlock) Proper = false; } return Proper ? ProperlyDominatesBlock : DominatesBlock; } case scUDivExpr: { const SCEVUDivExpr *UDiv = cast(S); const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); BlockDisposition LD = getBlockDisposition(LHS, BB); if (LD == DoesNotDominateBlock) return DoesNotDominateBlock; BlockDisposition RD = getBlockDisposition(RHS, BB); if (RD == DoesNotDominateBlock) return DoesNotDominateBlock; return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? ProperlyDominatesBlock : DominatesBlock; } case scUnknown: if (Instruction *I = dyn_cast(cast(S)->getValue())) { if (I->getParent() == BB) return DominatesBlock; if (DT.properlyDominates(I->getParent(), BB)) return ProperlyDominatesBlock; return DoesNotDominateBlock; } return ProperlyDominatesBlock; case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } llvm_unreachable("Unknown SCEV kind!"); } bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { return getBlockDisposition(S, BB) >= DominatesBlock; } bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { return getBlockDisposition(S, BB) == ProperlyDominatesBlock; } bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { // Search for a SCEV expression node within an expression tree. // Implements SCEVTraversal::Visitor. struct SCEVSearch { const SCEV *Node; bool IsFound; SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} bool follow(const SCEV *S) { IsFound |= (S == Node); return !IsFound; } bool isDone() const { return IsFound; } }; SCEVSearch Search(Op); visitAll(S, Search); return Search.IsFound; } void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { ValuesAtScopes.erase(S); LoopDispositions.erase(S); BlockDispositions.erase(S); UnsignedRanges.erase(S); SignedRanges.erase(S); ExprValueMap.erase(S); HasRecMap.erase(S); auto RemoveSCEVFromBackedgeMap = [S, this](DenseMap &Map) { for (auto I = Map.begin(), E = Map.end(); I != E;) { BackedgeTakenInfo &BEInfo = I->second; if (BEInfo.hasOperand(S, this)) { BEInfo.clear(); Map.erase(I++); } else ++I; } }; RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } typedef DenseMap VerifyMap; /// replaceSubString - Replaces all occurrences of From in Str with To. static void replaceSubString(std::string &Str, StringRef From, StringRef To) { size_t Pos = 0; while ((Pos = Str.find(From, Pos)) != std::string::npos) { Str.replace(Pos, From.size(), To.data(), To.size()); Pos += To.size(); } } /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. static void getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { std::string &S = Map[L]; if (S.empty()) { raw_string_ostream OS(S); SE.getBackedgeTakenCount(L)->print(OS); // false and 0 are semantically equivalent. This can happen in dead loops. replaceSubString(OS.str(), "false", "0"); // Remove wrap flags, their use in SCEV is highly fragile. // FIXME: Remove this when SCEV gets smarter about them. replaceSubString(OS.str(), "", ""); replaceSubString(OS.str(), "", ""); replaceSubString(OS.str(), "", ""); } for (auto *R : reverse(*L)) getLoopBackedgeTakenCounts(R, Map, SE); // recurse. } void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast(this); // Gather stringified backedge taken counts for all loops using SCEV's caches. // FIXME: It would be much better to store actual values instead of strings, // but SCEV pointers will change if we drop the caches. VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); // Gather stringified backedge taken counts for all loops using a fresh // ScalarEvolution object. ScalarEvolution SE2(F, TLI, AC, DT, LI); for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2); // Now compare whether they're the same with and without caches. This allows // verifying that no pass changed the cache. assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() && "New loops suddenly appeared!"); for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(), OldE = BackedgeDumpsOld.end(), NewI = BackedgeDumpsNew.begin(); OldI != OldE; ++OldI, ++NewI) { assert(OldI->first == NewI->first && "Loop order changed!"); // Compare the stringified SCEVs. We don't care if undef backedgetaken count // changes. // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This // means that a pass is buggy or SCEV has to learn a new pattern but is // usually not harmful. if (OldI->second != NewI->second && OldI->second.find("undef") == std::string::npos && NewI->second.find("undef") == std::string::npos && OldI->second != "***COULDNOTCOMPUTE***" && NewI->second != "***COULDNOTCOMPUTE***") { dbgs() << "SCEVValidator: SCEV for loop '" << OldI->first->getHeader()->getName() << "' changed from '" << OldI->second << "' to '" << NewI->second << "'!\n"; std::abort(); } } // TODO: Verify more things. } char ScalarEvolutionAnalysis::PassID; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, AnalysisManager &AM) { return ScalarEvolution(F, AM.getResult(F), AM.getResult(F), AM.getResult(F), AM.getResult(F)); } PreservedAnalyses ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager &AM) { AM.getResult(F).print(OS); return PreservedAnalyses::all(); } INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", "Scalar Evolution Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", "Scalar Evolution Analysis", false, true) char ScalarEvolutionWrapperPass::ID = 0; ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry()); } bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { SE.reset(new ScalarEvolution( F, getAnalysis().getTLI(), getAnalysis().getAssumptionCache(F), getAnalysis().getDomTree(), getAnalysis().getLoopInfo())); return false; } void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { SE->print(OS); } void ScalarEvolutionWrapperPass::verifyAnalysis() const { if (!VerifySCEV) return; SE->verify(); } void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); } const SCEVPredicate * ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, const SCEVConstant *RHS) { FoldingSetNodeID ID; // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Equal); ID.AddPointer(LHS); ID.AddPointer(RHS); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; SCEVEqualPredicate *Eq = new (SCEVAllocator) SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); UniquePreds.InsertNode(Eq, IP); return Eq; } const SCEVPredicate *ScalarEvolution::getWrapPredicate( const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { FoldingSetNodeID ID; // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Wrap); ID.AddPointer(AR); ID.AddInteger(AddedFlags); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; auto *OF = new (SCEVAllocator) SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); UniquePreds.InsertNode(OF, IP); return OF; } namespace { class SCEVPredicateRewriter : public SCEVRewriteVisitor { public: // Rewrites \p S in the context of a loop L and the predicate A. // If Assume is true, rewrite is free to add further predicates to A // such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, SCEVUnionPredicate &A, bool Assume) { SCEVPredicateRewriter Rewriter(L, SE, A, Assume); return Rewriter.visit(S); } SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, SCEVUnionPredicate &P, bool Assume) : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); for (auto *Pred : ExprPreds) if (const auto *IPred = dyn_cast(Pred)) if (IPred->getLHS() == Expr) return IPred->getRHS(); return Expr; } const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nuw // flag. Add the nusw flag as an assumption that we could make. const SCEV *Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), SE.getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } return SE.getZeroExtendExpr(Operand, Expr->getType()); } const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { const SCEV *Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nsw // flag. Add the nssw flag as an assumption that we could make. const SCEV *Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), SE.getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } return SE.getSignExtendExpr(Operand, Expr->getType()); } private: bool addOverflowAssumption(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { auto *A = SE.getWrapPredicate(AR, AddedFlags); if (!Assume) { // Check if we've already made this assumption. if (P.implies(A)) return true; return false; } P.add(A); return true; } SCEVUnionPredicate &P; const Loop *L; bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); } const SCEVAddRecExpr * ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { SCEVUnionPredicate TransformPreds; S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); auto *AddRec = dyn_cast(S); if (!AddRec) return nullptr; // Since the transformation was successful, we can now transfer the SCEV // predicates. Preds.add(&TransformPreds); return AddRec; } /// SCEV predicates SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind) : FastID(ID), Kind(Kind) {} SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS, const SCEVConstant *RHS) : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { const auto *Op = dyn_cast(N); if (!Op) return false; return Op->LHS == LHS && Op->RHS == RHS; } bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags) : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { const auto *Op = dyn_cast(N); return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; } bool SCEVWrapPredicate::isAlwaysTrue() const { SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); IncrementWrapFlags IFlags = Flags; if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) IFlags = clearFlags(IFlags, IncrementNSSW); return IFlags == IncrementAnyWrap; } void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << *getExpr() << " Added Flags: "; if (SCEVWrapPredicate::IncrementNUSW & getFlags()) OS << ""; if (SCEVWrapPredicate::IncrementNSSW & getFlags()) OS << ""; OS << "\n"; } SCEVWrapPredicate::IncrementWrapFlags SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE) { IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); // We can safely transfer the NSW flag as NSSW. if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) ImpliedFlags = IncrementNSSW; if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { // If the increment is positive, the SCEV NUW flag will also imply the // WrapPredicate NUSW flag. if (const auto *Step = dyn_cast(AR->getStepRecurrence(SE))) if (Step->getValue()->getValue().isNonNegative()) ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); } return ImpliedFlags; } /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} bool SCEVUnionPredicate::isAlwaysTrue() const { return all_of(Preds, [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } ArrayRef SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { auto I = SCEVToPreds.find(Expr); if (I == SCEVToPreds.end()) return ArrayRef(); return I->second; } bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { if (const auto *Set = dyn_cast(N)) return all_of(Set->Preds, [this](const SCEVPredicate *I) { return this->implies(I); }); auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); if (ScevPredsIt == SCEVToPreds.end()) return false; auto &SCEVPreds = ScevPredsIt->second; return any_of(SCEVPreds, [N](const SCEVPredicate *I) { return I->implies(N); }); } const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { for (auto Pred : Preds) Pred->print(OS, Depth); } void SCEVUnionPredicate::add(const SCEVPredicate *N) { if (const auto *Set = dyn_cast(N)) { for (auto Pred : Set->Preds) add(Pred); return; } if (implies(N)) return; const SCEV *Key = N->getExpr(); assert(Key && "Only SCEVUnionPredicate doesn't have an " " associated expression!"); SCEVToPreds[Key].push_back(N); Preds.push_back(N); } PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); RewriteEntry &Entry = RewriteMap[Expr]; // If we already have an entry and the version matches, return it. if (Entry.second && Generation == Entry.first) return Entry.second; // We found an entry but it's stale. Rewrite the stale entry // acording to the current predicate. if (Entry.second) Expr = Entry.second; const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { if (!BackedgeCount) { SCEVUnionPredicate BackedgePred; BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); addPredicate(BackedgePred); } return BackedgeCount; } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { if (Preds.implies(&Pred)) return; Preds.add(&Pred); updateGeneration(); } const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { return Preds; } void PredicatedScalarEvolution::updateGeneration() { // If the generation number wrapped recompute everything. if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } void PredicatedScalarEvolution::setNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { const SCEV *Expr = getSCEV(V); const auto *AR = cast(Expr); auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); // Clear the statically implied flags. Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); addPredicate(*SE.getWrapPredicate(AR, Flags)); auto II = FlagsMap.insert({V, Flags}); if (!II.second) II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); } bool PredicatedScalarEvolution::hasNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { const SCEV *Expr = getSCEV(V); const auto *AR = cast(Expr); Flags = SCEVWrapPredicate::clearFlags( Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); auto II = FlagsMap.find(V); if (II != FlagsMap.end()) Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); return Flags == SCEVWrapPredicate::IncrementAnyWrap; } const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { const SCEV *Expr = this->getSCEV(V); auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); if (!New) return nullptr; updateGeneration(); RewriteMap[SE.getSCEV(V)] = {Generation, New}; return New; } PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (const auto &I : Init.FlagsMap) FlagsMap.insert(I); } void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { // For each block. for (auto *BB : L.getBlocks()) for (auto &I : *BB) { if (!SE.isSCEVable(I.getType())) continue; auto *Expr = SE.getSCEV(&I); auto II = RewriteMap.find(Expr); if (II == RewriteMap.end()) continue; // Don't print things that are not interesting. if (II->second.second == Expr) continue; OS.indent(Depth) << "[PSE]" << I << ":\n"; OS.indent(Depth + 2) << *Expr << "\n"; OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; } } Index: head/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- head/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp (revision 312831) +++ head/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp (revision 312832) @@ -1,2210 +1,2244 @@ //===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis ------------===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // This file contains the implementation of the scalar evolution expander, // which is used to generate the code corresponding to a given scalar evolution // expression. // //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace PatternMatch; /// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP, /// reusing an existing cast if a suitable one exists, moving an existing /// cast if a suitable one exists but isn't in the right place, or /// creating a new one. Value *SCEVExpander::ReuseOrCreateCast(Value *V, Type *Ty, Instruction::CastOps Op, BasicBlock::iterator IP) { // This function must be called with the builder having a valid insertion // point. It doesn't need to be the actual IP where the uses of the returned // cast will be added, but it must dominate such IP. // We use this precondition to produce a cast that will dominate all its // uses. In particular, this is crucial for the case where the builder's // insertion point *is* the point where we were asked to put the cast. // Since we don't know the builder's insertion point is actually // where the uses will be added (only that it dominates it), we are // not allowed to move it. BasicBlock::iterator BIP = Builder.GetInsertPoint(); Instruction *Ret = nullptr; // Check to see if there is already a cast! for (User *U : V->users()) if (U->getType() == Ty) if (CastInst *CI = dyn_cast(U)) if (CI->getOpcode() == Op) { // If the cast isn't where we want it, create a new cast at IP. // Likewise, do not reuse a cast at BIP because it must dominate // instructions that might be inserted before BIP. if (BasicBlock::iterator(CI) != IP || BIP == IP) { // Create a new cast, and leave the old cast in place in case // it is being used as an insert point. Clear its operand // so that it doesn't hold anything live. Ret = CastInst::Create(Op, V, Ty, "", &*IP); Ret->takeName(CI); CI->replaceAllUsesWith(Ret); CI->setOperand(0, UndefValue::get(V->getType())); break; } Ret = CI; break; } // Create a new cast. if (!Ret) Ret = CastInst::Create(Op, V, Ty, V->getName(), &*IP); // We assert at the end of the function since IP might point to an // instruction with different dominance properties than a cast // (an invoke for example) and not dominate BIP (but the cast does). assert(SE.DT.dominates(Ret, &*BIP)); rememberInstruction(Ret); return Ret; } static BasicBlock::iterator findInsertPointAfter(Instruction *I, BasicBlock *MustDominate) { BasicBlock::iterator IP = ++I->getIterator(); if (auto *II = dyn_cast(I)) IP = II->getNormalDest()->begin(); while (isa(IP)) ++IP; if (isa(IP) || isa(IP)) { ++IP; } else if (isa(IP)) { IP = MustDominate->getFirstInsertionPt(); } else { assert(!IP->isEHPad() && "unexpected eh pad!"); } return IP; } /// InsertNoopCastOfTo - Insert a cast of V to the specified type, /// which must be possible with a noop cast, doing what we can to share /// the casts. Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { Instruction::CastOps Op = CastInst::getCastOpcode(V, false, Ty, false); assert((Op == Instruction::BitCast || Op == Instruction::PtrToInt || Op == Instruction::IntToPtr) && "InsertNoopCastOfTo cannot perform non-noop casts!"); assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) && "InsertNoopCastOfTo cannot change sizes!"); // Short-circuit unnecessary bitcasts. if (Op == Instruction::BitCast) { if (V->getType() == Ty) return V; if (CastInst *CI = dyn_cast(V)) { if (CI->getOperand(0)->getType() == Ty) return CI->getOperand(0); } } // Short-circuit unnecessary inttoptr<->ptrtoint casts. if ((Op == Instruction::PtrToInt || Op == Instruction::IntToPtr) && SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(V->getType())) { if (CastInst *CI = dyn_cast(V)) if ((CI->getOpcode() == Instruction::PtrToInt || CI->getOpcode() == Instruction::IntToPtr) && SE.getTypeSizeInBits(CI->getType()) == SE.getTypeSizeInBits(CI->getOperand(0)->getType())) return CI->getOperand(0); if (ConstantExpr *CE = dyn_cast(V)) if ((CE->getOpcode() == Instruction::PtrToInt || CE->getOpcode() == Instruction::IntToPtr) && SE.getTypeSizeInBits(CE->getType()) == SE.getTypeSizeInBits(CE->getOperand(0)->getType())) return CE->getOperand(0); } // Fold a cast of a constant. if (Constant *C = dyn_cast(V)) return ConstantExpr::getCast(Op, C, Ty); // Cast the argument at the beginning of the entry block, after // any bitcasts of other arguments. if (Argument *A = dyn_cast(V)) { BasicBlock::iterator IP = A->getParent()->getEntryBlock().begin(); while ((isa(IP) && isa(cast(IP)->getOperand(0)) && cast(IP)->getOperand(0) != A) || isa(IP)) ++IP; return ReuseOrCreateCast(A, Ty, Op, IP); } // Cast the instruction immediately after the instruction. Instruction *I = cast(V); BasicBlock::iterator IP = findInsertPointAfter(I, Builder.GetInsertBlock()); return ReuseOrCreateCast(I, Ty, Op, IP); } /// InsertBinop - Insert the specified binary operator, doing a small amount /// of work to avoid inserting an obviously redundant operation. Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS) { // Fold a binop with constant operands. if (Constant *CLHS = dyn_cast(LHS)) if (Constant *CRHS = dyn_cast(RHS)) return ConstantExpr::get(Opcode, CLHS, CRHS); // Do a quick scan to see if we have this binop nearby. If so, reuse it. unsigned ScanLimit = 6; BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); // Scanning starts from the last instruction before the insertion point. BasicBlock::iterator IP = Builder.GetInsertPoint(); if (IP != BlockBegin) { --IP; for (; ScanLimit; --IP, --ScanLimit) { // Don't count dbg.value against the ScanLimit, to avoid perturbing the // generated code. if (isa(IP)) ScanLimit++; if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS && IP->getOperand(1) == RHS) return &*IP; if (IP == BlockBegin) break; } } // Save the original insertion point so we can restore it when we're done. DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc(); SCEVInsertPointGuard Guard(Builder, this); // Move the insertion point out of as many loops as we can. while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break; BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) break; // Ok, move up a level. Builder.SetInsertPoint(Preheader->getTerminator()); } // If we haven't found this binop, insert it. Instruction *BO = cast(Builder.CreateBinOp(Opcode, LHS, RHS)); BO->setDebugLoc(Loc); rememberInstruction(BO); return BO; } /// FactorOutConstant - Test if S is divisible by Factor, using signed /// division. If so, update S with Factor divided out and return true. /// S need not be evenly divisible if a reasonable remainder can be /// computed. /// TODO: When ScalarEvolution gets a SCEVSDivExpr, this can be made /// unnecessary; in its place, just signed-divide Ops[i] by the scale and /// check to see if the divide was folded. static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder, const SCEV *Factor, ScalarEvolution &SE, const DataLayout &DL) { // Everything is divisible by one. if (Factor->isOne()) return true; // x/x == 1. if (S == Factor) { S = SE.getConstant(S->getType(), 1); return true; } // For a Constant, check for a multiple of the given factor. if (const SCEVConstant *C = dyn_cast(S)) { // 0/x == 0. if (C->isZero()) return true; // Check for divisibility. if (const SCEVConstant *FC = dyn_cast(Factor)) { ConstantInt *CI = ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt())); // If the quotient is zero and the remainder is non-zero, reject // the value at this scale. It will be considered for subsequent // smaller scales. if (!CI->isZero()) { const SCEV *Div = SE.getConstant(CI); S = Div; Remainder = SE.getAddExpr( Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt()))); return true; } } } // In a Mul, check if there is a constant operand which is a multiple // of the given factor. if (const SCEVMulExpr *M = dyn_cast(S)) { // Size is known, check if there is a constant operand which is a multiple // of the given factor. If so, we can factor it. const SCEVConstant *FC = cast(Factor); if (const SCEVConstant *C = dyn_cast(M->getOperand(0))) if (!C->getAPInt().srem(FC->getAPInt())) { SmallVector NewMulOps(M->op_begin(), M->op_end()); NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); S = SE.getMulExpr(NewMulOps); return true; } } // In an AddRec, check if both start and step are divisible. if (const SCEVAddRecExpr *A = dyn_cast(S)) { const SCEV *Step = A->getStepRecurrence(SE); const SCEV *StepRem = SE.getConstant(Step->getType(), 0); if (!FactorOutConstant(Step, StepRem, Factor, SE, DL)) return false; if (!StepRem->isZero()) return false; const SCEV *Start = A->getStart(); if (!FactorOutConstant(Start, Remainder, Factor, SE, DL)) return false; S = SE.getAddRecExpr(Start, Step, A->getLoop(), A->getNoWrapFlags(SCEV::FlagNW)); return true; } return false; } /// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs /// is the number of SCEVAddRecExprs present, which are kept at the end of /// the list. /// static void SimplifyAddOperands(SmallVectorImpl &Ops, Type *Ty, ScalarEvolution &SE) { unsigned NumAddRecs = 0; for (unsigned i = Ops.size(); i > 0 && isa(Ops[i-1]); --i) ++NumAddRecs; // Group Ops into non-addrecs and addrecs. SmallVector NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs); SmallVector AddRecs(Ops.end() - NumAddRecs, Ops.end()); // Let ScalarEvolution sort and simplify the non-addrecs list. const SCEV *Sum = NoAddRecs.empty() ? SE.getConstant(Ty, 0) : SE.getAddExpr(NoAddRecs); // If it returned an add, use the operands. Otherwise it simplified // the sum into a single value, so just use that. Ops.clear(); if (const SCEVAddExpr *Add = dyn_cast(Sum)) Ops.append(Add->op_begin(), Add->op_end()); else if (!Sum->isZero()) Ops.push_back(Sum); // Then append the addrecs. Ops.append(AddRecs.begin(), AddRecs.end()); } /// SplitAddRecs - Flatten a list of add operands, moving addrec start values /// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}. /// This helps expose more opportunities for folding parts of the expressions /// into GEP indices. /// static void SplitAddRecs(SmallVectorImpl &Ops, Type *Ty, ScalarEvolution &SE) { // Find the addrecs. SmallVector AddRecs; for (unsigned i = 0, e = Ops.size(); i != e; ++i) while (const SCEVAddRecExpr *A = dyn_cast(Ops[i])) { const SCEV *Start = A->getStart(); if (Start->isZero()) break; const SCEV *Zero = SE.getConstant(Ty, 0); AddRecs.push_back(SE.getAddRecExpr(Zero, A->getStepRecurrence(SE), A->getLoop(), A->getNoWrapFlags(SCEV::FlagNW))); if (const SCEVAddExpr *Add = dyn_cast(Start)) { Ops[i] = Zero; Ops.append(Add->op_begin(), Add->op_end()); e += Add->getNumOperands(); } else { Ops[i] = Start; } } if (!AddRecs.empty()) { // Add the addrecs onto the end of the list. Ops.append(AddRecs.begin(), AddRecs.end()); // Resort the operand list, moving any constants to the front. SimplifyAddOperands(Ops, Ty, SE); } } /// expandAddToGEP - Expand an addition expression with a pointer type into /// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps /// BasicAliasAnalysis and other passes analyze the result. See the rules /// for getelementptr vs. inttoptr in /// http://llvm.org/docs/LangRef.html#pointeraliasing /// for details. /// /// Design note: The correctness of using getelementptr here depends on /// ScalarEvolution not recognizing inttoptr and ptrtoint operators, as /// they may introduce pointer arithmetic which may not be safely converted /// into getelementptr. /// /// Design note: It might seem desirable for this function to be more /// loop-aware. If some of the indices are loop-invariant while others /// aren't, it might seem desirable to emit multiple GEPs, keeping the /// loop-invariant portions of the overall computation outside the loop. /// However, there are a few reasons this is not done here. Hoisting simple /// arithmetic is a low-level optimization that often isn't very /// important until late in the optimization process. In fact, passes /// like InstructionCombining will combine GEPs, even if it means /// pushing loop-invariant computation down into loops, so even if the /// GEPs were split here, the work would quickly be undone. The /// LoopStrengthReduction pass, which is usually run quite late (and /// after the last InstructionCombining pass), takes care of hoisting /// loop-invariant portions of expressions, after considering what /// can be folded using target addressing modes. /// Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, const SCEV *const *op_end, PointerType *PTy, Type *Ty, Value *V) { Type *OriginalElTy = PTy->getElementType(); Type *ElTy = OriginalElTy; SmallVector GepIndices; SmallVector Ops(op_begin, op_end); bool AnyNonZeroIndices = false; // Split AddRecs up into parts as either of the parts may be usable // without the other. SplitAddRecs(Ops, Ty, SE); Type *IntPtrTy = DL.getIntPtrType(PTy); // Descend down the pointer's type and attempt to convert the other // operands into GEP indices, at each level. The first index in a GEP // indexes into the array implied by the pointer operand; the rest of // the indices index into the element or field type selected by the // preceding index. for (;;) { // If the scale size is not 0, attempt to factor out a scale for // array indexing. SmallVector ScaledOps; if (ElTy->isSized()) { const SCEV *ElSize = SE.getSizeOfExpr(IntPtrTy, ElTy); if (!ElSize->isZero()) { SmallVector NewOps; for (const SCEV *Op : Ops) { const SCEV *Remainder = SE.getConstant(Ty, 0); if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) { // Op now has ElSize factored out. ScaledOps.push_back(Op); if (!Remainder->isZero()) NewOps.push_back(Remainder); AnyNonZeroIndices = true; } else { // The operand was not divisible, so add it to the list of operands // we'll scan next iteration. NewOps.push_back(Op); } } // If we made any changes, update Ops. if (!ScaledOps.empty()) { Ops = NewOps; SimplifyAddOperands(Ops, Ty, SE); } } } // Record the scaled array index for this level of the type. If // we didn't find any operands that could be factored, tentatively // assume that element zero was selected (since the zero offset // would obviously be folded away). Value *Scaled = ScaledOps.empty() ? Constant::getNullValue(Ty) : expandCodeFor(SE.getAddExpr(ScaledOps), Ty); GepIndices.push_back(Scaled); // Collect struct field index operands. while (StructType *STy = dyn_cast(ElTy)) { bool FoundFieldNo = false; // An empty struct has no fields. if (STy->getNumElements() == 0) break; // Field offsets are known. See if a constant offset falls within any of // the struct fields. if (Ops.empty()) break; if (const SCEVConstant *C = dyn_cast(Ops[0])) if (SE.getTypeSizeInBits(C->getType()) <= 64) { const StructLayout &SL = *DL.getStructLayout(STy); uint64_t FullOffset = C->getValue()->getZExtValue(); if (FullOffset < SL.getSizeInBytes()) { unsigned ElIdx = SL.getElementContainingOffset(FullOffset); GepIndices.push_back( ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx)); ElTy = STy->getTypeAtIndex(ElIdx); Ops[0] = SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx)); AnyNonZeroIndices = true; FoundFieldNo = true; } } // If no struct field offsets were found, tentatively assume that // field zero was selected (since the zero offset would obviously // be folded away). if (!FoundFieldNo) { ElTy = STy->getTypeAtIndex(0u); GepIndices.push_back( Constant::getNullValue(Type::getInt32Ty(Ty->getContext()))); } } if (ArrayType *ATy = dyn_cast(ElTy)) ElTy = ATy->getElementType(); else break; } // If none of the operands were convertible to proper GEP indices, cast // the base to i8* and do an ugly getelementptr with that. It's still // better than ptrtoint+arithmetic+inttoptr at least. if (!AnyNonZeroIndices) { // Cast the base to i8*. V = InsertNoopCastOfTo(V, Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace())); assert(!isa(V) || SE.DT.dominates(cast(V), &*Builder.GetInsertPoint())); // Expand the operands for a plain byte offset. Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty); // Fold a GEP with constant operands. if (Constant *CLHS = dyn_cast(V)) if (Constant *CRHS = dyn_cast(Idx)) return ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ty->getContext()), CLHS, CRHS); // Do a quick scan to see if we have this GEP nearby. If so, reuse it. unsigned ScanLimit = 6; BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); // Scanning starts from the last instruction before the insertion point. BasicBlock::iterator IP = Builder.GetInsertPoint(); if (IP != BlockBegin) { --IP; for (; ScanLimit; --IP, --ScanLimit) { // Don't count dbg.value against the ScanLimit, to avoid perturbing the // generated code. if (isa(IP)) ScanLimit++; if (IP->getOpcode() == Instruction::GetElementPtr && IP->getOperand(0) == V && IP->getOperand(1) == Idx) return &*IP; if (IP == BlockBegin) break; } } // Save the original insertion point so we can restore it when we're done. SCEVInsertPointGuard Guard(Builder, this); // Move the insertion point out of as many loops as we can. while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break; BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) break; // Ok, move up a level. Builder.SetInsertPoint(Preheader->getTerminator()); } // Emit a GEP. Value *GEP = Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep"); rememberInstruction(GEP); return GEP; } { SCEVInsertPointGuard Guard(Builder, this); // Move the insertion point out of as many loops as we can. while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { if (!L->isLoopInvariant(V)) break; bool AnyIndexNotLoopInvariant = std::any_of(GepIndices.begin(), GepIndices.end(), [L](Value *Op) { return !L->isLoopInvariant(Op); }); if (AnyIndexNotLoopInvariant) break; BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) break; // Ok, move up a level. Builder.SetInsertPoint(Preheader->getTerminator()); } // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, // because ScalarEvolution may have changed the address arithmetic to // compute a value which is beyond the end of the allocated object. Value *Casted = V; if (V->getType() != PTy) Casted = InsertNoopCastOfTo(Casted, PTy); Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep"); Ops.push_back(SE.getUnknown(GEP)); rememberInstruction(GEP); } return expand(SE.getAddExpr(Ops)); } /// PickMostRelevantLoop - Given two loops pick the one that's most relevant for /// SCEV expansion. If they are nested, this is the most nested. If they are /// neighboring, pick the later. static const Loop *PickMostRelevantLoop(const Loop *A, const Loop *B, DominatorTree &DT) { if (!A) return B; if (!B) return A; if (A->contains(B)) return B; if (B->contains(A)) return A; if (DT.dominates(A->getHeader(), B->getHeader())) return B; if (DT.dominates(B->getHeader(), A->getHeader())) return A; return A; // Arbitrarily break the tie. } /// getRelevantLoop - Get the most relevant loop associated with the given /// expression, according to PickMostRelevantLoop. const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { // Test whether we've already computed the most relevant loop for this SCEV. auto Pair = RelevantLoops.insert(std::make_pair(S, nullptr)); if (!Pair.second) return Pair.first->second; if (isa(S)) // A constant has no relevant loops. return nullptr; if (const SCEVUnknown *U = dyn_cast(S)) { if (const Instruction *I = dyn_cast(U->getValue())) return Pair.first->second = SE.LI.getLoopFor(I->getParent()); // A non-instruction has no relevant loops. return nullptr; } if (const SCEVNAryExpr *N = dyn_cast(S)) { const Loop *L = nullptr; if (const SCEVAddRecExpr *AR = dyn_cast(S)) L = AR->getLoop(); for (const SCEV *Op : N->operands()) L = PickMostRelevantLoop(L, getRelevantLoop(Op), SE.DT); return RelevantLoops[N] = L; } if (const SCEVCastExpr *C = dyn_cast(S)) { const Loop *Result = getRelevantLoop(C->getOperand()); return RelevantLoops[C] = Result; } if (const SCEVUDivExpr *D = dyn_cast(S)) { const Loop *Result = PickMostRelevantLoop( getRelevantLoop(D->getLHS()), getRelevantLoop(D->getRHS()), SE.DT); return RelevantLoops[D] = Result; } llvm_unreachable("Unexpected SCEV type!"); } namespace { /// LoopCompare - Compare loops by PickMostRelevantLoop. class LoopCompare { DominatorTree &DT; public: explicit LoopCompare(DominatorTree &dt) : DT(dt) {} bool operator()(std::pair LHS, std::pair RHS) const { // Keep pointer operands sorted at the end. if (LHS.second->getType()->isPointerTy() != RHS.second->getType()->isPointerTy()) return LHS.second->getType()->isPointerTy(); // Compare loops with PickMostRelevantLoop. if (LHS.first != RHS.first) return PickMostRelevantLoop(LHS.first, RHS.first, DT) != LHS.first; // If one operand is a non-constant negative and the other is not, // put the non-constant negative on the right so that a sub can // be used instead of a negate and add. if (LHS.second->isNonConstantNegative()) { if (!RHS.second->isNonConstantNegative()) return false; } else if (RHS.second->isNonConstantNegative()) return true; // Otherwise they are equivalent according to this comparison. return false; } }; } Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); // Collect all the add operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal, and // so that pointer operands are inserted first, which the code below relies on // to form more involved GEPs. SmallVector, 8> OpsAndLoops; for (std::reverse_iterator I(S->op_end()), E(S->op_begin()); I != E; ++I) OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); // Sort by loop. Use a stable sort so that constants follow non-constants and // pointer operands precede non-pointer operands. std::stable_sort(OpsAndLoops.begin(), OpsAndLoops.end(), LoopCompare(SE.DT)); // Emit instructions to add all the operands. Hoist as much as possible // out of loops, and form meaningful getelementptrs where possible. Value *Sum = nullptr; for (auto I = OpsAndLoops.begin(), E = OpsAndLoops.end(); I != E;) { const Loop *CurLoop = I->first; const SCEV *Op = I->second; if (!Sum) { // This is the first operand. Just expand it. Sum = expand(Op); ++I; } else if (PointerType *PTy = dyn_cast(Sum->getType())) { // The running sum expression is a pointer. Try to form a getelementptr // at this level with that as the base. SmallVector NewOps; for (; I != E && I->first == CurLoop; ++I) { // If the operand is SCEVUnknown and not instructions, peek through // it, to enable more of it to be folded into the GEP. const SCEV *X = I->second; if (const SCEVUnknown *U = dyn_cast(X)) if (!isa(U->getValue())) X = SE.getSCEV(U->getValue()); NewOps.push_back(X); } Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); } else if (PointerType *PTy = dyn_cast(Op->getType())) { // The running sum is an integer, and there's a pointer at this level. // Try to form a getelementptr. If the running sum is instructions, // use a SCEVUnknown to avoid re-analyzing them. SmallVector NewOps; NewOps.push_back(isa(Sum) ? SE.getUnknown(Sum) : SE.getSCEV(Sum)); for (++I; I != E && I->first == CurLoop; ++I) NewOps.push_back(I->second); Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op)); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty); Sum = InsertNoopCastOfTo(Sum, Ty); Sum = InsertBinop(Instruction::Sub, Sum, W); ++I; } else { // A simple add. Value *W = expandCodeFor(Op, Ty); Sum = InsertNoopCastOfTo(Sum, Ty); // Canonicalize a constant to the RHS. if (isa(Sum)) std::swap(Sum, W); Sum = InsertBinop(Instruction::Add, Sum, W); ++I; } } return Sum; } Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); // Collect all the mul operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal. SmallVector, 8> OpsAndLoops; for (std::reverse_iterator I(S->op_end()), E(S->op_begin()); I != E; ++I) OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); // Sort by loop. Use a stable sort so that constants follow non-constants. std::stable_sort(OpsAndLoops.begin(), OpsAndLoops.end(), LoopCompare(SE.DT)); // Emit instructions to mul all the operands. Hoist as much as possible // out of loops. Value *Prod = nullptr; for (const auto &I : OpsAndLoops) { const SCEV *Op = I.second; if (!Prod) { // This is the first operand. Just expand it. Prod = expand(Op); } else if (Op->isAllOnesValue()) { // Instead of doing a multiply by negative one, just do a negate. Prod = InsertNoopCastOfTo(Prod, Ty); Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod); } else { // A simple mul. Value *W = expandCodeFor(Op, Ty); Prod = InsertNoopCastOfTo(Prod, Ty); // Canonicalize a constant to the RHS. if (isa(Prod)) std::swap(Prod, W); const APInt *RHS; if (match(W, m_Power2(RHS))) { // Canonicalize Prod*(1<isVectorTy() && "vector types are not SCEVable"); Prod = InsertBinop(Instruction::Shl, Prod, ConstantInt::get(Ty, RHS->logBase2())); } else { Prod = InsertBinop(Instruction::Mul, Prod, W); } } } return Prod; } Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *LHS = expandCodeFor(S->getLHS(), Ty); if (const SCEVConstant *SC = dyn_cast(S->getRHS())) { const APInt &RHS = SC->getAPInt(); if (RHS.isPowerOf2()) return InsertBinop(Instruction::LShr, LHS, ConstantInt::get(Ty, RHS.logBase2())); } Value *RHS = expandCodeFor(S->getRHS(), Ty); return InsertBinop(Instruction::UDiv, LHS, RHS); } /// Move parts of Base into Rest to leave Base with the minimal /// expression that provides a pointer operand suitable for a /// GEP expansion. static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest, ScalarEvolution &SE) { while (const SCEVAddRecExpr *A = dyn_cast(Base)) { Base = A->getStart(); Rest = SE.getAddExpr(Rest, SE.getAddRecExpr(SE.getConstant(A->getType(), 0), A->getStepRecurrence(SE), A->getLoop(), A->getNoWrapFlags(SCEV::FlagNW))); } if (const SCEVAddExpr *A = dyn_cast(Base)) { Base = A->getOperand(A->getNumOperands()-1); SmallVector NewAddOps(A->op_begin(), A->op_end()); NewAddOps.back() = Rest; Rest = SE.getAddExpr(NewAddOps); ExposePointerBase(Base, Rest, SE); } } /// Determine if this is a well-behaved chain of instructions leading back to /// the PHI. If so, it may be reused by expanded expressions. bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L) { if (IncV->getNumOperands() == 0 || isa(IncV) || (isa(IncV) && !isa(IncV))) return false; // If any of the operands don't dominate the insert position, bail. // Addrec operands are always loop-invariant, so this can only happen // if there are instructions which haven't been hoisted. if (L == IVIncInsertLoop) { for (User::op_iterator OI = IncV->op_begin()+1, OE = IncV->op_end(); OI != OE; ++OI) if (Instruction *OInst = dyn_cast(OI)) if (!SE.DT.dominates(OInst, IVIncInsertPos)) return false; } // Advance to the next instruction. IncV = dyn_cast(IncV->getOperand(0)); if (!IncV) return false; if (IncV->mayHaveSideEffects()) return false; if (IncV != PN) return true; return isNormalAddRecExprPHI(PN, IncV, L); } /// getIVIncOperand returns an induction variable increment's induction /// variable operand. /// /// If allowScale is set, any type of GEP is allowed as long as the nonIV /// operands dominate InsertPos. /// /// If allowScale is not set, ensure that a GEP increment conforms to one of the /// simple patterns generated by getAddRecExprPHILiterally and /// expandAddtoGEP. If the pattern isn't recognized, return NULL. Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV, Instruction *InsertPos, bool allowScale) { if (IncV == InsertPos) return nullptr; switch (IncV->getOpcode()) { default: return nullptr; // Check for a simple Add/Sub or GEP of a loop invariant step. case Instruction::Add: case Instruction::Sub: { Instruction *OInst = dyn_cast(IncV->getOperand(1)); if (!OInst || SE.DT.dominates(OInst, InsertPos)) return dyn_cast(IncV->getOperand(0)); return nullptr; } case Instruction::BitCast: return dyn_cast(IncV->getOperand(0)); case Instruction::GetElementPtr: for (auto I = IncV->op_begin() + 1, E = IncV->op_end(); I != E; ++I) { if (isa(*I)) continue; if (Instruction *OInst = dyn_cast(*I)) { if (!SE.DT.dominates(OInst, InsertPos)) return nullptr; } if (allowScale) { // allow any kind of GEP as long as it can be hoisted. continue; } // This must be a pointer addition of constants (pretty), which is already // handled, or some number of address-size elements (ugly). Ugly geps // have 2 operands. i1* is used by the expander to represent an // address-size element. if (IncV->getNumOperands() != 2) return nullptr; unsigned AS = cast(IncV->getType())->getAddressSpace(); if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS) && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS)) return nullptr; break; } return dyn_cast(IncV->getOperand(0)); } } /// If the insert point of the current builder or any of the builders on the /// stack of saved builders has 'I' as its insert point, update it to point to /// the instruction after 'I'. This is intended to be used when the instruction /// 'I' is being moved. If this fixup is not done and 'I' is moved to a /// different block, the inconsistent insert point (with a mismatched /// Instruction and Block) can lead to an instruction being inserted in a block /// other than its parent. void SCEVExpander::fixupInsertPoints(Instruction *I) { BasicBlock::iterator It(*I); BasicBlock::iterator NewInsertPt = std::next(It); if (Builder.GetInsertPoint() == It) Builder.SetInsertPoint(&*NewInsertPt); for (auto *InsertPtGuard : InsertPointGuards) if (InsertPtGuard->GetInsertPoint() == It) InsertPtGuard->SetInsertPoint(NewInsertPt); } /// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make /// it available to other uses in this loop. Recursively hoist any operands, /// until we reach a value that dominates InsertPos. bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { if (SE.DT.dominates(IncV, InsertPos)) return true; // InsertPos must itself dominate IncV so that IncV's new position satisfies // its existing users. if (isa(InsertPos) || !SE.DT.dominates(InsertPos->getParent(), IncV->getParent())) return false; if (!SE.LI.movementPreservesLCSSAForm(IncV, InsertPos)) return false; // Check that the chain of IV operands leading back to Phi can be hoisted. SmallVector IVIncs; for(;;) { Instruction *Oper = getIVIncOperand(IncV, InsertPos, /*allowScale*/true); if (!Oper) return false; // IncV is safe to hoist. IVIncs.push_back(IncV); IncV = Oper; if (SE.DT.dominates(IncV, InsertPos)) break; } for (auto I = IVIncs.rbegin(), E = IVIncs.rend(); I != E; ++I) { fixupInsertPoints(*I); (*I)->moveBefore(InsertPos); } return true; } /// Determine if this cyclic phi is in a form that would have been generated by /// LSR. We don't care if the phi was actually expanded in this pass, as long /// as it is in a low-cost form, for example, no implied multiplication. This /// should match any patterns generated by getAddRecExprPHILiterally and /// expandAddtoGEP. bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L) { for(Instruction *IVOper = IncV; (IVOper = getIVIncOperand(IVOper, L->getLoopPreheader()->getTerminator(), /*allowScale=*/false));) { if (IVOper == PN) return true; } return false; } /// expandIVInc - Expand an IV increment at Builder's current InsertPos. /// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may /// need to materialize IV increments elsewhere to handle difficult situations. Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, Type *ExpandTy, Type *IntTy, bool useSubtract) { Value *IncV; // If the PHI is a pointer, use a GEP, otherwise use an add or sub. if (ExpandTy->isPointerTy()) { PointerType *GEPPtrTy = cast(ExpandTy); // If the step isn't constant, don't use an implicitly scaled GEP, because // that would require a multiply inside the loop. if (!isa(StepV)) GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()), GEPPtrTy->getAddressSpace()); const SCEV *const StepArray[1] = { SE.getSCEV(StepV) }; IncV = expandAddToGEP(StepArray, StepArray+1, GEPPtrTy, IntTy, PN); if (IncV->getType() != PN->getType()) { IncV = Builder.CreateBitCast(IncV, PN->getType()); rememberInstruction(IncV); } } else { IncV = useSubtract ? Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : Builder.CreateAdd(PN, StepV, Twine(IVName) + ".iv.next"); rememberInstruction(IncV); } return IncV; } /// \brief Hoist the addrec instruction chain rooted in the loop phi above the /// position. This routine assumes that this is possible (has been checked). void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, Instruction *Pos, PHINode *LoopPhi) { do { if (DT->dominates(InstToHoist, Pos)) break; // Make sure the increment is where we want it. But don't move it // down past a potential existing post-inc user. fixupInsertPoints(InstToHoist); InstToHoist->moveBefore(Pos); Pos = InstToHoist; InstToHoist = cast(InstToHoist->getOperand(0)); } while (InstToHoist != LoopPhi); } /// \brief Check whether we can cheaply express the requested SCEV in terms of /// the available PHI SCEV by truncation and/or inversion of the step. static bool canBeCheaplyTransformed(ScalarEvolution &SE, const SCEVAddRecExpr *Phi, const SCEVAddRecExpr *Requested, bool &InvertStep) { Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType()); Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType()); if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth()) return false; // Try truncate it if necessary. Phi = dyn_cast(SE.getTruncateOrNoop(Phi, RequestedTy)); if (!Phi) return false; // Check whether truncation will help. if (Phi == Requested) { InvertStep = false; return true; } // Check whether inverting will help: {R,+,-1} == R - {0,+,1}. if (SE.getAddExpr(Requested->getStart(), SE.getNegativeSCEV(Requested)) == Phi) { InvertStep = true; return true; } return false; } static bool IsIncrementNSW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { if (!isa(AR->getType())) return false; unsigned BitWidth = cast(AR->getType())->getBitWidth(); Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *OpAfterExtend = SE.getAddExpr(SE.getSignExtendExpr(Step, WideTy), SE.getSignExtendExpr(AR, WideTy)); const SCEV *ExtendAfterOp = SE.getSignExtendExpr(SE.getAddExpr(AR, Step), WideTy); return ExtendAfterOp == OpAfterExtend; } static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { if (!isa(AR->getType())) return false; unsigned BitWidth = cast(AR->getType())->getBitWidth(); Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *OpAfterExtend = SE.getAddExpr(SE.getZeroExtendExpr(Step, WideTy), SE.getZeroExtendExpr(AR, WideTy)); const SCEV *ExtendAfterOp = SE.getZeroExtendExpr(SE.getAddExpr(AR, Step), WideTy); return ExtendAfterOp == OpAfterExtend; } /// getAddRecExprPHILiterally - Helper for expandAddRecExprLiterally. Expand /// the base addrec, which is the addrec without any non-loop-dominating /// values, and return the PHI. PHINode * SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, const Loop *L, Type *ExpandTy, Type *IntTy, Type *&TruncTy, bool &InvertStep) { assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position"); // Reuse a previously-inserted PHI, if present. BasicBlock *LatchBlock = L->getLoopLatch(); if (LatchBlock) { PHINode *AddRecPhiMatch = nullptr; Instruction *IncV = nullptr; TruncTy = nullptr; InvertStep = false; // Only try partially matching scevs that need truncation and/or // step-inversion if we know this loop is outside the current loop. bool TryNonMatchingSCEV = IVIncInsertLoop && SE.DT.properlyDominates(LatchBlock, IVIncInsertLoop->getHeader()); for (auto &I : *L->getHeader()) { auto *PN = dyn_cast(&I); if (!PN || !SE.isSCEVable(PN->getType())) continue; const SCEVAddRecExpr *PhiSCEV = dyn_cast(SE.getSCEV(PN)); if (!PhiSCEV) continue; bool IsMatchingSCEV = PhiSCEV == Normalized; // We only handle truncation and inversion of phi recurrences for the // expanded expression if the expanded expression's loop dominates the // loop we insert to. Check now, so we can bail out early. if (!IsMatchingSCEV && !TryNonMatchingSCEV) continue; Instruction *TempIncV = cast(PN->getIncomingValueForBlock(LatchBlock)); // Check whether we can reuse this PHI node. if (LSRMode) { if (!isExpandedAddRecExprPHI(PN, TempIncV, L)) continue; if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos)) continue; } else { if (!isNormalAddRecExprPHI(PN, TempIncV, L)) continue; } // Stop if we have found an exact match SCEV. if (IsMatchingSCEV) { IncV = TempIncV; TruncTy = nullptr; InvertStep = false; AddRecPhiMatch = PN; break; } // Try whether the phi can be translated into the requested form // (truncated and/or offset by a constant). if ((!TruncTy || InvertStep) && canBeCheaplyTransformed(SE, PhiSCEV, Normalized, InvertStep)) { // Record the phi node. But don't stop we might find an exact match // later. AddRecPhiMatch = PN; IncV = TempIncV; TruncTy = SE.getEffectiveSCEVType(Normalized->getType()); } } if (AddRecPhiMatch) { // Potentially, move the increment. We have made sure in // isExpandedAddRecExprPHI or hoistIVInc that this is possible. if (L == IVIncInsertLoop) hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch); // Ok, the add recurrence looks usable. // Remember this PHI, even in post-inc mode. InsertedValues.insert(AddRecPhiMatch); // Remember the increment. rememberInstruction(IncV); return AddRecPhiMatch; } } // Save the original insertion point so we can restore it when we're done. SCEVInsertPointGuard Guard(Builder, this); // Another AddRec may need to be recursively expanded below. For example, if // this AddRec is quadratic, the StepV may itself be an AddRec in this // loop. Remove this loop from the PostIncLoops set before expanding such // AddRecs. Otherwise, we cannot find a valid position for the step // (i.e. StepV can never dominate its loop header). Ideally, we could do // SavedIncLoops.swap(PostIncLoops), but we generally have a single element, // so it's not worth implementing SmallPtrSet::swap. PostIncLoopSet SavedPostIncLoops = PostIncLoops; PostIncLoops.clear(); // Expand code for the start value. Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy, &L->getHeader()->front()); // StartV must be hoisted into L's preheader to dominate the new phi. assert(!isa(StartV) || SE.DT.properlyDominates(cast(StartV)->getParent(), L->getHeader())); // Expand code for the step value. Do this before creating the PHI so that PHI // reuse code doesn't see an incomplete PHI. const SCEV *Step = Normalized->getStepRecurrence(SE); // If the stride is negative, insert a sub instead of an add for the increment // (unless it's a constant, because subtracts of constants are canonicalized // to adds). bool useSubtract = !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); if (useSubtract) Step = SE.getNegativeSCEV(Step); // Expand the step somewhere that dominates the loop header. Value *StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if // we actually do emit an addition. It does not apply if we emit a // subtraction. bool IncrementIsNUW = !useSubtract && IsIncrementNUW(SE, Normalized); bool IncrementIsNSW = !useSubtract && IsIncrementNSW(SE, Normalized); // Create the PHI. BasicBlock *Header = L->getHeader(); Builder.SetInsertPoint(Header, Header->begin()); pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); PHINode *PN = Builder.CreatePHI(ExpandTy, std::distance(HPB, HPE), Twine(IVName) + ".iv"); rememberInstruction(PN); // Create the step instructions and populate the PHI. for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { BasicBlock *Pred = *HPI; // Add a start value. if (!L->contains(Pred)) { PN->addIncoming(StartV, Pred); continue; } // Create a step value and add it to the PHI. // If IVIncInsertLoop is non-null and equal to the addrec's loop, insert the // instructions at IVIncInsertPos. Instruction *InsertPos = L == IVIncInsertLoop ? IVIncInsertPos : Pred->getTerminator(); Builder.SetInsertPoint(InsertPos); Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); if (isa(IncV)) { if (IncrementIsNUW) cast(IncV)->setHasNoUnsignedWrap(); if (IncrementIsNSW) cast(IncV)->setHasNoSignedWrap(); } PN->addIncoming(IncV, Pred); } // After expanding subexpressions, restore the PostIncLoops set so the caller // can ensure that IVIncrement dominates the current uses. PostIncLoops = SavedPostIncLoops; // Remember this PHI, even in post-inc mode. InsertedValues.insert(PN); return PN; } Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { Type *STy = S->getType(); Type *IntTy = SE.getEffectiveSCEVType(STy); const Loop *L = S->getLoop(); // Determine a normalized form of this expression, which is the expression // before any post-inc adjustment is made. const SCEVAddRecExpr *Normalized = S; if (PostIncLoops.count(L)) { PostIncLoopSet Loops; Loops.insert(L); Normalized = cast(TransformForPostIncUse( Normalize, S, nullptr, nullptr, Loops, SE, SE.DT)); } // Strip off any non-loop-dominating component from the addrec start. const SCEV *Start = Normalized->getStart(); const SCEV *PostLoopOffset = nullptr; if (!SE.properlyDominates(Start, L->getHeader())) { PostLoopOffset = Start; Start = SE.getConstant(Normalized->getType(), 0); Normalized = cast( SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE), Normalized->getLoop(), Normalized->getNoWrapFlags(SCEV::FlagNW))); } // Strip off any non-loop-dominating component from the addrec step. const SCEV *Step = Normalized->getStepRecurrence(SE); const SCEV *PostLoopScale = nullptr; if (!SE.dominates(Step, L->getHeader())) { PostLoopScale = Step; Step = SE.getConstant(Normalized->getType(), 1); if (!Start->isZero()) { // The normalization below assumes that Start is constant zero, so if // it isn't re-associate Start to PostLoopOffset. assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?"); PostLoopOffset = Start; Start = SE.getConstant(Normalized->getType(), 0); } Normalized = cast(SE.getAddRecExpr( Start, Step, Normalized->getLoop(), Normalized->getNoWrapFlags(SCEV::FlagNW))); } // Expand the core addrec. If we need post-loop scaling, force it to // expand to an integer type to avoid the need for additional casting. Type *ExpandTy = PostLoopScale ? IntTy : STy; // In some cases, we decide to reuse an existing phi node but need to truncate // it and/or invert the step. Type *TruncTy = nullptr; bool InvertStep = false; PHINode *PN = getAddRecExprPHILiterally(Normalized, L, ExpandTy, IntTy, TruncTy, InvertStep); // Accommodate post-inc mode, if necessary. Value *Result; if (!PostIncLoops.count(L)) Result = PN; else { // In PostInc mode, use the post-incremented value. BasicBlock *LatchBlock = L->getLoopLatch(); assert(LatchBlock && "PostInc mode requires a unique loop latch!"); Result = PN->getIncomingValueForBlock(LatchBlock); // For an expansion to use the postinc form, the client must call // expandCodeFor with an InsertPoint that is either outside the PostIncLoop // or dominated by IVIncInsertPos. if (isa(Result) && !SE.DT.dominates(cast(Result), &*Builder.GetInsertPoint())) { // The induction variable's postinc expansion does not dominate this use. // IVUsers tries to prevent this case, so it is rare. However, it can // happen when an IVUser outside the loop is not dominated by the latch // block. Adjusting IVIncInsertPos before expansion begins cannot handle // all cases. Consider a phi outide whose operand is replaced during // expansion with the value of the postinc user. Without fundamentally // changing the way postinc users are tracked, the only remedy is // inserting an extra IV increment. StepV might fold into PostLoopOffset, // but hopefully expandCodeFor handles that. bool useSubtract = !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); if (useSubtract) Step = SE.getNegativeSCEV(Step); Value *StepV; { // Expand the step somewhere that dominates the loop header. SCEVInsertPointGuard Guard(Builder, this); StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); } Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); } } // We have decided to reuse an induction variable of a dominating loop. Apply // truncation and/or invertion of the step. if (TruncTy) { Type *ResTy = Result->getType(); // Normalize the result type. if (ResTy != SE.getEffectiveSCEVType(ResTy)) Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy)); // Truncate the result. if (TruncTy != Result->getType()) { Result = Builder.CreateTrunc(Result, TruncTy); rememberInstruction(Result); } // Invert the result. if (InvertStep) { Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy), Result); rememberInstruction(Result); } } // Re-apply any non-loop-dominating scale. if (PostLoopScale) { assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateMul(Result, expandCodeFor(PostLoopScale, IntTy)); rememberInstruction(Result); } // Re-apply any non-loop-dominating offset. if (PostLoopOffset) { if (PointerType *PTy = dyn_cast(ExpandTy)) { const SCEV *const OffsetArray[1] = { PostLoopOffset }; Result = expandAddToGEP(OffsetArray, OffsetArray+1, PTy, IntTy, Result); } else { Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateAdd(Result, expandCodeFor(PostLoopOffset, IntTy)); rememberInstruction(Result); } } return Result; } Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { if (!CanonicalMode) return expandAddRecExprLiterally(S); Type *Ty = SE.getEffectiveSCEVType(S->getType()); const Loop *L = S->getLoop(); // First check for an existing canonical IV in a suitable type. PHINode *CanonicalIV = nullptr; if (PHINode *PN = L->getCanonicalInductionVariable()) if (SE.getTypeSizeInBits(PN->getType()) >= SE.getTypeSizeInBits(Ty)) CanonicalIV = PN; // Rewrite an AddRec in terms of the canonical induction variable, if // its type is more narrow. if (CanonicalIV && SE.getTypeSizeInBits(CanonicalIV->getType()) > SE.getTypeSizeInBits(Ty)) { SmallVector NewOps(S->getNumOperands()); for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i) NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType()); Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(), S->getNoWrapFlags(SCEV::FlagNW))); BasicBlock::iterator NewInsertPt = findInsertPointAfter(cast(V), Builder.GetInsertBlock()); V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, &*NewInsertPt); return V; } // {X,+,F} --> X + {0,+,F} if (!S->getStart()->isZero()) { SmallVector NewOps(S->op_begin(), S->op_end()); NewOps[0] = SE.getConstant(Ty, 0); const SCEV *Rest = SE.getAddRecExpr(NewOps, L, S->getNoWrapFlags(SCEV::FlagNW)); // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the // comments on expandAddToGEP for details. const SCEV *Base = S->getStart(); const SCEV *RestArray[1] = { Rest }; // Dig into the expression to find the pointer base for a GEP. ExposePointerBase(Base, RestArray[0], SE); // If we found a pointer, expand the AddRec with a GEP. if (PointerType *PTy = dyn_cast(Base->getType())) { // Make sure the Base isn't something exotic, such as a multiplied // or divided pointer value. In those cases, the result type isn't // actually a pointer type. if (!isa(Base) && !isa(Base)) { Value *StartV = expand(Base); assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); return expandAddToGEP(RestArray, RestArray+1, PTy, Ty, StartV); } } // Just do a normal add. Pre-expand the operands to suppress folding. // // The LHS and RHS values are factored out of the expand call to make the // output independent of the argument evaluation order. const SCEV *AddExprLHS = SE.getUnknown(expand(S->getStart())); const SCEV *AddExprRHS = SE.getUnknown(expand(Rest)); return expand(SE.getAddExpr(AddExprLHS, AddExprRHS)); } // If we don't yet have a canonical IV, create one. if (!CanonicalIV) { // Create and insert the PHI node for the induction variable in the // specified loop. BasicBlock *Header = L->getHeader(); pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar", &Header->front()); rememberInstruction(CanonicalIV); SmallSet PredSeen; Constant *One = ConstantInt::get(Ty, 1); for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { BasicBlock *HP = *HPI; if (!PredSeen.insert(HP).second) { // There must be an incoming value for each predecessor, even the // duplicates! CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP); continue; } if (L->contains(HP)) { // Insert a unit add instruction right before the terminator // corresponding to the back-edge. Instruction *Add = BinaryOperator::CreateAdd(CanonicalIV, One, "indvar.next", HP->getTerminator()); Add->setDebugLoc(HP->getTerminator()->getDebugLoc()); rememberInstruction(Add); CanonicalIV->addIncoming(Add, HP); } else { CanonicalIV->addIncoming(Constant::getNullValue(Ty), HP); } } } // {0,+,1} --> Insert a canonical induction variable into the loop! if (S->isAffine() && S->getOperand(1)->isOne()) { assert(Ty == SE.getEffectiveSCEVType(CanonicalIV->getType()) && "IVs with types different from the canonical IV should " "already have been handled!"); return CanonicalIV; } // {0,+,F} --> {0,+,1} * F // If this is a simple linear addrec, emit it now as a special case. if (S->isAffine()) // {0,+,F} --> i*F return expand(SE.getTruncateOrNoop( SE.getMulExpr(SE.getUnknown(CanonicalIV), SE.getNoopOrAnyExtend(S->getOperand(1), CanonicalIV->getType())), Ty)); // If this is a chain of recurrences, turn it into a closed form, using the // folders, then expandCodeFor the closed form. This allows the folders to // simplify the expression without having to build a bunch of special code // into this folder. const SCEV *IH = SE.getUnknown(CanonicalIV); // Get I as a "symbolic" SCEV. // Promote S up to the canonical IV type, if the cast is foldable. const SCEV *NewS = S; const SCEV *Ext = SE.getNoopOrAnyExtend(S, CanonicalIV->getType()); if (isa(Ext)) NewS = Ext; const SCEV *V = cast(NewS)->evaluateAtIteration(IH, SE); //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n"; // Truncate the result down to the original type, if needed. const SCEV *T = SE.getTruncateOrNoop(V, Ty); return expand(T); } Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeFor(S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())); Value *I = Builder.CreateTrunc(V, Ty); rememberInstruction(I); return I; } Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeFor(S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())); Value *I = Builder.CreateZExt(V, Ty); rememberInstruction(I); return I; } Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeFor(S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType())); Value *I = Builder.CreateSExt(V, Ty); rememberInstruction(I); return I; } Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { // In the case of mixed integer and pointer types, do the // rest of the comparisons as integer. if (S->getOperand(i)->getType() != Ty) { Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } Value *RHS = expandCodeFor(S->getOperand(i), Ty); Value *ICmp = Builder.CreateICmpSGT(LHS, RHS); rememberInstruction(ICmp); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax"); rememberInstruction(Sel); LHS = Sel; } // In the case of mixed integer and pointer types, cast the // final result back to the pointer type. if (LHS->getType() != S->getType()) LHS = InsertNoopCastOfTo(LHS, S->getType()); return LHS; } Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { // In the case of mixed integer and pointer types, do the // rest of the comparisons as integer. if (S->getOperand(i)->getType() != Ty) { Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } Value *RHS = expandCodeFor(S->getOperand(i), Ty); Value *ICmp = Builder.CreateICmpUGT(LHS, RHS); rememberInstruction(ICmp); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax"); rememberInstruction(Sel); LHS = Sel; } // In the case of mixed integer and pointer types, cast the // final result back to the pointer type. if (LHS->getType() != S->getType()) LHS = InsertNoopCastOfTo(LHS, S->getType()); return LHS; } Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty, Instruction *IP) { setInsertPoint(IP); return expandCodeFor(SH, Ty); } Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { // Expand the code for this SCEV. Value *V = expand(SH); if (Ty) { assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && "non-trivial casts should be done with the SCEVs directly!"); V = InsertNoopCastOfTo(V, Ty); } return V; } -Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, - const Instruction *InsertPt) { - SetVector *Set = SE.getSCEVValues(S); +ScalarEvolution::ValueOffsetPair +SCEVExpander::FindValueInExprValueMap(const SCEV *S, + const Instruction *InsertPt) { + SetVector *Set = SE.getSCEVValues(S); // If the expansion is not in CanonicalMode, and the SCEV contains any // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. if (CanonicalMode || !SE.containsAddRecurrence(S)) { // If S is scConstant, it may be worse to reuse an existing Value. if (S->getSCEVType() != scConstant && Set) { // Choose a Value from the set which dominates the insertPt. // insertPt should be inside the Value's parent loop so as not to break // the LCSSA form. - for (auto const &Ent : *Set) { + for (auto const &VOPair : *Set) { + Value *V = VOPair.first; + ConstantInt *Offset = VOPair.second; Instruction *EntInst = nullptr; - if (Ent && isa(Ent) && - (EntInst = cast(Ent)) && - S->getType() == Ent->getType() && + if (V && isa(V) && (EntInst = cast(V)) && + S->getType() == V->getType() && EntInst->getFunction() == InsertPt->getFunction() && SE.DT.dominates(EntInst, InsertPt) && (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || - SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) { - return Ent; - } + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + return {V, Offset}; } } } - return nullptr; + return {nullptr, nullptr}; } // The expansion of SCEV will either reuse a previous Value in ExprValueMap, // or expand the SCEV literally. Specifically, if the expansion is in LSRMode, // and the SCEV contains any sub scAddRecExpr type SCEV, it will be expanded // literally, to prevent LSR's transformed SCEV from being reverted. Otherwise, // the expansion will try to reuse Value from ExprValueMap, and only when it // fails, expand the SCEV literally. Value *SCEVExpander::expand(const SCEV *S) { // Compute an insertion point for this SCEV object. Hoist the instructions // as far out in the loop nest as possible. Instruction *InsertPt = &*Builder.GetInsertPoint(); for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());; L = L->getParentLoop()) if (SE.isLoopInvariant(S, L)) { if (!L) break; if (BasicBlock *Preheader = L->getLoopPreheader()) InsertPt = Preheader->getTerminator(); else { // LSR sets the insertion point for AddRec start/step values to the // block start to simplify value reuse, even though it's an invalid // position. SCEVExpander must correct for this in all cases. InsertPt = &*L->getHeader()->getFirstInsertionPt(); } } else { // If the SCEV is computable at this level, insert it into the header // after the PHIs (and after any other instructions that we've inserted // there) so that it is guaranteed to dominate any user inside the loop. if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) InsertPt = &*L->getHeader()->getFirstInsertionPt(); while (InsertPt->getIterator() != Builder.GetInsertPoint() && (isInsertedInstruction(InsertPt) || isa(InsertPt))) { InsertPt = &*std::next(InsertPt->getIterator()); } break; } // Check to see if we already expanded this here. auto I = InsertedExpressions.find(std::make_pair(S, InsertPt)); if (I != InsertedExpressions.end()) return I->second; SCEVInsertPointGuard Guard(Builder, this); Builder.SetInsertPoint(InsertPt); // Expand the expression into instructions. - Value *V = FindValueInExprValueMap(S, InsertPt); + ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, InsertPt); + Value *V = VO.first; if (!V) V = visit(S); - + else if (VO.second) { + if (PointerType *Vty = dyn_cast(V->getType())) { + Type *Ety = Vty->getPointerElementType(); + int64_t Offset = VO.second->getSExtValue(); + int64_t ESize = SE.getTypeSizeInBits(Ety); + if ((Offset * 8) % ESize == 0) { + ConstantInt *Idx = + ConstantInt::getSigned(VO.second->getType(), -(Offset * 8) / ESize); + V = Builder.CreateGEP(Ety, V, Idx, "scevgep"); + } else { + ConstantInt *Idx = + ConstantInt::getSigned(VO.second->getType(), -Offset); + unsigned AS = Vty->getAddressSpace(); + V = Builder.CreateBitCast(V, Type::getInt8PtrTy(SE.getContext(), AS)); + V = Builder.CreateGEP(Type::getInt8Ty(SE.getContext()), V, Idx, + "uglygep"); + V = Builder.CreateBitCast(V, Vty); + } + } else { + V = Builder.CreateSub(V, VO.second); + } + } // Remember the expanded value for this SCEV at this location. // // This is independent of PostIncLoops. The mapped value simply materializes // the expression at this insertion point. If the mapped value happened to be // a postinc expansion, it could be reused by a non-postinc user, but only if // its insertion point was already at the head of the loop. InsertedExpressions[std::make_pair(S, InsertPt)] = V; return V; } void SCEVExpander::rememberInstruction(Value *I) { if (!PostIncLoops.empty()) InsertedPostIncValues.insert(I); else InsertedValues.insert(I); } /// getOrInsertCanonicalInductionVariable - This method returns the /// canonical induction variable of the specified type for the specified /// loop (inserting one if there is none). A canonical induction variable /// starts at zero and steps by one on each iteration. PHINode * SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L, Type *Ty) { assert(Ty->isIntegerTy() && "Can only insert integer induction variables!"); // Build a SCEV for {0,+,1}. // Conservatively use FlagAnyWrap for now. const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0), SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap); // Emit code for it. SCEVInsertPointGuard Guard(Builder, this); PHINode *V = cast(expandCodeFor(H, nullptr, &L->getHeader()->front())); return V; } /// replaceCongruentIVs - Check for congruent phis in this loop header and /// replace them with their most canonical representative. Return the number of /// phis eliminated. /// /// This does not depend on any SCEVExpander state but should be used in /// the same context that SCEVExpander is used. unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, SmallVectorImpl &DeadInsts, const TargetTransformInfo *TTI) { // Find integer phis in order of increasing width. SmallVector Phis; for (auto &I : *L->getHeader()) { if (auto *PN = dyn_cast(&I)) Phis.push_back(PN); else break; } if (TTI) std::sort(Phis.begin(), Phis.end(), [](Value *LHS, Value *RHS) { // Put pointers at the back and make sure pointer < pointer = false. if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); return RHS->getType()->getPrimitiveSizeInBits() < LHS->getType()->getPrimitiveSizeInBits(); }); unsigned NumElim = 0; DenseMap ExprToIVMap; // Process phis from wide to narrow. Map wide phis to their truncation // so narrow phis can reuse them. for (PHINode *Phi : Phis) { auto SimplifyPHINode = [&](PHINode *PN) -> Value * { if (Value *V = SimplifyInstruction(PN, DL, &SE.TLI, &SE.DT, &SE.AC)) return V; if (!SE.isSCEVable(PN->getType())) return nullptr; auto *Const = dyn_cast(SE.getSCEV(PN)); if (!Const) return nullptr; return Const->getValue(); }; // Fold constant phis. They may be congruent to other constant phis and // would confuse the logic below that expects proper IVs. if (Value *V = SimplifyPHINode(Phi)) { if (V->getType() != Phi->getType()) continue; Phi->replaceAllUsesWith(V); DeadInsts.emplace_back(Phi); ++NumElim; DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated constant iv: " << *Phi << '\n'); continue; } if (!SE.isSCEVable(Phi->getType())) continue; PHINode *&OrigPhiRef = ExprToIVMap[SE.getSCEV(Phi)]; if (!OrigPhiRef) { OrigPhiRef = Phi; if (Phi->getType()->isIntegerTy() && TTI && TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { // This phi can be freely truncated to the narrowest phi type. Map the // truncated expression to it so it will be reused for narrow types. const SCEV *TruncExpr = SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType()); ExprToIVMap[TruncExpr] = Phi; } continue; } // Replacing a pointer phi with an integer phi or vice-versa doesn't make // sense. if (OrigPhiRef->getType()->isPointerTy() != Phi->getType()->isPointerTy()) continue; if (BasicBlock *LatchBlock = L->getLoopLatch()) { Instruction *OrigInc = dyn_cast( OrigPhiRef->getIncomingValueForBlock(LatchBlock)); Instruction *IsomorphicInc = dyn_cast(Phi->getIncomingValueForBlock(LatchBlock)); if (OrigInc && IsomorphicInc) { // If this phi has the same width but is more canonical, replace the // original with it. As part of the "more canonical" determination, // respect a prior decision to use an IV chain. if (OrigPhiRef->getType() == Phi->getType() && !(ChainedPhis.count(Phi) || isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) && (ChainedPhis.count(Phi) || isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { std::swap(OrigPhiRef, Phi); std::swap(OrigInc, IsomorphicInc); } // Replacing the congruent phi is sufficient because acyclic // redundancy elimination, CSE/GVN, should handle the // rest. However, once SCEV proves that a phi is congruent, // it's often the head of an IV user cycle that is isomorphic // with the original phi. It's worth eagerly cleaning up the // common case of a single IV increment so that DeleteDeadPHIs // can remove cycles that had postinc uses. const SCEV *TruncExpr = SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); if (OrigInc != IsomorphicInc && TruncExpr == SE.getSCEV(IsomorphicInc) && SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && hoistIVInc(OrigInc, IsomorphicInc)) { DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv.inc: " << *IsomorphicInc << '\n'); Value *NewInc = OrigInc; if (OrigInc->getType() != IsomorphicInc->getType()) { Instruction *IP = nullptr; if (PHINode *PN = dyn_cast(OrigInc)) IP = &*PN->getParent()->getFirstInsertionPt(); else IP = OrigInc->getNextNode(); IRBuilder<> Builder(IP); Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); NewInc = Builder.CreateTruncOrBitCast( OrigInc, IsomorphicInc->getType(), IVName); } IsomorphicInc->replaceAllUsesWith(NewInc); DeadInsts.emplace_back(IsomorphicInc); } } } DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: " << *Phi << '\n'); ++NumElim; Value *NewIV = OrigPhiRef; if (OrigPhiRef->getType() != Phi->getType()) { IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt()); Builder.SetCurrentDebugLocation(Phi->getDebugLoc()); NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName); } Phi->replaceAllUsesWith(NewIV); DeadInsts.emplace_back(Phi); } return NumElim; } -Value *SCEVExpander::findExistingExpansion(const SCEV *S, - const Instruction *At, Loop *L) { +Value *SCEVExpander::getExactExistingExpansion(const SCEV *S, + const Instruction *At, Loop *L) { + Optional VO = + getRelatedExistingExpansion(S, At, L); + if (VO && VO.getValue().second == nullptr) + return VO.getValue().first; + return nullptr; +} + +Optional +SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, + Loop *L) { using namespace llvm::PatternMatch; SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); // Look for suitable value in simple conditions at the loop exits. for (BasicBlock *BB : ExitingBlocks) { ICmpInst::Predicate Pred; Instruction *LHS, *RHS; BasicBlock *TrueBB, *FalseBB; if (!match(BB->getTerminator(), m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), TrueBB, FalseBB))) continue; if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) - return LHS; + return ScalarEvolution::ValueOffsetPair(LHS, nullptr); if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) - return RHS; + return ScalarEvolution::ValueOffsetPair(RHS, nullptr); } // Use expand's logic which is used for reusing a previous Value in // ExprValueMap. - if (Value *Val = FindValueInExprValueMap(S, At)) - return Val; + ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, At); + if (VO.first) + return VO; // There is potential to make this significantly smarter, but this simple // heuristic already gets some interesting cases. // Can not find suitable value. - return nullptr; + return None; } bool SCEVExpander::isHighCostExpansionHelper( const SCEV *S, Loop *L, const Instruction *At, SmallPtrSetImpl &Processed) { // If we can find an existing value for this scev avaliable at the point "At" // then consider the expression cheap. - if (At && findExistingExpansion(S, At, L) != nullptr) + if (At && getRelatedExistingExpansion(S, At, L)) return false; // Zero/One operand expressions switch (S->getSCEVType()) { case scUnknown: case scConstant: return false; case scTruncate: return isHighCostExpansionHelper(cast(S)->getOperand(), L, At, Processed); case scZeroExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), L, At, Processed); case scSignExtend: return isHighCostExpansionHelper(cast(S)->getOperand(), L, At, Processed); } if (!Processed.insert(S).second) return false; if (auto *UDivExpr = dyn_cast(S)) { // If the divisor is a power of two and the SCEV type fits in a native // integer, consider the division cheap irrespective of whether it occurs in // the user code since it can be lowered into a right shift. if (auto *SC = dyn_cast(UDivExpr->getRHS())) if (SC->getAPInt().isPowerOf2()) { const DataLayout &DL = L->getHeader()->getParent()->getParent()->getDataLayout(); unsigned Width = cast(UDivExpr->getType())->getBitWidth(); return DL.isIllegalInteger(Width); } // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or // HowManyLessThans produced to compute a precise expression, rather than a // UDiv from the user's code. If we can't find a UDiv in the code with some // simple searching, assume the former consider UDivExpr expensive to // compute. BasicBlock *ExitingBB = L->getExitingBlock(); if (!ExitingBB) return true; // At the beginning of this function we already tried to find existing value // for plain 'S'. Now try to lookup 'S + 1' since it is common pattern // involving division. This is just a simple search heuristic. if (!At) At = &ExitingBB->back(); - if (!findExistingExpansion( + if (!getRelatedExistingExpansion( SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L)) return true; } // HowManyLessThans uses a Max expression whenever the loop is not guarded by // the exit condition. if (isa(S) || isa(S)) return true; // Recurse past nary expressions, which commonly occur in the // BackedgeTakenCount. They may already exist in program code, and if not, // they are not too expensive rematerialize. if (const SCEVNAryExpr *NAry = dyn_cast(S)) { for (auto *Op : NAry->operands()) if (isHighCostExpansionHelper(Op, L, At, Processed)) return true; } // If we haven't recognized an expensive SCEV pattern, assume it's an // expression produced by program code. return false; } Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *IP) { assert(IP); switch (Pred->getKind()) { case SCEVPredicate::P_Union: return expandUnionPredicate(cast(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast(Pred), IP); case SCEVPredicate::P_Wrap: { auto *AddRecPred = cast(Pred); return expandWrapPredicate(AddRecPred, IP); } } llvm_unreachable("Unknown SCEV predicate type"); } Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *IP) { Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP); Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP); Builder.SetInsertPoint(IP); auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); return I; } Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, bool Signed) { assert(AR->isAffine() && "Cannot generate RT check for " "non-affine expression"); SCEVUnionPredicate Pred; const SCEV *ExitCount = SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *Start = AR->getStart(); unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); // The expression {Start,+,Step} has nusw/nssw if // Step < 0, Start - |Step| * Backedge <= Start // Step >= 0, Start + |Step| * Backedge > Start // and |Step| * Backedge doesn't unsigned overflow. IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); Builder.SetInsertPoint(Loc); Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); IntegerType *Ty = IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType())); Value *StepValue = expandCodeFor(Step, Ty, Loc); Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); Value *StartValue = expandCodeFor(Start, Ty, Loc); ConstantInt *Zero = ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); Builder.SetInsertPoint(Loc); // Compute |Step| Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); // Get the backedge taken count and truncate or extended to the AR type. Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), Intrinsic::umul_with_overflow, Ty); // Compute |Step| * Backedge CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); // Compute: // Start + |Step| * Backedge < Start // Start - |Step| * Backedge > Start Value *Add = Builder.CreateAdd(StartValue, MulV); Value *Sub = Builder.CreateSub(StartValue, MulV); Value *EndCompareGT = Builder.CreateICmp( Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); Value *EndCompareLT = Builder.CreateICmp( Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); // Select the answer based on the sign of Step. Value *EndCheck = Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); // If the backedge taken count type is larger than the AR type, // check that we don't drop any bits by truncating it. If we are // droping bits, then we have overflow (unless the step is zero). if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); auto *BackedgeCheck = Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, ConstantInt::get(Loc->getContext(), MaxVal)); BackedgeCheck = Builder.CreateAnd( BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); } EndCheck = Builder.CreateOr(EndCheck, OfMul); return EndCheck; } Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, Instruction *IP) { const auto *A = cast(Pred->getExpr()); Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; // Add a check for NUSW if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) NUSWCheck = generateOverflowCheck(A, IP, false); // Add a check for NSSW if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) NSSWCheck = generateOverflowCheck(A, IP, true); if (NUSWCheck && NSSWCheck) return Builder.CreateOr(NUSWCheck, NSSWCheck); if (NUSWCheck) return NUSWCheck; if (NSSWCheck) return NSSWCheck; return ConstantInt::getFalse(IP->getContext()); } Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); Value *Check = ConstantInt::getNullValue(BoolType); // Loop over all checks in this set. for (auto Pred : Union->getPredicates()) { auto *NextCheck = expandCodeForPredicate(Pred, IP); Builder.SetInsertPoint(IP); Check = Builder.CreateOr(Check, NextCheck); } return Check; } namespace { // Search for a SCEV subexpression that is not safe to expand. Any expression // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely // UDiv expressions. We don't know if the UDiv is derived from an IR divide // instruction, but the important thing is that we prove the denominator is // nonzero before expansion. // // IVUsers already checks that IV-derived expressions are safe. So this check is // only needed when the expression includes some subexpression that is not IV // derived. // // Currently, we only allow division by a nonzero constant here. If this is // inadequate, we could easily allow division by SCEVUnknown by using // ValueTracking to check isKnownNonZero(). // // We cannot generally expand recurrences unless the step dominates the loop // header. The expander handles the special case of affine recurrences by // scaling the recurrence outside the loop, but this technique isn't generally // applicable. Expanding a nested recurrence outside a loop requires computing // binomial coefficients. This could be done, but the recurrence has to be in a // perfectly reduced form, which can't be guaranteed. struct SCEVFindUnsafe { ScalarEvolution &SE; bool IsUnsafe; SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {} bool follow(const SCEV *S) { if (const SCEVUDivExpr *D = dyn_cast(S)) { const SCEVConstant *SC = dyn_cast(D->getRHS()); if (!SC || SC->getValue()->isZero()) { IsUnsafe = true; return false; } } if (const SCEVAddRecExpr *AR = dyn_cast(S)) { const SCEV *Step = AR->getStepRecurrence(SE); if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) { IsUnsafe = true; return false; } } return true; } bool isDone() const { return IsUnsafe; } }; } namespace llvm { bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) { SCEVFindUnsafe Search(SE); visitAll(S, Search); return !Search.IsUnsafe; } } Index: head/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- head/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp (revision 312831) +++ head/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp (revision 312832) @@ -1,2288 +1,2288 @@ //===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // This transformation analyzes and transforms the induction variables (and // computations derived from them) into simpler forms suitable for subsequent // analysis and transformation. // // If the trip count of a loop is computable, this pass also makes the following // changes: // 1. The exit condition for the loop is canonicalized to compare the // induction value against the exit value. This turns loops like: // 'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)' // 2. Any use outside of the loop of an expression derived from the indvar // is changed to compute the derived value outside of the loop, eliminating // the dependence on the exit value of the induction variable. If the only // purpose of the loop is to compute the exit value of some derived // expression, this transformation will make the loop dead. // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/IndVarSimplify.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" using namespace llvm; #define DEBUG_TYPE "indvars" STATISTIC(NumWidened , "Number of indvars widened"); STATISTIC(NumReplaced , "Number of exit values replaced"); STATISTIC(NumLFTR , "Number of loop exit tests replaced"); STATISTIC(NumElimExt , "Number of IV sign/zero extends eliminated"); STATISTIC(NumElimIV , "Number of congruent IVs eliminated"); // Trip count verification can be enabled by default under NDEBUG if we // implement a strong expression equivalence checker in SCEV. Until then, we // use the verify-indvars flag, which may assert in some cases. static cl::opt VerifyIndvars( "verify-indvars", cl::Hidden, cl::desc("Verify the ScalarEvolution result after running indvars")); enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, AlwaysRepl }; static cl::opt ReplaceExitValue( "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), cl::desc("Choose the strategy to replace exit value in IndVarSimplify"), cl::values(clEnumValN(NeverRepl, "never", "never replace exit value"), clEnumValN(OnlyCheapRepl, "cheap", "only replace exit value when the cost is cheap"), clEnumValN(AlwaysRepl, "always", "always replace exit value whenever possible"), clEnumValEnd)); namespace { struct RewritePhi; class IndVarSimplify { LoopInfo *LI; ScalarEvolution *SE; DominatorTree *DT; const DataLayout &DL; TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; SmallVector DeadInsts; bool Changed = false; bool isValidRewrite(Value *FromVal, Value *ToVal); void handleFloatingPointIV(Loop *L, PHINode *PH); void rewriteNonIntegerIVs(Loop *L); void simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); bool canLoopBeDeleted(Loop *L, SmallVector &RewritePhiSet); void rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); void rewriteFirstIterationLoopExitValues(Loop *L); Value *linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, PHINode *IndVar, SCEVExpander &Rewriter); void sinkUnusedInvariants(Loop *L); Value *expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, Instruction *InsertPt, Type *Ty); public: IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, const DataLayout &DL, TargetLibraryInfo *TLI, TargetTransformInfo *TTI) : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {} bool run(Loop *L); }; } /// Return true if the SCEV expansion generated by the rewriter can replace the /// original value. SCEV guarantees that it produces the same value, but the way /// it is produced may be illegal IR. Ideally, this function will only be /// called for verification. bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { // If an SCEV expression subsumed multiple pointers, its expansion could // reassociate the GEP changing the base pointer. This is illegal because the // final address produced by a GEP chain must be inbounds relative to its // underlying object. Otherwise basic alias analysis, among other things, // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid // producing an expression involving multiple pointers. Until then, we must // bail out here. // // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject // because it understands lcssa phis while SCEV does not. Value *FromPtr = FromVal; Value *ToPtr = ToVal; if (auto *GEP = dyn_cast(FromVal)) { FromPtr = GEP->getPointerOperand(); } if (auto *GEP = dyn_cast(ToVal)) { ToPtr = GEP->getPointerOperand(); } if (FromPtr != FromVal || ToPtr != ToVal) { // Quickly check the common case if (FromPtr == ToPtr) return true; // SCEV may have rewritten an expression that produces the GEP's pointer // operand. That's ok as long as the pointer operand has the same base // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the // base of a recurrence. This handles the case in which SCEV expansion // converts a pointer type recurrence into a nonrecurrent pointer base // indexed by an integer recurrence. // If the GEP base pointer is a vector of pointers, abort. if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) return false; const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); if (FromBase == ToBase) return true; DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " << *FromBase << " != " << *ToBase << "\n"); return false; } return true; } /// Determine the insertion point for this user. By default, insert immediately /// before the user. SCEVExpander or LICM will hoist loop invariants out of the /// loop. For PHI nodes, there may be multiple uses, so compute the nearest /// common dominator for the incoming blocks. static Instruction *getInsertPointForUses(Instruction *User, Value *Def, DominatorTree *DT, LoopInfo *LI) { PHINode *PHI = dyn_cast(User); if (!PHI) return User; Instruction *InsertPt = nullptr; for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) { if (PHI->getIncomingValue(i) != Def) continue; BasicBlock *InsertBB = PHI->getIncomingBlock(i); if (!InsertPt) { InsertPt = InsertBB->getTerminator(); continue; } InsertBB = DT->findNearestCommonDominator(InsertPt->getParent(), InsertBB); InsertPt = InsertBB->getTerminator(); } assert(InsertPt && "Missing phi operand"); auto *DefI = dyn_cast(Def); if (!DefI) return InsertPt; assert(DT->dominates(DefI, InsertPt) && "def does not dominate all uses"); auto *L = LI->getLoopFor(DefI->getParent()); assert(!L || L->contains(LI->getLoopFor(InsertPt->getParent()))); for (auto *DTN = (*DT)[InsertPt->getParent()]; DTN; DTN = DTN->getIDom()) if (LI->getLoopFor(DTN->getBlock()) == L) return DTN->getBlock()->getTerminator(); llvm_unreachable("DefI dominates InsertPt!"); } //===----------------------------------------------------------------------===// // rewriteNonIntegerIVs and helpers. Prefer integer IVs. //===----------------------------------------------------------------------===// /// Convert APF to an integer, if possible. static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { bool isExact = false; // See if we can convert this to an int64_t uint64_t UIntVal; if (APF.convertToInteger(&UIntVal, 64, true, APFloat::rmTowardZero, &isExact) != APFloat::opOK || !isExact) return false; IntVal = UIntVal; return true; } /// If the loop has floating induction variable then insert corresponding /// integer induction variable if possible. /// For example, /// for(double i = 0; i < 10000; ++i) /// bar(i) /// is converted into /// for(int i = 0; i < 10000; ++i) /// bar((double)i); /// void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); unsigned BackEdge = IncomingEdge^1; // Check incoming value. auto *InitValueVal = dyn_cast(PN->getIncomingValue(IncomingEdge)); int64_t InitValue; if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue)) return; // Check IV increment. Reject this PN if increment operation is not // an add or increment value can not be represented by an integer. auto *Incr = dyn_cast(PN->getIncomingValue(BackEdge)); if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return; // If this is not an add of the PHI with a constantfp, or if the constant fp // is not an integer, bail out. ConstantFP *IncValueVal = dyn_cast(Incr->getOperand(1)); int64_t IncValue; if (IncValueVal == nullptr || Incr->getOperand(0) != PN || !ConvertToSInt(IncValueVal->getValueAPF(), IncValue)) return; // Check Incr uses. One user is PN and the other user is an exit condition // used by the conditional terminator. Value::user_iterator IncrUse = Incr->user_begin(); Instruction *U1 = cast(*IncrUse++); if (IncrUse == Incr->user_end()) return; Instruction *U2 = cast(*IncrUse++); if (IncrUse != Incr->user_end()) return; // Find exit condition, which is an fcmp. If it doesn't exist, or if it isn't // only used by a branch, we can't transform it. FCmpInst *Compare = dyn_cast(U1); if (!Compare) Compare = dyn_cast(U2); if (!Compare || !Compare->hasOneUse() || !isa(Compare->user_back())) return; BranchInst *TheBr = cast(Compare->user_back()); // We need to verify that the branch actually controls the iteration count // of the loop. If not, the new IV can overflow and no one will notice. // The branch block must be in the loop and one of the successors must be out // of the loop. assert(TheBr->isConditional() && "Can't use fcmp if not conditional"); if (!L->contains(TheBr->getParent()) || (L->contains(TheBr->getSuccessor(0)) && L->contains(TheBr->getSuccessor(1)))) return; // If it isn't a comparison with an integer-as-fp (the exit value), we can't // transform it. ConstantFP *ExitValueVal = dyn_cast(Compare->getOperand(1)); int64_t ExitValue; if (ExitValueVal == nullptr || !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue)) return; // Find new predicate for integer comparison. CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE; switch (Compare->getPredicate()) { default: return; // Unknown comparison. case CmpInst::FCMP_OEQ: case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break; case CmpInst::FCMP_ONE: case CmpInst::FCMP_UNE: NewPred = CmpInst::ICMP_NE; break; case CmpInst::FCMP_OGT: case CmpInst::FCMP_UGT: NewPred = CmpInst::ICMP_SGT; break; case CmpInst::FCMP_OGE: case CmpInst::FCMP_UGE: NewPred = CmpInst::ICMP_SGE; break; case CmpInst::FCMP_OLT: case CmpInst::FCMP_ULT: NewPred = CmpInst::ICMP_SLT; break; case CmpInst::FCMP_OLE: case CmpInst::FCMP_ULE: NewPred = CmpInst::ICMP_SLE; break; } // We convert the floating point induction variable to a signed i32 value if // we can. This is only safe if the comparison will not overflow in a way // that won't be trapped by the integer equivalent operations. Check for this // now. // TODO: We could use i64 if it is native and the range requires it. // The start/stride/exit values must all fit in signed i32. if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue)) return; // If not actually striding (add x, 0.0), avoid touching the code. if (IncValue == 0) return; // Positive and negative strides have different safety conditions. if (IncValue > 0) { // If we have a positive stride, we require the init to be less than the // exit value. if (InitValue >= ExitValue) return; uint32_t Range = uint32_t(ExitValue-InitValue); // Check for infinite loop, either: // while (i <= Exit) or until (i > Exit) if (NewPred == CmpInst::ICMP_SLE || NewPred == CmpInst::ICMP_SGT) { if (++Range == 0) return; // Range overflows. } unsigned Leftover = Range % uint32_t(IncValue); // If this is an equality comparison, we require that the strided value // exactly land on the exit value, otherwise the IV condition will wrap // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) return; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue) return; } else { // If we have a negative stride, we require the init to be greater than the // exit value. if (InitValue <= ExitValue) return; uint32_t Range = uint32_t(InitValue-ExitValue); // Check for infinite loop, either: // while (i >= Exit) or until (i < Exit) if (NewPred == CmpInst::ICMP_SGE || NewPred == CmpInst::ICMP_SLT) { if (++Range == 0) return; // Range overflows. } unsigned Leftover = Range % uint32_t(-IncValue); // If this is an equality comparison, we require that the strided value // exactly land on the exit value, otherwise the IV condition will wrap // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) return; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue) return; } IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext()); // Insert new integer induction variable. PHINode *NewPHI = PHINode::Create(Int32Ty, 2, PN->getName()+".int", PN); NewPHI->addIncoming(ConstantInt::get(Int32Ty, InitValue), PN->getIncomingBlock(IncomingEdge)); Value *NewAdd = BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue), Incr->getName()+".int", Incr); NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge)); ICmpInst *NewCompare = new ICmpInst(TheBr, NewPred, NewAdd, ConstantInt::get(Int32Ty, ExitValue), Compare->getName()); // In the following deletions, PN may become dead and may be deleted. // Use a WeakVH to observe whether this happens. WeakVH WeakPH = PN; // Delete the old floating point exit comparison. The branch starts using the // new comparison. NewCompare->takeName(Compare); Compare->replaceAllUsesWith(NewCompare); RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI); // Delete the old floating point increment. Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI); // If the FP induction variable still has uses, this is because something else // in the loop uses its value. In order to canonicalize the induction // variable, we chose to eliminate the IV and rewrite it in terms of an // int->fp cast. // // We give preference to sitofp over uitofp because it is faster on most // platforms. if (WeakPH) { Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv", &*PN->getParent()->getFirstInsertionPt()); PN->replaceAllUsesWith(Conv); RecursivelyDeleteTriviallyDeadInstructions(PN, TLI); } Changed = true; } void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { // First step. Check to see if there are any floating-point recurrences. // If there are, change them into integer recurrences, permitting analysis by // the SCEV routines. // BasicBlock *Header = L->getHeader(); SmallVector PHIs; for (BasicBlock::iterator I = Header->begin(); PHINode *PN = dyn_cast(I); ++I) PHIs.push_back(PN); for (unsigned i = 0, e = PHIs.size(); i != e; ++i) if (PHINode *PN = dyn_cast_or_null(&*PHIs[i])) handleFloatingPointIV(L, PN); // If the loop previously had floating-point IV, ScalarEvolution // may not have been able to compute a trip count. Now that we've done some // re-writing, the trip count may be computable. if (Changed) SE->forgetLoop(L); } namespace { // Collect information about PHI nodes which can be transformed in // rewriteLoopExitValues. struct RewritePhi { PHINode *PN; unsigned Ith; // Ith incoming value. Value *Val; // Exit value after expansion. bool HighCost; // High Cost when expansion. RewritePhi(PHINode *P, unsigned I, Value *V, bool H) : PN(P), Ith(I), Val(V), HighCost(H) {} }; } Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, Instruction *InsertPt, Type *ResultTy) { // Before expanding S into an expensive LLVM expression, see if we can use an // already existing value as the expansion for S. - if (Value *ExistingValue = Rewriter.findExistingExpansion(S, InsertPt, L)) + if (Value *ExistingValue = Rewriter.getExactExistingExpansion(S, InsertPt, L)) if (ExistingValue->getType() == ResultTy) return ExistingValue; // We didn't find anything, fall back to using SCEVExpander. return Rewriter.expandCodeFor(S, ResultTy, InsertPt); } //===----------------------------------------------------------------------===// // rewriteLoopExitValues - Optimize IV users outside the loop. // As a side effect, reduces the amount of IV processing within the loop. //===----------------------------------------------------------------------===// /// Check to see if this loop has a computable loop-invariant execution count. /// If so, this means that we can compute the final value of any expressions /// that are recurrent in the loop, and substitute the exit values from the loop /// into any instructions outside of the loop that use the final values of the /// current expressions. /// /// This is mostly redundant with the regular IndVarSimplify activities that /// happen later, except that it's more powerful in some cases, because it's /// able to brute-force evaluate arbitrary instructions as long as they have /// constant operands at the beginning of the loop. void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Check a pre-condition. assert(L->isRecursivelyLCSSAForm(*DT) && "Indvars did not preserve LCSSA!"); SmallVector ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); SmallVector RewritePhiSet; // Find all values that are computed inside the loop, but used outside of it. // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan // the exit blocks of the loop to find them. for (BasicBlock *ExitBB : ExitBlocks) { // If there are no PHI nodes in this exit block, then no values defined // inside the loop are used on this path, skip it. PHINode *PN = dyn_cast(ExitBB->begin()); if (!PN) continue; unsigned NumPreds = PN->getNumIncomingValues(); // Iterate over all of the PHI nodes. BasicBlock::iterator BBI = ExitBB->begin(); while ((PN = dyn_cast(BBI++))) { if (PN->use_empty()) continue; // dead use, don't replace it if (!SE->isSCEVable(PN->getType())) continue; // It's necessary to tell ScalarEvolution about this explicitly so that // it can walk the def-use list and forget all SCEVs, as it may not be // watching the PHI itself. Once the new exit value is in place, there // may not be a def-use connection between the loop and every instruction // which got a SCEVAddRecExpr for that loop. SE->forgetValue(PN); // Iterate over all of the values in all the PHI nodes. for (unsigned i = 0; i != NumPreds; ++i) { // If the value being merged in is not integer or is not defined // in the loop, skip it. Value *InVal = PN->getIncomingValue(i); if (!isa(InVal)) continue; // If this pred is for a subloop, not L itself, skip it. if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) continue; // The Block is in a subloop, skip it. // Check that InVal is defined in the loop. Instruction *Inst = cast(InVal); if (!L->contains(Inst)) continue; // Okay, this instruction has a user outside of the current loop // and varies predictably *inside* the loop. Evaluate the value it // contains when the loop exits, if possible. const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); if (!SE->isLoopInvariant(ExitValue, L) || !isSafeToExpand(ExitValue, *SE)) continue; // Computing the value outside of the loop brings no benefit if : // - it is definitely used inside the loop in a way which can not be // optimized away. // - no use outside of the loop can take advantage of hoisting the // computation out of the loop if (ExitValue->getSCEVType()>=scMulExpr) { unsigned NumHardInternalUses = 0; unsigned NumSoftExternalUses = 0; unsigned NumUses = 0; for (auto IB = Inst->user_begin(), IE = Inst->user_end(); IB != IE && NumUses <= 6; ++IB) { Instruction *UseInstr = cast(*IB); unsigned Opc = UseInstr->getOpcode(); NumUses++; if (L->contains(UseInstr)) { if (Opc == Instruction::Call || Opc == Instruction::Ret) NumHardInternalUses++; } else { if (Opc == Instruction::PHI) { // Do not count the Phi as a use. LCSSA may have inserted // plenty of trivial ones. NumUses--; for (auto PB = UseInstr->user_begin(), PE = UseInstr->user_end(); PB != PE && NumUses <= 6; ++PB, ++NumUses) { unsigned PhiOpc = cast(*PB)->getOpcode(); if (PhiOpc != Instruction::Call && PhiOpc != Instruction::Ret) NumSoftExternalUses++; } continue; } if (Opc != Instruction::Call && Opc != Instruction::Ret) NumSoftExternalUses++; } } if (NumUses <= 6 && NumHardInternalUses && !NumSoftExternalUses) continue; } bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); Value *ExitVal = expandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType()); DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n' << " LoopVal = " << *Inst << "\n"); if (!isValidRewrite(Inst, ExitVal)) { DeadInsts.push_back(ExitVal); continue; } // Collect all the candidate PHINodes to be rewritten. RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); } } } bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); // Transformation. for (const RewritePhi &Phi : RewritePhiSet) { PHINode *PN = Phi.PN; Value *ExitVal = Phi.Val; // Only do the rewrite when the ExitValue can be expanded cheaply. // If LoopCanBeDel is true, rewrite exit value aggressively. if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { DeadInsts.push_back(ExitVal); continue; } Changed = true; ++NumReplaced; Instruction *Inst = cast(PN->getIncomingValue(Phi.Ith)); PN->setIncomingValue(Phi.Ith, ExitVal); // If this instruction is dead now, delete it. Don't do it now to avoid // invalidating iterators. if (isInstructionTriviallyDead(Inst, TLI)) DeadInsts.push_back(Inst); // Replace PN with ExitVal if that is legal and does not break LCSSA. if (PN->getNumIncomingValues() == 1 && LI->replacementPreservesLCSSAForm(PN, ExitVal)) { PN->replaceAllUsesWith(ExitVal); PN->eraseFromParent(); } } // The insertion point instruction may have been deleted; clear it out // so that the rewriter doesn't trip over it later. Rewriter.clearInsertPoint(); } //===---------------------------------------------------------------------===// // rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know // they will exit at the first iteration. //===---------------------------------------------------------------------===// /// Check to see if this loop has loop invariant conditions which lead to loop /// exits. If so, we know that if the exit path is taken, it is at the first /// loop iteration. This lets us predict exit values of PHI nodes that live in /// loop header. void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { // Verify the input to the pass is already in LCSSA form. assert(L->isLCSSAForm(*DT)); SmallVector ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); auto *LoopHeader = L->getHeader(); assert(LoopHeader && "Invalid loop"); for (auto *ExitBB : ExitBlocks) { BasicBlock::iterator BBI = ExitBB->begin(); // If there are no more PHI nodes in this exit block, then no more // values defined inside the loop are used on this path. while (auto *PN = dyn_cast(BBI++)) { for (unsigned IncomingValIdx = 0, E = PN->getNumIncomingValues(); IncomingValIdx != E; ++IncomingValIdx) { auto *IncomingBB = PN->getIncomingBlock(IncomingValIdx); // We currently only support loop exits from loop header. If the // incoming block is not loop header, we need to recursively check // all conditions starting from loop header are loop invariants. // Additional support might be added in the future. if (IncomingBB != LoopHeader) continue; // Get condition that leads to the exit path. auto *TermInst = IncomingBB->getTerminator(); Value *Cond = nullptr; if (auto *BI = dyn_cast(TermInst)) { // Must be a conditional branch, otherwise the block // should not be in the loop. Cond = BI->getCondition(); } else if (auto *SI = dyn_cast(TermInst)) Cond = SI->getCondition(); else continue; if (!L->isLoopInvariant(Cond)) continue; auto *ExitVal = dyn_cast(PN->getIncomingValue(IncomingValIdx)); // Only deal with PHIs. if (!ExitVal) continue; // If ExitVal is a PHI on the loop header, then we know its // value along this exit because the exit can only be taken // on the first iteration. auto *LoopPreheader = L->getLoopPreheader(); assert(LoopPreheader && "Invalid loop"); int PreheaderIdx = ExitVal->getBasicBlockIndex(LoopPreheader); if (PreheaderIdx != -1) { assert(ExitVal->getParent() == LoopHeader && "ExitVal must be in loop header"); PN->setIncomingValue(IncomingValIdx, ExitVal->getIncomingValue(PreheaderIdx)); } } } } } /// Check whether it is possible to delete the loop after rewriting exit /// value. If it is possible, ignore ReplaceExitValue and do rewriting /// aggressively. bool IndVarSimplify::canLoopBeDeleted( Loop *L, SmallVector &RewritePhiSet) { BasicBlock *Preheader = L->getLoopPreheader(); // If there is no preheader, the loop will not be deleted. if (!Preheader) return false; // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1. // We obviate multiple ExitingBlocks case for simplicity. // TODO: If we see testcase with multiple ExitingBlocks can be deleted // after exit value rewriting, we can enhance the logic here. SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); SmallVector ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); if (ExitBlocks.size() > 1 || ExitingBlocks.size() > 1) return false; BasicBlock *ExitBlock = ExitBlocks[0]; BasicBlock::iterator BI = ExitBlock->begin(); while (PHINode *P = dyn_cast(BI)) { Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); // If the Incoming value of P is found in RewritePhiSet, we know it // could be rewritten to use a loop invariant value in transformation // phase later. Skip it in the loop invariant check below. bool found = false; for (const RewritePhi &Phi : RewritePhiSet) { unsigned i = Phi.Ith; if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { found = true; break; } } Instruction *I; if (!found && (I = dyn_cast(Incoming))) if (!L->hasLoopInvariantOperands(I)) return false; ++BI; } for (auto *BB : L->blocks()) if (any_of(*BB, [](Instruction &I) { return I.mayHaveSideEffects(); })) return false; return true; } //===----------------------------------------------------------------------===// // IV Widening - Extend the width of an IV to cover its widest uses. //===----------------------------------------------------------------------===// namespace { // Collect information about induction variables that are used by sign/zero // extend operations. This information is recorded by CollectExtend and provides // the input to WidenIV. struct WideIVInfo { PHINode *NarrowIV = nullptr; Type *WidestNativeType = nullptr; // Widest integer type created [sz]ext bool IsSigned = false; // Was a sext user seen before a zext? }; } /// Update information about the induction variable that is extended by this /// sign or zero extend operation. This is used to determine the final width of /// the IV before actually widening it. static void visitIVCast(CastInst *Cast, WideIVInfo &WI, ScalarEvolution *SE, const TargetTransformInfo *TTI) { bool IsSigned = Cast->getOpcode() == Instruction::SExt; if (!IsSigned && Cast->getOpcode() != Instruction::ZExt) return; Type *Ty = Cast->getType(); uint64_t Width = SE->getTypeSizeInBits(Ty); if (!Cast->getModule()->getDataLayout().isLegalInteger(Width)) return; // Check that `Cast` actually extends the induction variable (we rely on this // later). This takes care of cases where `Cast` is extending a truncation of // the narrow induction variable, and thus can end up being narrower than the // "narrow" induction variable. uint64_t NarrowIVWidth = SE->getTypeSizeInBits(WI.NarrowIV->getType()); if (NarrowIVWidth >= Width) return; // Cast is either an sext or zext up to this point. // We should not widen an indvar if arithmetics on the wider indvar are more // expensive than those on the narrower indvar. We check only the cost of ADD // because at least an ADD is required to increment the induction variable. We // could compute more comprehensively the cost of all instructions on the // induction variable when necessary. if (TTI && TTI->getArithmeticInstrCost(Instruction::Add, Ty) > TTI->getArithmeticInstrCost(Instruction::Add, Cast->getOperand(0)->getType())) { return; } if (!WI.WidestNativeType) { WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); WI.IsSigned = IsSigned; return; } // We extend the IV to satisfy the sign of its first user, arbitrarily. if (WI.IsSigned != IsSigned) return; if (Width > SE->getTypeSizeInBits(WI.WidestNativeType)) WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); } namespace { /// Record a link in the Narrow IV def-use chain along with the WideIV that /// computes the same value as the Narrow IV def. This avoids caching Use* /// pointers. struct NarrowIVDefUse { Instruction *NarrowDef = nullptr; Instruction *NarrowUse = nullptr; Instruction *WideDef = nullptr; // True if the narrow def is never negative. Tracking this information lets // us use a sign extension instead of a zero extension or vice versa, when // profitable and legal. bool NeverNegative = false; NarrowIVDefUse(Instruction *ND, Instruction *NU, Instruction *WD, bool NeverNegative) : NarrowDef(ND), NarrowUse(NU), WideDef(WD), NeverNegative(NeverNegative) {} }; /// The goal of this transform is to remove sign and zero extends without /// creating any new induction variables. To do this, it creates a new phi of /// the wider type and redirects all users, either removing extends or inserting /// truncs whenever we stop propagating the type. /// class WidenIV { // Parameters PHINode *OrigPhi; Type *WideType; bool IsSigned; // Context LoopInfo *LI; Loop *L; ScalarEvolution *SE; DominatorTree *DT; // Result PHINode *WidePhi; Instruction *WideInc; const SCEV *WideIncExpr; SmallVectorImpl &DeadInsts; SmallPtrSet Widened; SmallVector NarrowIVUsers; public: WidenIV(const WideIVInfo &WI, LoopInfo *LInfo, ScalarEvolution *SEv, DominatorTree *DTree, SmallVectorImpl &DI) : OrigPhi(WI.NarrowIV), WideType(WI.WidestNativeType), IsSigned(WI.IsSigned), LI(LInfo), L(LI->getLoopFor(OrigPhi->getParent())), SE(SEv), DT(DTree), WidePhi(nullptr), WideInc(nullptr), WideIncExpr(nullptr), DeadInsts(DI) { assert(L->getHeader() == OrigPhi->getParent() && "Phi must be an IV"); } PHINode *createWideIV(SCEVExpander &Rewriter); protected: Value *createExtendInst(Value *NarrowOper, Type *WideType, bool IsSigned, Instruction *Use); Instruction *cloneIVUser(NarrowIVDefUse DU, const SCEVAddRecExpr *WideAR); Instruction *cloneArithmeticIVUser(NarrowIVDefUse DU, const SCEVAddRecExpr *WideAR); Instruction *cloneBitwiseIVUser(NarrowIVDefUse DU); const SCEVAddRecExpr *getWideRecurrence(Instruction *NarrowUse); const SCEVAddRecExpr* getExtendedOperandRecurrence(NarrowIVDefUse DU); const SCEV *getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, unsigned OpCode) const; Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); bool widenLoopCompare(NarrowIVDefUse DU); void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); }; } // anonymous namespace /// Perform a quick domtree based check for loop invariance assuming that V is /// used within the loop. LoopInfo::isLoopInvariant() seems gratuitous for this /// purpose. static bool isLoopInvariant(Value *V, const Loop *L, const DominatorTree *DT) { Instruction *Inst = dyn_cast(V); if (!Inst) return true; return DT->properlyDominates(Inst->getParent(), L->getHeader()); } Value *WidenIV::createExtendInst(Value *NarrowOper, Type *WideType, bool IsSigned, Instruction *Use) { // Set the debug location and conservative insertion point. IRBuilder<> Builder(Use); // Hoist the insertion point into loop preheaders as far as possible. for (const Loop *L = LI->getLoopFor(Use->getParent()); L && L->getLoopPreheader() && isLoopInvariant(NarrowOper, L, DT); L = L->getParentLoop()) Builder.SetInsertPoint(L->getLoopPreheader()->getTerminator()); return IsSigned ? Builder.CreateSExt(NarrowOper, WideType) : Builder.CreateZExt(NarrowOper, WideType); } /// Instantiate a wide operation to replace a narrow operation. This only needs /// to handle operations that can evaluation to SCEVAddRec. It can safely return /// 0 for any operation we decide not to clone. Instruction *WidenIV::cloneIVUser(NarrowIVDefUse DU, const SCEVAddRecExpr *WideAR) { unsigned Opcode = DU.NarrowUse->getOpcode(); switch (Opcode) { default: return nullptr; case Instruction::Add: case Instruction::Mul: case Instruction::UDiv: case Instruction::Sub: return cloneArithmeticIVUser(DU, WideAR); case Instruction::And: case Instruction::Or: case Instruction::Xor: case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: return cloneBitwiseIVUser(DU); } } Instruction *WidenIV::cloneBitwiseIVUser(NarrowIVDefUse DU) { Instruction *NarrowUse = DU.NarrowUse; Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; DEBUG(dbgs() << "Cloning bitwise IVUser: " << *NarrowUse << "\n"); // Replace NarrowDef operands with WideDef. Otherwise, we don't know anything // about the narrow operand yet so must insert a [sz]ext. It is probably loop // invariant and will be folded or hoisted. If it actually comes from a // widened IV, it should be removed during a future call to widenIVUse. Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) ? WideDef : createExtendInst(NarrowUse->getOperand(0), WideType, IsSigned, NarrowUse); Value *RHS = (NarrowUse->getOperand(1) == NarrowDef) ? WideDef : createExtendInst(NarrowUse->getOperand(1), WideType, IsSigned, NarrowUse); auto *NarrowBO = cast(NarrowUse); auto *WideBO = BinaryOperator::Create(NarrowBO->getOpcode(), LHS, RHS, NarrowBO->getName()); IRBuilder<> Builder(NarrowUse); Builder.Insert(WideBO); WideBO->copyIRFlags(NarrowBO); return WideBO; } Instruction *WidenIV::cloneArithmeticIVUser(NarrowIVDefUse DU, const SCEVAddRecExpr *WideAR) { Instruction *NarrowUse = DU.NarrowUse; Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; DEBUG(dbgs() << "Cloning arithmetic IVUser: " << *NarrowUse << "\n"); unsigned IVOpIdx = (NarrowUse->getOperand(0) == NarrowDef) ? 0 : 1; // We're trying to find X such that // // Widen(NarrowDef `op` NonIVNarrowDef) == WideAR == WideDef `op.wide` X // // We guess two solutions to X, sext(NonIVNarrowDef) and zext(NonIVNarrowDef), // and check using SCEV if any of them are correct. // Returns true if extending NonIVNarrowDef according to `SignExt` is a // correct solution to X. auto GuessNonIVOperand = [&](bool SignExt) { const SCEV *WideLHS; const SCEV *WideRHS; auto GetExtend = [this, SignExt](const SCEV *S, Type *Ty) { if (SignExt) return SE->getSignExtendExpr(S, Ty); return SE->getZeroExtendExpr(S, Ty); }; if (IVOpIdx == 0) { WideLHS = SE->getSCEV(WideDef); const SCEV *NarrowRHS = SE->getSCEV(NarrowUse->getOperand(1)); WideRHS = GetExtend(NarrowRHS, WideType); } else { const SCEV *NarrowLHS = SE->getSCEV(NarrowUse->getOperand(0)); WideLHS = GetExtend(NarrowLHS, WideType); WideRHS = SE->getSCEV(WideDef); } // WideUse is "WideDef `op.wide` X" as described in the comment. const SCEV *WideUse = nullptr; switch (NarrowUse->getOpcode()) { default: llvm_unreachable("No other possibility!"); case Instruction::Add: WideUse = SE->getAddExpr(WideLHS, WideRHS); break; case Instruction::Mul: WideUse = SE->getMulExpr(WideLHS, WideRHS); break; case Instruction::UDiv: WideUse = SE->getUDivExpr(WideLHS, WideRHS); break; case Instruction::Sub: WideUse = SE->getMinusSCEV(WideLHS, WideRHS); break; } return WideUse == WideAR; }; bool SignExtend = IsSigned; if (!GuessNonIVOperand(SignExtend)) { SignExtend = !SignExtend; if (!GuessNonIVOperand(SignExtend)) return nullptr; } Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) ? WideDef : createExtendInst(NarrowUse->getOperand(0), WideType, SignExtend, NarrowUse); Value *RHS = (NarrowUse->getOperand(1) == NarrowDef) ? WideDef : createExtendInst(NarrowUse->getOperand(1), WideType, SignExtend, NarrowUse); auto *NarrowBO = cast(NarrowUse); auto *WideBO = BinaryOperator::Create(NarrowBO->getOpcode(), LHS, RHS, NarrowBO->getName()); IRBuilder<> Builder(NarrowUse); Builder.Insert(WideBO); WideBO->copyIRFlags(NarrowBO); return WideBO; } const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, unsigned OpCode) const { if (OpCode == Instruction::Add) return SE->getAddExpr(LHS, RHS); if (OpCode == Instruction::Sub) return SE->getMinusSCEV(LHS, RHS); if (OpCode == Instruction::Mul) return SE->getMulExpr(LHS, RHS); llvm_unreachable("Unsupported opcode."); } /// No-wrap operations can transfer sign extension of their result to their /// operands. Generate the SCEV value for the widened operation without /// actually modifying the IR yet. If the expression after extending the /// operands is an AddRec for this loop, return it. const SCEVAddRecExpr* WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { // Handle the common case of add const unsigned OpCode = DU.NarrowUse->getOpcode(); // Only Add/Sub/Mul instructions supported yet. if (OpCode != Instruction::Add && OpCode != Instruction::Sub && OpCode != Instruction::Mul) return nullptr; // One operand (NarrowDef) has already been extended to WideDef. Now determine // if extending the other will lead to a recurrence. const unsigned ExtendOperIdx = DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU"); const SCEV *ExtendOperExpr = nullptr; const OverflowingBinaryOperator *OBO = cast(DU.NarrowUse); if (IsSigned && OBO->hasNoSignedWrap()) ExtendOperExpr = SE->getSignExtendExpr( SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); else if(!IsSigned && OBO->hasNoUnsignedWrap()) ExtendOperExpr = SE->getZeroExtendExpr( SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); else return nullptr; // When creating this SCEV expr, don't apply the current operations NSW or NUW // flags. This instruction may be guarded by control flow that the no-wrap // behavior depends on. Non-control-equivalent instructions can be mapped to // the same SCEV expression, and it would be incorrect to transfer NSW/NUW // semantics to those operations. const SCEV *lhs = SE->getSCEV(DU.WideDef); const SCEV *rhs = ExtendOperExpr; // Let's swap operands to the initial order for the case of non-commutative // operations, like SUB. See PR21014. if (ExtendOperIdx == 0) std::swap(lhs, rhs); const SCEVAddRecExpr *AddRec = dyn_cast(getSCEVByOpCode(lhs, rhs, OpCode)); if (!AddRec || AddRec->getLoop() != L) return nullptr; return AddRec; } /// Is this instruction potentially interesting for further simplification after /// widening it's type? In other words, can the extend be safely hoisted out of /// the loop with SCEV reducing the value to a recurrence on the same loop. If /// so, return the sign or zero extended recurrence. Otherwise return NULL. const SCEVAddRecExpr *WidenIV::getWideRecurrence(Instruction *NarrowUse) { if (!SE->isSCEVable(NarrowUse->getType())) return nullptr; const SCEV *NarrowExpr = SE->getSCEV(NarrowUse); if (SE->getTypeSizeInBits(NarrowExpr->getType()) >= SE->getTypeSizeInBits(WideType)) { // NarrowUse implicitly widens its operand. e.g. a gep with a narrow // index. So don't follow this use. return nullptr; } const SCEV *WideExpr = IsSigned ? SE->getSignExtendExpr(NarrowExpr, WideType) : SE->getZeroExtendExpr(NarrowExpr, WideType); const SCEVAddRecExpr *AddRec = dyn_cast(WideExpr); if (!AddRec || AddRec->getLoop() != L) return nullptr; return AddRec; } /// This IV user cannot be widen. Replace this use of the original narrow IV /// with a truncation of the new wide IV to isolate and eliminate the narrow IV. static void truncateIVUse(NarrowIVDefUse DU, DominatorTree *DT, LoopInfo *LI) { DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef << " for user " << *DU.NarrowUse << "\n"); IRBuilder<> Builder( getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); Value *Trunc = Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType()); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, Trunc); } /// If the narrow use is a compare instruction, then widen the compare // (and possibly the other operand). The extend operation is hoisted into the // loop preheader as far as possible. bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { ICmpInst *Cmp = dyn_cast(DU.NarrowUse); if (!Cmp) return false; // We can legally widen the comparison in the following two cases: // // - The signedness of the IV extension and comparison match // // - The narrow IV is always positive (and thus its sign extension is equal // to its zero extension). For instance, let's say we're zero extending // %narrow for the following use // // icmp slt i32 %narrow, %val ... (A) // // and %narrow is always positive. Then // // (A) == icmp slt i32 sext(%narrow), sext(%val) // == icmp slt i32 zext(%narrow), sext(%val) if (!(DU.NeverNegative || IsSigned == Cmp->isSigned())) return false; Value *Op = Cmp->getOperand(Cmp->getOperand(0) == DU.NarrowDef ? 1 : 0); unsigned CastWidth = SE->getTypeSizeInBits(Op->getType()); unsigned IVWidth = SE->getTypeSizeInBits(WideType); assert (CastWidth <= IVWidth && "Unexpected width while widening compare."); // Widen the compare instruction. IRBuilder<> Builder( getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); // Widen the other operand of the compare, if necessary. if (CastWidth < IVWidth) { Value *ExtOp = createExtendInst(Op, WideType, Cmp->isSigned(), Cmp); DU.NarrowUse->replaceUsesOfWith(Op, ExtOp); } return true; } /// Determine whether an individual user of the narrow IV can be widened. If so, /// return the wide clone of the user. Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // Stop traversing the def-use chain at inner-loop phis or post-loop phis. if (PHINode *UsePhi = dyn_cast(DU.NarrowUse)) { if (LI->getLoopFor(UsePhi->getParent()) != L) { // For LCSSA phis, sink the truncate outside the loop. // After SimplifyCFG most loop exit targets have a single predecessor. // Otherwise fall back to a truncate within the loop. if (UsePhi->getNumOperands() != 1) truncateIVUse(DU, DT, LI); else { // Widening the PHI requires us to insert a trunc. The logical place // for this trunc is in the same BB as the PHI. This is not possible if // the BB is terminated by a catchswitch. if (isa(UsePhi->getParent()->getTerminator())) return nullptr; PHINode *WidePhi = PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", UsePhi); WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0)); IRBuilder<> Builder(&*WidePhi->getParent()->getFirstInsertionPt()); Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); UsePhi->replaceAllUsesWith(Trunc); DeadInsts.emplace_back(UsePhi); DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi << " to " << *WidePhi << "\n"); } return nullptr; } } // Our raison d'etre! Eliminate sign and zero extension. if (IsSigned ? isa(DU.NarrowUse) : isa(DU.NarrowUse)) { Value *NewDef = DU.WideDef; if (DU.NarrowUse->getType() != WideType) { unsigned CastWidth = SE->getTypeSizeInBits(DU.NarrowUse->getType()); unsigned IVWidth = SE->getTypeSizeInBits(WideType); if (CastWidth < IVWidth) { // The cast isn't as wide as the IV, so insert a Trunc. IRBuilder<> Builder(DU.NarrowUse); NewDef = Builder.CreateTrunc(DU.WideDef, DU.NarrowUse->getType()); } else { // A wider extend was hidden behind a narrower one. This may induce // another round of IV widening in which the intermediate IV becomes // dead. It should be very rare. DEBUG(dbgs() << "INDVARS: New IV " << *WidePhi << " not wide enough to subsume " << *DU.NarrowUse << "\n"); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); NewDef = DU.NarrowUse; } } if (NewDef != DU.NarrowUse) { DEBUG(dbgs() << "INDVARS: eliminating " << *DU.NarrowUse << " replaced by " << *DU.WideDef << "\n"); ++NumElimExt; DU.NarrowUse->replaceAllUsesWith(NewDef); DeadInsts.emplace_back(DU.NarrowUse); } // Now that the extend is gone, we want to expose it's uses for potential // further simplification. We don't need to directly inform SimplifyIVUsers // of the new users, because their parent IV will be processed later as a // new loop phi. If we preserved IVUsers analysis, we would also want to // push the uses of WideDef here. // No further widening is needed. The deceased [sz]ext had done it for us. return nullptr; } // Does this user itself evaluate to a recurrence after widening? const SCEVAddRecExpr *WideAddRec = getWideRecurrence(DU.NarrowUse); if (!WideAddRec) WideAddRec = getExtendedOperandRecurrence(DU); if (!WideAddRec) { // If use is a loop condition, try to promote the condition instead of // truncating the IV first. if (widenLoopCompare(DU)) return nullptr; // This user does not evaluate to a recurence after widening, so don't // follow it. Instead insert a Trunc to kill off the original use, // eventually isolating the original narrow IV so it can be removed. truncateIVUse(DU, DT, LI); return nullptr; } // Assume block terminators cannot evaluate to a recurrence. We can't to // insert a Trunc after a terminator if there happens to be a critical edge. assert(DU.NarrowUse != DU.NarrowUse->getParent()->getTerminator() && "SCEV is not expected to evaluate a block terminator"); // Reuse the IV increment that SCEVExpander created as long as it dominates // NarrowUse. Instruction *WideUse = nullptr; if (WideAddRec == WideIncExpr && Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) WideUse = WideInc; else { WideUse = cloneIVUser(DU, WideAddRec); if (!WideUse) return nullptr; } // Evaluation of WideAddRec ensured that the narrow expression could be // extended outside the loop without overflow. This suggests that the wide use // evaluates to the same expression as the extended narrow use, but doesn't // absolutely guarantee it. Hence the following failsafe check. In rare cases // where it fails, we simply throw away the newly created wide use. if (WideAddRec != SE->getSCEV(WideUse)) { DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": " << *SE->getSCEV(WideUse) << " != " << *WideAddRec << "\n"); DeadInsts.emplace_back(WideUse); return nullptr; } // Returning WideUse pushes it on the worklist. return WideUse; } /// Add eligible users of NarrowDef to NarrowIVUsers. /// void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { const SCEV *NarrowSCEV = SE->getSCEV(NarrowDef); bool NeverNegative = SE->isKnownPredicate(ICmpInst::ICMP_SGE, NarrowSCEV, SE->getConstant(NarrowSCEV->getType(), 0)); for (User *U : NarrowDef->users()) { Instruction *NarrowUser = cast(U); // Handle data flow merges and bizarre phi cycles. if (!Widened.insert(NarrowUser).second) continue; NarrowIVUsers.emplace_back(NarrowDef, NarrowUser, WideDef, NeverNegative); } } /// Process a single induction variable. First use the SCEVExpander to create a /// wide induction variable that evaluates to the same recurrence as the /// original narrow IV. Then use a worklist to forward traverse the narrow IV's /// def-use chain. After widenIVUse has processed all interesting IV users, the /// narrow IV will be isolated for removal by DeleteDeadPHIs. /// /// It would be simpler to delete uses as they are processed, but we must avoid /// invalidating SCEV expressions. /// PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // Is this phi an induction variable? const SCEVAddRecExpr *AddRec = dyn_cast(SE->getSCEV(OrigPhi)); if (!AddRec) return nullptr; // Widen the induction variable expression. const SCEV *WideIVExpr = IsSigned ? SE->getSignExtendExpr(AddRec, WideType) : SE->getZeroExtendExpr(AddRec, WideType); assert(SE->getEffectiveSCEVType(WideIVExpr->getType()) == WideType && "Expect the new IV expression to preserve its type"); // Can the IV be extended outside the loop without overflow? AddRec = dyn_cast(WideIVExpr); if (!AddRec || AddRec->getLoop() != L) return nullptr; // An AddRec must have loop-invariant operands. Since this AddRec is // materialized by a loop header phi, the expression cannot have any post-loop // operands, so they must dominate the loop header. assert( SE->properlyDominates(AddRec->getStart(), L->getHeader()) && SE->properlyDominates(AddRec->getStepRecurrence(*SE), L->getHeader()) && "Loop header phi recurrence inputs do not dominate the loop"); // The rewriter provides a value for the desired IV expression. This may // either find an existing phi or materialize a new one. Either way, we // expect a well-formed cyclic phi-with-increments. i.e. any operand not part // of the phi-SCC dominates the loop entry. Instruction *InsertPt = &L->getHeader()->front(); WidePhi = cast(Rewriter.expandCodeFor(AddRec, WideType, InsertPt)); // Remembering the WideIV increment generated by SCEVExpander allows // widenIVUse to reuse it when widening the narrow IV's increment. We don't // employ a general reuse mechanism because the call above is the only call to // SCEVExpander. Henceforth, we produce 1-to-1 narrow to wide uses. if (BasicBlock *LatchBlock = L->getLoopLatch()) { WideInc = cast(WidePhi->getIncomingValueForBlock(LatchBlock)); WideIncExpr = SE->getSCEV(WideInc); } DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); ++NumWidened; // Traverse the def-use chain using a worklist starting at the original IV. assert(Widened.empty() && NarrowIVUsers.empty() && "expect initial state" ); Widened.insert(OrigPhi); pushNarrowIVUsers(OrigPhi, WidePhi); while (!NarrowIVUsers.empty()) { NarrowIVDefUse DU = NarrowIVUsers.pop_back_val(); // Process a def-use edge. This may replace the use, so don't hold a // use_iterator across it. Instruction *WideUse = widenIVUse(DU, Rewriter); // Follow all def-use edges from the previous narrow use. if (WideUse) pushNarrowIVUsers(DU.NarrowUse, WideUse); // widenIVUse may have removed the def-use edge. if (DU.NarrowDef->use_empty()) DeadInsts.emplace_back(DU.NarrowDef); } return WidePhi; } //===----------------------------------------------------------------------===// // Live IV Reduction - Minimize IVs live across the loop. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Simplification of IV users based on SCEV evaluation. //===----------------------------------------------------------------------===// namespace { class IndVarSimplifyVisitor : public IVVisitor { ScalarEvolution *SE; const TargetTransformInfo *TTI; PHINode *IVPhi; public: WideIVInfo WI; IndVarSimplifyVisitor(PHINode *IV, ScalarEvolution *SCEV, const TargetTransformInfo *TTI, const DominatorTree *DTree) : SE(SCEV), TTI(TTI), IVPhi(IV) { DT = DTree; WI.NarrowIV = IVPhi; } // Implement the interface used by simplifyUsersOfIV. void visitCast(CastInst *Cast) override { visitIVCast(Cast, WI, SE, TTI); } }; } /// Iteratively perform simplification on a worklist of IV users. Each /// successive simplification may push more users which may themselves be /// candidates for simplification. /// /// Sign/Zero extend elimination is interleaved with IV simplification. /// void IndVarSimplify::simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI) { SmallVector WideIVs; SmallVector LoopPhis; for (BasicBlock::iterator I = L->getHeader()->begin(); isa(I); ++I) { LoopPhis.push_back(cast(I)); } // Each round of simplification iterates through the SimplifyIVUsers worklist // for all current phis, then determines whether any IVs can be // widened. Widening adds new phis to LoopPhis, inducing another round of // simplification on the wide IVs. while (!LoopPhis.empty()) { // Evaluate as many IV expressions as possible before widening any IVs. This // forces SCEV to set no-wrap flags before evaluating sign/zero // extension. The first time SCEV attempts to normalize sign/zero extension, // the result becomes final. So for the most predictable results, we delay // evaluation of sign/zero extend evaluation until needed, and avoid running // other SCEV based analysis prior to simplifyAndExtend. do { PHINode *CurrIV = LoopPhis.pop_back_val(); // Information about sign/zero extensions of CurrIV. IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, &Visitor); if (Visitor.WI.WidestNativeType) { WideIVs.push_back(Visitor.WI); } } while(!LoopPhis.empty()); for (; !WideIVs.empty(); WideIVs.pop_back()) { WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts); if (PHINode *WidePhi = Widener.createWideIV(Rewriter)) { Changed = true; LoopPhis.push_back(WidePhi); } } } } //===----------------------------------------------------------------------===// // linearFunctionTestReplace and its kin. Rewrite the loop exit condition. //===----------------------------------------------------------------------===// /// Return true if this loop's backedge taken count expression can be safely and /// cheaply expanded into an instruction sequence that can be used by /// linearFunctionTestReplace. /// /// TODO: This fails for pointer-type loop counters with greater than one byte /// strides, consequently preventing LFTR from running. For the purpose of LFTR /// we could skip this check in the case that the LFTR loop counter (chosen by /// FindLoopCounter) is also pointer type. Instead, we could directly convert /// the loop test to an inequality test by checking the target data's alignment /// of element types (given that the initial pointer value originates from or is /// used by ABI constrained operation, as opposed to inttoptr/ptrtoint). /// However, we don't yet have a strong motivation for converting loop tests /// into inequality tests. static bool canExpandBackedgeTakenCount(Loop *L, ScalarEvolution *SE, SCEVExpander &Rewriter) { const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); if (isa(BackedgeTakenCount) || BackedgeTakenCount->isZero()) return false; if (!L->getExitingBlock()) return false; // Can't rewrite non-branch yet. if (!isa(L->getExitingBlock()->getTerminator())) return false; if (Rewriter.isHighCostExpansion(BackedgeTakenCount, L)) return false; return true; } /// Return the loop header phi IFF IncV adds a loop invariant value to the phi. static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { Instruction *IncI = dyn_cast(IncV); if (!IncI) return nullptr; switch (IncI->getOpcode()) { case Instruction::Add: case Instruction::Sub: break; case Instruction::GetElementPtr: // An IV counter must preserve its type. if (IncI->getNumOperands() == 2) break; default: return nullptr; } PHINode *Phi = dyn_cast(IncI->getOperand(0)); if (Phi && Phi->getParent() == L->getHeader()) { if (isLoopInvariant(IncI->getOperand(1), L, DT)) return Phi; return nullptr; } if (IncI->getOpcode() == Instruction::GetElementPtr) return nullptr; // Allow add/sub to be commuted. Phi = dyn_cast(IncI->getOperand(1)); if (Phi && Phi->getParent() == L->getHeader()) { if (isLoopInvariant(IncI->getOperand(0), L, DT)) return Phi; } return nullptr; } /// Return the compare guarding the loop latch, or NULL for unrecognized tests. static ICmpInst *getLoopTest(Loop *L) { assert(L->getExitingBlock() && "expected loop exit"); BasicBlock *LatchBlock = L->getLoopLatch(); // Don't bother with LFTR if the loop is not properly simplified. if (!LatchBlock) return nullptr; BranchInst *BI = dyn_cast(L->getExitingBlock()->getTerminator()); assert(BI && "expected exit branch"); return dyn_cast(BI->getCondition()); } /// linearFunctionTestReplace policy. Return true unless we can show that the /// current exit test is already sufficiently canonical. static bool needsLFTR(Loop *L, DominatorTree *DT) { // Do LFTR to simplify the exit condition to an ICMP. ICmpInst *Cond = getLoopTest(L); if (!Cond) return true; // Do LFTR to simplify the exit ICMP to EQ/NE ICmpInst::Predicate Pred = Cond->getPredicate(); if (Pred != ICmpInst::ICMP_NE && Pred != ICmpInst::ICMP_EQ) return true; // Look for a loop invariant RHS Value *LHS = Cond->getOperand(0); Value *RHS = Cond->getOperand(1); if (!isLoopInvariant(RHS, L, DT)) { if (!isLoopInvariant(LHS, L, DT)) return true; std::swap(LHS, RHS); } // Look for a simple IV counter LHS PHINode *Phi = dyn_cast(LHS); if (!Phi) Phi = getLoopPhiForCounter(LHS, L, DT); if (!Phi) return true; // Do LFTR if PHI node is defined in the loop, but is *not* a counter. int Idx = Phi->getBasicBlockIndex(L->getLoopLatch()); if (Idx < 0) return true; // Do LFTR if the exit condition's IV is *not* a simple counter. Value *IncV = Phi->getIncomingValue(Idx); return Phi != getLoopPhiForCounter(IncV, L, DT); } /// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils /// down to checking that all operands are constant and listing instructions /// that may hide undef. static bool hasConcreteDefImpl(Value *V, SmallPtrSetImpl &Visited, unsigned Depth) { if (isa(V)) return !isa(V); if (Depth >= 6) return false; // Conservatively handle non-constant non-instructions. For example, Arguments // may be undef. Instruction *I = dyn_cast(V); if (!I) return false; // Load and return values may be undef. if(I->mayReadFromMemory() || isa(I) || isa(I)) return false; // Optimistically handle other instructions. for (Value *Op : I->operands()) { if (!Visited.insert(Op).second) continue; if (!hasConcreteDefImpl(Op, Visited, Depth+1)) return false; } return true; } /// Return true if the given value is concrete. We must prove that undef can /// never reach it. /// /// TODO: If we decide that this is a good approach to checking for undef, we /// may factor it into a common location. static bool hasConcreteDef(Value *V) { SmallPtrSet Visited; Visited.insert(V); return hasConcreteDefImpl(V, Visited, 0); } /// Return true if this IV has any uses other than the (soon to be rewritten) /// loop exit test. static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) { int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); Value *IncV = Phi->getIncomingValue(LatchIdx); for (User *U : Phi->users()) if (U != Cond && U != IncV) return false; for (User *U : IncV->users()) if (U != Cond && U != Phi) return false; return true; } /// Find an affine IV in canonical form. /// /// BECount may be an i8* pointer type. The pointer difference is already /// valid count without scaling the address stride, so it remains a pointer /// expression as far as SCEV is concerned. /// /// Currently only valid for LFTR. See the comments on hasConcreteDef below. /// /// FIXME: Accept -1 stride and set IVLimit = IVInit - BECount /// /// FIXME: Accept non-unit stride as long as SCEV can reduce BECount * Stride. /// This is difficult in general for SCEV because of potential overflow. But we /// could at least handle constant BECounts. static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, ScalarEvolution *SE, DominatorTree *DT) { uint64_t BCWidth = SE->getTypeSizeInBits(BECount->getType()); Value *Cond = cast(L->getExitingBlock()->getTerminator())->getCondition(); // Loop over all of the PHI nodes, looking for a simple counter. PHINode *BestPhi = nullptr; const SCEV *BestInit = nullptr; BasicBlock *LatchBlock = L->getLoopLatch(); assert(LatchBlock && "needsLFTR should guarantee a loop latch"); const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); for (BasicBlock::iterator I = L->getHeader()->begin(); isa(I); ++I) { PHINode *Phi = cast(I); if (!SE->isSCEVable(Phi->getType())) continue; // Avoid comparing an integer IV against a pointer Limit. if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy()) continue; const SCEVAddRecExpr *AR = dyn_cast(SE->getSCEV(Phi)); if (!AR || AR->getLoop() != L || !AR->isAffine()) continue; // AR may be a pointer type, while BECount is an integer type. // AR may be wider than BECount. With eq/ne tests overflow is immaterial. // AR may not be a narrower type, or we may never exit. uint64_t PhiWidth = SE->getTypeSizeInBits(AR->getType()); if (PhiWidth < BCWidth || !DL.isLegalInteger(PhiWidth)) continue; const SCEV *Step = dyn_cast(AR->getStepRecurrence(*SE)); if (!Step || !Step->isOne()) continue; int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); Value *IncV = Phi->getIncomingValue(LatchIdx); if (getLoopPhiForCounter(IncV, L, DT) != Phi) continue; // Avoid reusing a potentially undef value to compute other values that may // have originally had a concrete definition. if (!hasConcreteDef(Phi)) { // We explicitly allow unknown phis as long as they are already used by // the loop test. In this case we assume that performing LFTR could not // increase the number of undef users. if (ICmpInst *Cond = getLoopTest(L)) { if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L, DT) && Phi != getLoopPhiForCounter(Cond->getOperand(1), L, DT)) { continue; } } } const SCEV *Init = AR->getStart(); if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { // Don't force a live loop counter if another IV can be used. if (AlmostDeadIV(Phi, LatchBlock, Cond)) continue; // Prefer to count-from-zero. This is a more "canonical" counter form. It // also prefers integer to pointer IVs. if (BestInit->isZero() != Init->isZero()) { if (BestInit->isZero()) continue; } // If two IVs both count from zero or both count from nonzero then the // narrower is likely a dead phi that has been widened. Use the wider phi // to allow the other to be eliminated. else if (PhiWidth <= SE->getTypeSizeInBits(BestPhi->getType())) continue; } BestPhi = Phi; BestInit = Init; } return BestPhi; } /// Help linearFunctionTestReplace by generating a value that holds the RHS of /// the new loop test. static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, SCEVExpander &Rewriter, ScalarEvolution *SE) { const SCEVAddRecExpr *AR = dyn_cast(SE->getSCEV(IndVar)); assert(AR && AR->getLoop() == L && AR->isAffine() && "bad loop counter"); const SCEV *IVInit = AR->getStart(); // IVInit may be a pointer while IVCount is an integer when FindLoopCounter // finds a valid pointer IV. Sign extend BECount in order to materialize a // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing // the existing GEPs whenever possible. if (IndVar->getType()->isPointerTy() && !IVCount->getType()->isPointerTy()) { // IVOffset will be the new GEP offset that is interpreted by GEP as a // signed value. IVCount on the other hand represents the loop trip count, // which is an unsigned value. FindLoopCounter only allows induction // variables that have a positive unit stride of one. This means we don't // have to handle the case of negative offsets (yet) and just need to zero // extend IVCount. Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType()); const SCEV *IVOffset = SE->getTruncateOrZeroExtend(IVCount, OfsTy); // Expand the code for the iteration count. assert(SE->isLoopInvariant(IVOffset, L) && "Computed iteration count is not loop invariant!"); BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); Value *GEPOffset = Rewriter.expandCodeFor(IVOffset, OfsTy, BI); Value *GEPBase = IndVar->getIncomingValueForBlock(L->getLoopPreheader()); assert(AR->getStart() == SE->getSCEV(GEPBase) && "bad loop counter"); // We could handle pointer IVs other than i8*, but we need to compensate for // gep index scaling. See canExpandBackedgeTakenCount comments. assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()), cast(GEPBase->getType()) ->getElementType())->isOne() && "unit stride pointer IV must be i8*"); IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); return Builder.CreateGEP(nullptr, GEPBase, GEPOffset, "lftr.limit"); } else { // In any other case, convert both IVInit and IVCount to integers before // comparing. This may result in SCEV expension of pointers, but in practice // SCEV will fold the pointer arithmetic away as such: // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc). // // Valid Cases: (1) both integers is most common; (2) both may be pointers // for simple memset-style loops. // // IVInit integer and IVCount pointer would only occur if a canonical IV // were generated on top of case #2, which is not expected. const SCEV *IVLimit = nullptr; // For unit stride, IVCount = Start + BECount with 2's complement overflow. // For non-zero Start, compute IVCount here. if (AR->getStart()->isZero()) IVLimit = IVCount; else { assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride"); const SCEV *IVInit = AR->getStart(); // For integer IVs, truncate the IV before computing IVInit + BECount. if (SE->getTypeSizeInBits(IVInit->getType()) > SE->getTypeSizeInBits(IVCount->getType())) IVInit = SE->getTruncateExpr(IVInit, IVCount->getType()); IVLimit = SE->getAddExpr(IVInit, IVCount); } // Expand the code for the iteration count. BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); IRBuilder<> Builder(BI); assert(SE->isLoopInvariant(IVLimit, L) && "Computed iteration count is not loop invariant!"); // Ensure that we generate the same type as IndVar, or a smaller integer // type. In the presence of null pointer values, we have an integer type // SCEV expression (IVInit) for a pointer type IV value (IndVar). Type *LimitTy = IVCount->getType()->isPointerTy() ? IndVar->getType() : IVCount->getType(); return Rewriter.expandCodeFor(IVLimit, LimitTy, BI); } } /// This method rewrites the exit condition of the loop to be a canonical != /// comparison against the incremented loop induction variable. This pass is /// able to rewrite the exit tests of any loop where the SCEV analysis can /// determine a loop-invariant trip count of the loop, which is actually a much /// broader range than just linear tests. Value *IndVarSimplify:: linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, PHINode *IndVar, SCEVExpander &Rewriter) { assert(canExpandBackedgeTakenCount(L, SE, Rewriter) && "precondition"); // Initialize CmpIndVar and IVCount to their preincremented values. Value *CmpIndVar = IndVar; const SCEV *IVCount = BackedgeTakenCount; // If the exiting block is the same as the backedge block, we prefer to // compare against the post-incremented value, otherwise we must compare // against the preincremented value. if (L->getExitingBlock() == L->getLoopLatch()) { // Add one to the "backedge-taken" count to get the trip count. // This addition may overflow, which is valid as long as the comparison is // truncated to BackedgeTakenCount->getType(). IVCount = SE->getAddExpr(BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); // The BackedgeTaken expression contains the number of times that the // backedge branches to the loop header. This is one less than the // number of times the loop executes, so use the incremented indvar. CmpIndVar = IndVar->getIncomingValueForBlock(L->getExitingBlock()); } Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE); assert(ExitCnt->getType()->isPointerTy() == IndVar->getType()->isPointerTy() && "genLoopLimit missed a cast"); // Insert a new icmp_ne or icmp_eq instruction before the branch. BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); ICmpInst::Predicate P; if (L->contains(BI->getSuccessor(0))) P = ICmpInst::ICMP_NE; else P = ICmpInst::ICMP_EQ; DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" << " LHS:" << *CmpIndVar << '\n' << " op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==") << "\n" << " RHS:\t" << *ExitCnt << "\n" << " IVCount:\t" << *IVCount << "\n"); IRBuilder<> Builder(BI); // LFTR can ignore IV overflow and truncate to the width of // BECount. This avoids materializing the add(zext(add)) expression. unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType()); unsigned ExitCntSize = SE->getTypeSizeInBits(ExitCnt->getType()); if (CmpIndVarSize > ExitCntSize) { const SCEVAddRecExpr *AR = cast(SE->getSCEV(IndVar)); const SCEV *ARStart = AR->getStart(); const SCEV *ARStep = AR->getStepRecurrence(*SE); // For constant IVCount, avoid truncation. if (isa(ARStart) && isa(IVCount)) { const APInt &Start = cast(ARStart)->getAPInt(); APInt Count = cast(IVCount)->getAPInt(); // Note that the post-inc value of BackedgeTakenCount may have overflowed // above such that IVCount is now zero. if (IVCount != BackedgeTakenCount && Count == 0) { Count = APInt::getMaxValue(Count.getBitWidth()).zext(CmpIndVarSize); ++Count; } else Count = Count.zext(CmpIndVarSize); APInt NewLimit; if (cast(ARStep)->getValue()->isNegative()) NewLimit = Start - Count; else NewLimit = Start + Count; ExitCnt = ConstantInt::get(CmpIndVar->getType(), NewLimit); DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); } else { CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), "lftr.wideiv"); } } Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond"); Value *OrigCond = BI->getCondition(); // It's tempting to use replaceAllUsesWith here to fully replace the old // comparison, but that's not immediately safe, since users of the old // comparison may not be dominated by the new comparison. Instead, just // update the branch to use the new comparison; in the common case this // will make old comparison dead. BI->setCondition(Cond); DeadInsts.push_back(OrigCond); ++NumLFTR; Changed = true; return Cond; } //===----------------------------------------------------------------------===// // sinkUnusedInvariants. A late subpass to cleanup loop preheaders. //===----------------------------------------------------------------------===// /// If there's a single exit block, sink any loop-invariant values that /// were defined in the preheader but not used inside the loop into the /// exit block to reduce register pressure in the loop. void IndVarSimplify::sinkUnusedInvariants(Loop *L) { BasicBlock *ExitBlock = L->getExitBlock(); if (!ExitBlock) return; BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) return; Instruction *InsertPt = &*ExitBlock->getFirstInsertionPt(); BasicBlock::iterator I(Preheader->getTerminator()); while (I != Preheader->begin()) { --I; // New instructions were inserted at the end of the preheader. if (isa(I)) break; // Don't move instructions which might have side effects, since the side // effects need to complete before instructions inside the loop. Also don't // move instructions which might read memory, since the loop may modify // memory. Note that it's okay if the instruction might have undefined // behavior: LoopSimplify guarantees that the preheader dominates the exit // block. if (I->mayHaveSideEffects() || I->mayReadFromMemory()) continue; // Skip debug info intrinsics. if (isa(I)) continue; // Skip eh pad instructions. if (I->isEHPad()) continue; // Don't sink alloca: we never want to sink static alloca's out of the // entry block, and correctly sinking dynamic alloca's requires // checks for stacksave/stackrestore intrinsics. // FIXME: Refactor this check somehow? if (isa(I)) continue; // Determine if there is a use in or before the loop (direct or // otherwise). bool UsedInLoop = false; for (Use &U : I->uses()) { Instruction *User = cast(U.getUser()); BasicBlock *UseBB = User->getParent(); if (PHINode *P = dyn_cast(User)) { unsigned i = PHINode::getIncomingValueNumForOperand(U.getOperandNo()); UseBB = P->getIncomingBlock(i); } if (UseBB == Preheader || L->contains(UseBB)) { UsedInLoop = true; break; } } // If there is, the def must remain in the preheader. if (UsedInLoop) continue; // Otherwise, sink it to the exit block. Instruction *ToMove = &*I; bool Done = false; if (I != Preheader->begin()) { // Skip debug info intrinsics. do { --I; } while (isa(I) && I != Preheader->begin()); if (isa(I) && I == Preheader->begin()) Done = true; } else { Done = true; } ToMove->moveBefore(InsertPt); if (Done) break; InsertPt = ToMove; } } //===----------------------------------------------------------------------===// // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// bool IndVarSimplify::run(Loop *L) { // We need (and expect!) the incoming loop to be in LCSSA. assert(L->isRecursivelyLCSSAForm(*DT) && "LCSSA required to run indvars!"); // If LoopSimplify form is not available, stay out of trouble. Some notes: // - LSR currently only supports LoopSimplify-form loops. Indvars' // canonicalization can be a pessimization without LSR to "clean up" // afterwards. // - We depend on having a preheader; in particular, // Loop::getCanonicalInductionVariable only supports loops with preheaders, // and we're in trouble if we can't find the induction variable even when // we've manually inserted one. if (!L->isLoopSimplifyForm()) return false; // If there are any floating-point recurrences, attempt to // transform them to use integer recurrences. rewriteNonIntegerIVs(L); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); // Create a rewriter object which we'll use to transform the code with. SCEVExpander Rewriter(*SE, DL, "indvars"); #ifndef NDEBUG Rewriter.setDebugType(DEBUG_TYPE); #endif // Eliminate redundant IV users. // // Simplification works best when run before other consumers of SCEV. We // attempt to avoid evaluating SCEVs for sign/zero extend operations until // other expressions involving loop IVs have been evaluated. This helps SCEV // set no-wrap flags before normalizing sign/zero extension. Rewriter.disableCanonicalMode(); simplifyAndExtend(L, Rewriter, LI); // Check to see if this loop has a computable loop-invariant execution count. // If so, this means that we can compute the final value of any expressions // that are recurrent in the loop, and substitute the exit values from the // loop into any instructions outside of the loop that use the final values of // the current expressions. // if (ReplaceExitValue != NeverRepl && !isa(BackedgeTakenCount)) rewriteLoopExitValues(L, Rewriter); // Eliminate redundant IV cycles. NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); // If we have a trip count expression, rewrite the loop's exit condition // using it. We can currently only handle loops with a single exit. if (canExpandBackedgeTakenCount(L, SE, Rewriter) && needsLFTR(L, DT)) { PHINode *IndVar = FindLoopCounter(L, BackedgeTakenCount, SE, DT); if (IndVar) { // Check preconditions for proper SCEVExpander operation. SCEV does not // express SCEVExpander's dependencies, such as LoopSimplify. Instead any // pass that uses the SCEVExpander must do it. This does not work well for // loop passes because SCEVExpander makes assumptions about all loops, // while LoopPassManager only forces the current loop to be simplified. // // FIXME: SCEV expansion has no way to bail out, so the caller must // explicitly check any assumptions made by SCEV. Brittle. const SCEVAddRecExpr *AR = dyn_cast(BackedgeTakenCount); if (!AR || AR->getLoop()->getLoopPreheader()) (void)linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, Rewriter); } } // Clear the rewriter cache, because values that are in the rewriter's cache // can be deleted in the loop below, causing the AssertingVH in the cache to // trigger. Rewriter.clear(); // Now that we're done iterating through lists, clean up any instructions // which are now dead. while (!DeadInsts.empty()) if (Instruction *Inst = dyn_cast_or_null(DeadInsts.pop_back_val())) RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); // The Rewriter may not be used from this point on. // Loop-invariant instructions in the preheader that aren't used in the // loop may be sunk below the loop to reduce register pressure. sinkUnusedInvariants(L); // rewriteFirstIterationLoopExitValues does not rely on the computation of // trip count and therefore can further simplify exit values in addition to // rewriteLoopExitValues. rewriteFirstIterationLoopExitValues(L); // Clean up dead instructions. Changed |= DeleteDeadPHIs(L->getHeader(), TLI); // Check a post-condition. assert(L->isRecursivelyLCSSAForm(*DT) && "Indvars did not preserve LCSSA!"); // Verify that LFTR, and any other change have not interfered with SCEV's // ability to compute trip count. #ifndef NDEBUG if (VerifyIndvars && !isa(BackedgeTakenCount)) { SE->forgetLoop(L); const SCEV *NewBECount = SE->getBackedgeTakenCount(L); if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) < SE->getTypeSizeInBits(NewBECount->getType())) NewBECount = SE->getTruncateOrNoop(NewBECount, BackedgeTakenCount->getType()); else BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, NewBECount->getType()); assert(BackedgeTakenCount == NewBECount && "indvars must preserve SCEV"); } #endif return Changed; } PreservedAnalyses IndVarSimplifyPass::run(Loop &L, AnalysisManager &AM) { auto &FAM = AM.getResult(L).getManager(); Function *F = L.getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); auto *LI = FAM.getCachedResult(*F); auto *SE = FAM.getCachedResult(*F); auto *DT = FAM.getCachedResult(*F); assert((LI && SE && DT) && "Analyses required for indvarsimplify not available!"); // Optional analyses. auto *TTI = FAM.getCachedResult(*F); auto *TLI = FAM.getCachedResult(*F); IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); if (!IVS.run(&L)) return PreservedAnalyses::all(); // FIXME: This should also 'preserve the CFG'. return getLoopPassPreservedAnalyses(); } namespace { struct IndVarSimplifyLegacyPass : public LoopPass { static char ID; // Pass identification, replacement for typeid IndVarSimplifyLegacyPass() : LoopPass(ID) { initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnLoop(Loop *L, LPPassManager &LPM) override { if (skipLoop(L)) return false; auto *LI = &getAnalysis().getLoopInfo(); auto *SE = &getAnalysis().getSE(); auto *DT = &getAnalysis().getDomTree(); auto *TLIP = getAnalysisIfAvailable(); auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; auto *TTIP = getAnalysisIfAvailable(); auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); return IVS.run(L); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); getLoopAnalysisUsage(AU); } }; } char IndVarSimplifyLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", "Induction Variable Simplification", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars", "Induction Variable Simplification", false, false) Pass *llvm::createIndVarSimplifyPass() { return new IndVarSimplifyLegacyPass(); }