{- Tock: a compiler for parallel languages Copyright (C) 2007, 2008, 2009 University of Kent This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} module ArrayUsageCheck ( BackgroundKnowledge(..), BK, canonicalise, checkArrayUsage, findRepSolutions, FlattenedExp(..), fmapFlattenedExp, makeEquations, makeExpSet, ModuloCase(..), onlyConst, showFlattenedExp, VarMap) where import Control.Monad.Error import Control.Monad.Reader import Control.Monad.State import Data.Array.IArray import qualified Data.Foldable as F import Data.Generics (Data, Typeable) import Data.Int import Data.List import qualified Data.Map as Map import Data.Maybe import qualified Data.Set as Set import qualified Data.Traversable as T import qualified AST as A import CompState import Data.Generics.Alloy.Schemes import Errors import Metadata import Omega import OrdAST() import Pass import ShowCode import Types import UsageCheckUtils import Utils -- Each list is a possible set of background knowledge mapping vars to a list -- of constraints. So it is a disjunction of map from variables to conjunctions type BK = [Map.Map Var [BackgroundKnowledge]] type BK' = [Map.Map Var (Either String (EqualityProblem, InequalityProblem))] -- | Given a list of replicators, and a set of background knowledge for each -- access inside the replicator, checks if there are any solutions for a -- combination of the normal replicator constraints, and the given background -- knowledge (pairing each set against each other, applying one set to the replicator, -- and the other to the mirror of the replicator). -- -- Returns Nothing if no solutions, a String with a counter-example if there -- are solutions findRepSolutions :: (CSMR m, MonadIO m) => [(A.Name, A.Replicator)] -> [BK] -> m (Maybe String) findRepSolutions reps bks -- To get the right comparison, we create a SeqItems with all the accesses -- Because they are inside a PAR replicator, they will all get compared to each -- other with one set of BK applied to i and one applied to i', but they will -- never be compared to each other just with the constraints on i (which is not -- what we are checking here). We set the dummy array accesses to all be zero, -- which means they can overlap -- but only if there is also a solution to the -- replicator background knowledge, which is what this function is trying to -- determine. = getCompState >>= \cs -> case flip runReaderT cs $ makeEquations (addReps $ SeqItems [(bk, [makeConstant emptyMeta 0], []) | bk <- bks]) maxInt of Right problems -> do probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems] debug $ "Problems in findRepSolutions:\n" ++ probs case catMaybes [fmap ((,) i) $ solve p | (i::Integer, p) <- zip [0..] problems] of [] -> return Nothing -- No solutions, safe xs -> liftM (Just . unlines) $ mapM format xs res -> error $ "Unexpected reachability result" where maxInt = makeConstant emptyMeta $ fromInteger $ toInteger (maxBound :: Int32) format (i, ((lx,ly),varMapping,vm,problem)) = formatSolution varMapping vm >>* (("#" ++ show i ++ ": ") ++) addReps = flip (foldl $ flip RepParItem) reps -- | A check-pass that checks the given ParItems (usually generated from a control-flow graph) -- for any overlapping array indices. checkArrayUsage :: forall m. (Die m, CSMR m, MonadIO m) => NameAttr -> (Meta, ParItems (BK, UsageLabel)) -> m () checkArrayUsage sharedAttr (m,p) = do indexes <- groupArrayIndexes $ fmap (transformPair id nodeVars) p let filteredIndexes = Map.toList $ Map.filter ((>= 1) . length . map (\(_,w,r) -> w++r) . F.toList) indexes debug $ "checkArrayUsage: " ++ show m ++ ", " ++ show (length filteredIndexes) mapM_ (checkIndexes m) filteredIndexes where getDecl :: UsageLabel -> Maybe String getDecl = join . fmap getScopeIn . nodeDecl where getScopeIn (ScopeIn _ n) = Just n getScopeIn _ = Nothing -- Takes a ParItems Vars, and returns a map from array-variable-name to a list of writes and a list of reads for that array. -- Returns (array name, list of written-to indexes, list of read-from indexes) groupArrayIndexes :: ParItems (BK, Vars) -> m (Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression]))) groupArrayIndexes = liftM filterByKey . T.mapM (\(bk,vs) -> do w <- makeList $ (Map.keysSet $ writtenVars vs) `Set.union` (usedVars vs) r <- makeList $ readVars vs return $ zipMap (join bk) w r) where join :: b -> Maybe [a] -> Maybe [a] -> Maybe (b, [a],[a]) join k x y = Just (k, fromMaybe [] x, fromMaybe [] y) -- Turns a set of variables into a map (from array-name to list of index-expressions) makeList :: Set.Set Var -> m (Map.Map (String, Maybe A.Direction) [A.Expression]) makeList vs = do indexes <- concatMapM getArrayIndex $ Set.toList vs return $ Map.fromListWith (++) indexes -- Lifts a map (from array-name to expression-lists) inside a ParItems to being a map (from array-name to ParItems of expression lists) filterByKey :: ParItems (Map.Map (String, Maybe A.Direction) (BK, [A.Expression], [A.Expression])) -> Map.Map (String, Maybe A.Direction) (ParItems (BK, [A.Expression], [A.Expression])) filterByKey p = Map.fromList $ map trans keys where keys :: [(String, Maybe A.Direction)] keys = concatMap Map.keys $ flattenParItems p trans :: (String, Maybe A.Direction) -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) trans k = (k, fmap (Map.findWithDefault ([], [], []) k) p) -- Gets the (array-name, indexes) from a Var. -- TODO this is quite hacky, and doesn't yet deal with slices and so on: getArrayIndex :: Var -> m [((String, Maybe A.Direction), [A.Expression])] getArrayIndex (Var v@(A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n))) = do t <- astTypeOf v let dirs = case t of A.Chan {} -> [Just A.DirInput, Just A.DirOutput] _ -> [Nothing] return [((A.nameName n, d), [e]) | d <- dirs] getArrayIndex (Var (A.SubscriptedVariable _ (A.Subscript _ _ e) (A.DirectedVariable _ dir (A.Variable _ n)))) = return [((A.nameName n, Just dir), [e])] getArrayIndex (Var (A.DirectedVariable _ dir (A.SubscriptedVariable _ (A.Subscript _ _ e) (A.Variable _ n)))) = return [((A.nameName n, Just dir), [e])] getArrayIndex _ = return [] -- Checks the given ParItems of writes and reads against each other. The -- String (array-name) and Meta are only used for printing out error messages checkIndexes :: Meta -> ((String, Maybe A.Direction), ParItems (BK, [A.Expression], [A.Expression])) -> m () checkIndexes m ((arrName, arrDir), indexes) = do sharedNames <- getCompState >>* csNameAttr let declNames = [x | Just x <- fmap (getDecl . snd) $ flattenParItems p] when (fmap (Set.member sharedAttr) (Map.lookup arrName sharedNames) /= Just True && arrName `notElem` declNames) $ do userArrName <- getRealName (A.Name undefined arrName) arrType <- astTypeOf (A.Name undefined arrName) >>= resolveUserType m arrLength <- case arrType of A.Array (A.Dimension d:_) _ -> return d -- Unknown dimension, use the maximum value for a (assumed 32-bit for INT) integer: A.Array (A.UnknownDimension:_) _ -> return $ makeConstant m $ fromInteger $ toInteger (maxBound :: Int32) -- It's not an array: _ -> dieP m $ "Cannot usage check array \"" ++ userArrName ++ "\"; found to be of type: " ++ show arrType cs <- getCompState case runReaderT (makeEquations indexes arrLength) cs of Left err -> dieP m $ "Could not work with array indexes for array \"" ++ userArrName ++ "\": " ++ err Right [] -> return () -- No problems to work with Right problems -> do probs <- formatProblems [(vm, prob) | (_,vm,prob) <- problems] debug $ "Problems in checkArrayUsage" ++ show m ++ ":\n" ++ probs case mapMaybe solve problems of -- No solutions; no worries! [] -> return () (((lx,ly),varMapping,vm,problem):_) -> do sol <- formatSolution varMapping vm cx <- showCode (fst lx) cy <- showCode (fst ly) -- liftIO $ putStrLn $ "Found solution for problem: " ++ probs -- ++ show p -- liftIO $ putStrLn $ "Succeeded on problem: " ++ prob -- allProbs <- concatMapM (\(_,_,p) -> formatProblem varMapping p >>* (++ "\n#\n")) problems -- svm <- mapM (showFlattenedExp showCode) $ Map.keys varMapping -- liftIO $ putStrLn $ "All problems: " ++ allProbs ++ "\n" ++ (concat $ intersperse " ; " $ svm) dieP m $ "Indexes of array \"" ++ userArrName ++ "\" " ++ "(\"" ++ cx ++ "\" and \"" ++ cy ++ "\") could overlap" ++ if sol /= "" then " when: " ++ sol else "" -- TODO this is surely defined elsewhere already? getRealName :: A.Name -> m String getRealName n = lookupName n >>* A.ndOrigName formatProblems :: CSMR m => [(VarMap, (EqualityProblem, InequalityProblem))] -> m String formatProblems probs = do formatted <- mapM (uncurry formatProblem) probs return $ concat [addNum i (lines p) | (p, i) <- zip formatted [0..]] where addNum :: Int -> [String] -> String addNum i [] = "" addNum i (p:ps) = unlines $ ("#" ++ show i ++ (if length (show i) == 1 then " :" else ":") ++ p) : map (" " ++) ps -- | Formats an entire problem ready to print it out half-legibly for debugging purposes formatProblem :: forall m. CSMR m => VarMap -> (EqualityProblem, InequalityProblem) -> m String formatProblem varToIndex (eq, ineq) = do feqs <- mapM (showWithConst "=") $ eq fineqs <- mapM (\e -> if allNegative e then showWithConst "<=" (negateAll e) else showWithConst ">=" e) $ ineq return $ unlines $ feqs ++ fineqs where --Returns true if all the variable coefficients are negative (ignoring -- the constant term) allNegative :: Array CoeffIndex Integer -> Bool allNegative = all (<= 0) . tail . elems negateAll :: Array CoeffIndex Integer -> Array CoeffIndex Integer negateAll = amap negate showWithConst :: String -> Array CoeffIndex Integer -> m String showWithConst op item = do text <- showEq item return $ (if text == "" then "0" else text) ++ " " ++ op ++ " " ++ show (negate $ item ! 0) showEq :: Array CoeffIndex Integer -> m String showEq = liftM (joinWith " + ") . mapM showItem . filter ((/= 0) . snd) . tail . assocs showItem :: (CoeffIndex, Integer) -> m String showItem (n, a) = case find ((== n) . snd) $ Map.assocs varToIndex of Just (exp,_) -> showFlattenedExp showCode exp >>* (mult ++) Nothing -> return "" where mult = case a of 1 -> "" -1 -> "-" _ -> show a ++ "*" -- | Solves the problem and munges the arguments and results into a useful order solve :: (labels,vm,(EqualityProblem,InequalityProblem)) -> Maybe (labels,vm,VariableMapping,(EqualityProblem,InequalityProblem)) solve (ls,vm,(eq,ineq)) = case solveProblem eq ineq of Nothing -> Nothing Just vm' -> Just (ls,vm,vm',(eq,ineq)) -- | Formats a solution (not a problem, just the solution) ready to print it out for the user formatSolution :: (CSMR m, Monad m) => VarMap -> VariableMapping -> m String formatSolution varToIndex vm = do names <- mapM valOfVar $ Map.assocs varToIndex return $ joinWith " , " $ catMaybes names where indexToVar = flip lookup $ map revPair $ Map.assocs varToIndex indexToVar' (0, x) = Just (Nothing, x) indexToVar' (_, 0) = Nothing indexToVar' (i, x) = case indexToVar i of Just v -> Just (Just v, x) Nothing -> Nothing indexToConst = getCounterEqs vm showWithCoeff' (Nothing, n) = return $ show n showWithCoeff' (Just v, n) = liftM (mult ++) $ showFlattenedExp showCode v where mult = case n of 1 -> "" -1 -> "-" n -> show n ++ "*" showWithCoeff xs = liftM (joinWith " + ") $ mapM showWithCoeff' xs valOfVar (varExp,k) = case Map.lookup k indexToConst of Nothing -> return Nothing Just (Left (n, low, high)) -> do varExp' <- showWithCoeff' (Just varExp, n) low' <- mapM showWithCoeff $ map (mapMaybe indexToVar') low high' <- mapM showWithCoeff $ map (mapMaybe indexToVar') high return $ Just $ formatBounds (++ " <= ") low' ++ varExp' ++ formatBounds (" <= " ++) high' Just (Right val) -> do varExp' <- showFlattenedExp showCode varExp return $ Just $ varExp' ++ " = " ++ show val formatBounds _ [] = "" formatBounds f [b] = f b formatBounds f bs = f $ "(" ++ joinWith "," bs ++ ")" showFlattenedExpSet :: Monad m => (A.Expression -> m String) -> Set.Set FlattenedExp -> m String showFlattenedExpSet showExp s = liftM concat $ sequence $ intersperse (return " + ") $ map (showFlattenedExp showExp) $ Set.toList s -- Shows a FlattenedExp legibly by looking up real names for variables, and formatting things. -- The output for things involving modulo might be a bit odd, but there isn't really anything -- much that can be done about that showFlattenedExp :: Monad m => (A.Expression -> m String) -> FlattenedExp -> m String showFlattenedExp _ (Const n) = return $ show n showFlattenedExp showExp (Scale n (e,vi)) = do vn' <- showExp e >>* (++ replicate vi '\'') return $ showScale vn' n showFlattenedExp showExp (Modulo n top bottom) = do top' <- showFlattenedExpSet showExp top bottom' <- showFlattenedExpSet showExp bottom case onlyConst (Set.toList bottom) of Just _ -> return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") (-n) Nothing -> return $ showScale ("((" ++ top' ++ " REM " ++ bottom' ++ ") - " ++ top' ++ ")") n showFlattenedExp showExp (Divide n top bottom) = do top' <- showFlattenedExpSet showExp top bottom' <- showFlattenedExpSet showExp bottom return $ showScale ("(" ++ top' ++ " / " ++ bottom' ++ ")") n showScale :: String -> Integer -> String showScale s n = case n of 1 -> s -1 -> "-" ++ s _ -> (show n) ++ "*" ++ s -- | A type for inside makeEquations: data FlattenedExp = Const Integer -- ^ A constant | Scale Integer (A.Expression, Int) -- ^ A variable and coefficient. The first argument is the coefficient -- The second part of the pair is for sub-indexing (or "priming") variables. -- For example, replication is done by checking the replicated variable "i" -- against a sub-indexed (with "1") version (denoted "i'"). The sub-index -- is what differentiates i from i', given that they are technically the -- same A.Variable | Modulo Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp) -- ^ A modulo, with a coefficient\/scale and given top and bottom (in that order) | Divide Integer (Set.Set FlattenedExp) (Set.Set FlattenedExp) -- ^ An integer division, with a coefficient\/scale and the given top and bottom (in that order) instance Eq FlattenedExp where a == b = EQ == compare a b -- | A straightforward comparison for FlattenedExp that compares while ignoring -- the value of a const @(Const 3 == Const 5)@ and the value of a scale -- @(Scale 1 (v,0)) == (Scale 3 (v,0))@, although note that @(Scale 1 (v,0)) \/= (Scale 1 (v,1))@. instance Ord FlattenedExp where compare (Const _) (Const _) = EQ compare (Const _) _ = LT compare _ (Const _) = GT compare (Scale _ (lv,li)) (Scale _ (rv,ri)) = combineCompare [compare lv rv, compare li ri] compare (Scale {}) _ = LT compare _ (Scale {}) = GT compare (Modulo _ ltop lbottom) (Modulo _ rtop rbottom) = combineCompare [compare ltop rtop, compare lbottom rbottom] compare (Modulo {}) _ = LT compare _ (Modulo {}) = GT compare (Divide _ ltop lbottom) (Divide _ rtop rbottom) = combineCompare [compare ltop rtop, compare lbottom rbottom] -- | Checks if an expression list contains only constants. Returns Just (the aggregate constant) if so, -- otherwise returns Nothing. onlyConst :: [FlattenedExp] -> Maybe Integer onlyConst [] = Just 0 onlyConst ((Const n):es) = liftM2 (+) (return n) $ onlyConst es onlyConst _ = Nothing fmapFlattenedExp :: (A.Expression -> A.Expression) -> FlattenedExp -> FlattenedExp fmapFlattenedExp f x@(Const _) = x fmapFlattenedExp f (Scale n (e, i)) = Scale n (f e, i) fmapFlattenedExp f (Modulo n top bottom) = Modulo n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom) fmapFlattenedExp f (Divide n top bottom) = Divide n (Set.map (fmapFlattenedExp f) top) (Set.map (fmapFlattenedExp f) bottom) -- | A data type representing an array access. Each triple is (index, extra-equalities, extra-inequalities). -- A Single item can be paired with every other access. -- Each item of a Group cannot be paired with each other, but can be paired with each other access. -- With a Replicated, each item in the left branch can be paired with each item in the right branch. -- Each item in the left branch can be paired with each other, and each item in the left branch can -- be paired with all other items. data ArrayAccess label = Group [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] | Replicated [ArrayAccess label] [ArrayAccess label] -- | A simple data type for denoting whether an array access is a read or a write data ArrayAccessType = AAWrite | AARead -- | Transforms the ParItems (from the control-flow graph) into the more suitable ArrayAccess -- data type used by this array usage checker. parItemToArrayAccessM :: Monad m => ( [((A.Name, A.Replicator), Bool)] -> a -> m [(label, ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] ) -> ParItems a -> m ([ArrayAccess label], [A.Name]) parItemToArrayAccessM f (SeqItems xs) -- Each sequential item is a group of one: = do aas <- sequence [concatMapM (f []) xs >>* Group] return (aas, []) parItemToArrayAccessM f (ParItems ps) = liftM (transformPair concat concat . unzip) $ mapM (parItemToArrayAccessM f) ps parItemToArrayAccessM f (RepParItem rep p) = do (normal, otherReps) <- parItemToArrayAccessM (\reps -> f ((rep,False):reps)) p mirror <- liftM fst $ parItemToArrayAccessM (\reps -> f ((rep,True):reps)) p return ([Replicated normal mirror], fst rep : otherReps) -- | Turns a list of expressions (which may contain many constants, or duplicated variables) -- into a set of expressions with at most one constant term, and at most one appearance -- of a any variable, or distinct modulo\/division of variables. -- If there is any problem (specifically, nested modulo or divisions) an error will be returned instead makeExpSet :: forall m. MonadError String m => [FlattenedExp] -> m (Set.Set FlattenedExp) makeExpSet = foldM makeExpSet' Set.empty where makeExpSet' :: Set.Set FlattenedExp -> FlattenedExp -> m (Set.Set FlattenedExp) makeExpSet' accum (Const n) = return $ insert (addConst n) (Const n) accum makeExpSet' accum (Scale n v) = return $ insert (addScale n v) (Scale n v) accum makeExpSet' accum m@(Modulo {}) | Set.member m accum = throwError "Cannot have repeated REM items in an expression" | otherwise = return $ Set.insert m accum makeExpSet' accum d@(Divide {}) | Set.member d accum = throwError "Cannot have repeated (/) items in an expression" | otherwise = return $ Set.insert d accum insert :: (FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp)) -> FlattenedExp -> Set.Set FlattenedExp -> Set.Set FlattenedExp insert f e s = case Set.fold insert' (Set.empty,False) s of (s',True) -> s' _ -> Set.insert e s where insert' :: FlattenedExp -> (Set.Set FlattenedExp, Bool) -> (Set.Set FlattenedExp, Bool) insert' e (s,b) = case f e s of Just s' -> (s', True) Nothing -> (Set.insert e s, False) addConst :: Integer -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) addConst x (Const n) s = Just $ Set.insert (Const (n + x)) s addConst _ _ _ = Nothing addScale :: Integer -> (A.Expression,Int) -> FlattenedExp -> Set.Set FlattenedExp -> Maybe (Set.Set FlattenedExp) addScale x (lv,li) (Scale n (rv,ri)) s | (EQ == compare lv rv) && (li == ri) = Just $ Set.insert (Scale (x + n) (rv,ri)) s | otherwise = Nothing addScale _ _ _ _ = Nothing -- | A map from an item (a FlattenedExp, which may be a variable, or modulo\/divide item) to its coefficient in the problem. type VarMap = Map.Map FlattenedExp CoeffIndex -- | Background knowledge about a problem; either an equality or an inequality. data BackgroundKnowledge = Equal A.Expression A.Expression | LessThanOrEqual A.Expression A.Expression | RepBoundsIncl A.Variable A.Expression A.Expression deriving (Typeable, Data) instance Show BackgroundKnowledge where show (Equal e e') = showOccam e ++ " = " ++ showOccam e' show (LessThanOrEqual e e') = showOccam e ++ " <= " ++ showOccam e' show (RepBoundsIncl v e e') = showOccam e ++ " <= " ++ showOccam v ++ " <= " ++ showOccam e' -- | The names relate to the equations given in my Omega Test presentation. -- X is the top, Y is the bottom, A is the other var (x REM y = x + a) data ModuloCase = XZero | XPos | XNeg -- these two are for constant divisor, all the ones below are for variable divisor | XPosYPosAZero | XPosYPosANonZero | XPosYNegAZero | XPosYNegANonZero | XNegYPosAZero | XNegYPosANonZero | XNegYNegAZero | XNegYNegANonZero deriving (Show, Eq, Ord) type BKM = StateT VarMap (ReaderT CompState (Either String)) -- | Transforms background knowledge into problems -- TODO allow modulo in background knowledge transformBK :: ([FlattenedExp] -> [FlattenedExp]) -> BackgroundKnowledge -> BKM (EqualityProblem,InequalityProblem) transformBK f (Equal eL eR) = do eL' <- makeSingleEq f eL "background knowledge" eR' <- makeSingleEq f eR "background knowledge" let e = addEq eL' (amap negate eR') return ([e],[]) transformBK f (LessThanOrEqual eL eR) = do eL' <- makeSingleEq f eL "background knowledge" eR' <- makeSingleEq f eR "background knowledge" -- eL <= eR implies eR - eL >= 0 let e = addEq (amap negate eL') eR' return ([],[e]) transformBK f (RepBoundsIncl v low high) = do eLow <- makeSingleEq f low "background knowledge, lower bound" eHigh <- makeSingleEq f high "background knowledge, upper bound" -- v <= eH implies eH - v >= 0 -- eL <= v implies v - eL >= 0 ev <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 0)] >>= getSingleAccessItem ("Modulo or divide impossible") ev' <- makeEquation v ([], id) (error "Irrelevant type") [Scale 1 (A.ExprVariable emptyMeta v, 1)] >>= getSingleAccessItem ("Modulo or divide impossible") return ([], [ addEq (amap negate ev) eHigh , addEq (amap negate ev') eHigh , addEq (amap negate eLow) ev , addEq (amap negate eLow) ev' ]) transformBKList :: ([FlattenedExp] -> [FlattenedExp]) -> [BackgroundKnowledge] -> BKM (EqualityProblem,InequalityProblem) transformBKList f bk = mapM (transformBK f) bk >>* foldl accumProblem ([],[]) -- | Turns a single expression into an equation-item. An error is given if the resulting -- expression is anything complicated (for example, modulo or divide) makeSingleEq :: ([FlattenedExp] -> [FlattenedExp]) -> A.Expression -> String -> BKM EqualityConstraintEquation makeSingleEq f e desc = (lift (flatten e) >>* f) >>= makeEquation e ([{-TODO-}], f) (error $ "Type is irrelevant for " ++ desc) >>= getSingleAccessItem ("Modulo or Divide not allowed in " ++ desc ++ "(while processing: " ++ showOccam e ++ ")") -- | A helper function for joining two problems accumProblem :: (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem) -> (EqualityProblem,InequalityProblem) accumProblem = concatPair -- | Given a list of (written,read) expressions, an expression representing the upper array bound, returns either an error -- (because the expressions can't be handled, typically) or a set of equalities, inequalities and mapping from -- (unique, munged) variable name to variable-index in the equations. -- -- The general strategy is as follows. -- For every array index (here termed an "access"), we transform it into -- the usual @[FlattenedExp]@ using the flatten function. Then we also transform -- any access that is in the mirror-side of a Replicated item into its mirrored version -- where each i is changed into i\'. This is done by using @vi=(variable "i",0)@ -- (in @Scale _ vi@) for the plain (normal) version, and @vi=(variable "i",1)@ -- for the prime (mirror) version. -- -- Then the equations have bounds added. The rules are fairly simple; if -- any of the transformed EqualityConstraintEquation (or related equalities or inequalities) representing an access -- have a non-zero i (and\/or i\'), the bound for that variable is added. -- So for example, an expression like i = i\' + 3 would have the bounds for -- both i and i\' added (which would be near-identical, e.g. 1 <= i <= 6 and -- 1 <= i\' <= 6). We have to check the equalities and inequalities because -- when processing modulo, for the i REM y == 0 option, i will not appear in -- the index itself (which will be 0) but will appear in the surrounding -- constraints, and we still want to add the replication bounds. -- -- The remainder of the work (correctly pairing equations) is done by -- squareAndPair. -- -- TODO probably want to take this into the PassM monad at some point, to use the Meta in the error message makeEquations :: ParItems (BK, [A.Expression], [A.Expression]) -> A.Expression -> ReaderT CompState (Either String) [(((A.Expression, [ModuloCase]), (A.Expression, [ModuloCase])), VarMap, (EqualityProblem, InequalityProblem))] makeEquations accesses bound = do ((v,h,repVarIndexes, allReps),s) <- (flip runStateT) Map.empty $ do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses high <- makeSingleEq id bound "upper bound" return (accesses', high, nub repVars, allReps) lift $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) where lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String [(EqualityProblem, InequalityProblem)] lookupBK reps (e,_,bk) = liftM (filter (\x -> (not $ null $ fst x) || (not $ null $ snd x))) $ mapM (foldl (liftM2 accumProblem) (return ([],[])) . map snd . filter (\(v,_) -> v `elem` vs || v `elem` reps') . Map.toList) bk where reps' :: [Var] reps' = map (Var . A.Variable emptyMeta) reps vs :: [Var] vs = map Var $ listifyDepth (const True :: A.Variable -> Bool) e -- | A front-end to the setIndexVar' function setIndexVar :: A.Variable -> Int -> [FlattenedExp] -> [FlattenedExp] setIndexVar tv ti = map (setIndexVar' tv ti) -- | Sets the sub-index of the specified variable throughout the expression setIndexVar' :: A.Variable -> Int -> FlattenedExp -> FlattenedExp setIndexVar' tv ti s@(Scale n (v,_)) | EQ == compare (A.ExprVariable emptyMeta tv) v = Scale n (v,ti) | otherwise = s setIndexVar' tv ti (Modulo n top bottom) = Modulo n top' bottom' where top' = Set.map (setIndexVar' tv ti) top bottom' = Set.map (setIndexVar' tv ti) bottom setIndexVar' tv ti (Divide n top bottom) = Divide n top' bottom' where top' = Set.map (setIndexVar' tv ti) top bottom' = Set.map (setIndexVar' tv ti) bottom setIndexVar' _ _ e = e -- | Given a list of replicators (marked enabled\/disabled by a flag), the writes and reads, -- turns them into a single list of accesses with all the relevant information. The writes and reads -- can be grouped together because they are differentiated by the ArrayAccessType in the result mkEq :: [((A.Name, A.Replicator), Bool)] -> (BK, [A.Expression], [A.Expression]) -> StateT [(CoeffIndex, CoeffIndex)] BKM [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] mkEq reps (bk, ws, rs) = do repVarEqs <- mapM (liftF makeRepVarEq) reps concatMapM (mkEq' repVarEqs) (ws' ++ rs') where ws' = zip (repeat AAWrite) ws rs' = zip (repeat AARead) rs makeRepVarEq :: ((A.Name, A.Replicator), Bool) -> BKM (A.Variable, EqualityConstraintEquation, EqualityConstraintEquation) makeRepVarEq ((varName, A.For m from for _), _) = do from' <- makeSingleEq id from "replication start" upper <- makeSingleEq id (subExprsInt (addExprsInt for from) (makeConstant m 1)) "replication count" return (A.Variable m varName, from', upper) mkEq' :: [(A.Variable, EqualityConstraintEquation, EqualityConstraintEquation)] -> (ArrayAccessType, A.Expression) -> StateT [(CoeffIndex,CoeffIndex)] BKM [((A.Expression, [ModuloCase], BK'), ArrayAccessType, (EqualityConstraintEquation, EqualityProblem, InequalityProblem))] mkEq' repVarEqs (aat, e) = do f <- lift . lift $ flatten e mirrorFunc <- liftM foldFuncs $ mapM mirrorFlaggedVar reps g <- lift $ makeEquation e (bk, mirrorFunc) aat (mirrorFunc f) case g of Group g' -> return g' _ -> throwError "Replicated group found unexpectedly" -- | Turns all instances of the variable from the given replicator into their primed version in the given expression mirrorFlaggedVar :: ((A.Name, A.Replicator),Bool) -> StateT [(CoeffIndex,CoeffIndex)] BKM ([FlattenedExp] -> [FlattenedExp]) mirrorFlaggedVar (_,False) = return id mirrorFlaggedVar ((varName, A.For m from for _), True) = do varIndexes <- lift $ seqPair (varIndex (Scale 1 (A.ExprVariable emptyMeta var,0)), varIndex (Scale 1 (A.ExprVariable emptyMeta var,1))) modify (varIndexes :) return $ setIndexVar var 1 where var = A.Variable m varName instance Die (ReaderT CompState (Either String)) where dieReport (_, s) = throwError s -- Note that in all these functions, the divisor should always be positive! canonicalise :: forall m. (CSMR m, Die m) => A.Expression -> m A.Expression canonicalise e@(A.FunctionCall m n es) = do mOp <- functionOperator n ts <- mapM astTypeOf es case (mOp, fmap (\op -> A.nameName n == occamDefaultOperator op ts) mOp) of (Just op, Just True) | op == "+" || op == "*" -> liftM (foldl1 (\a b -> A.FunctionCall m n [a, b]) . sort) $ gatherTerms n e _ -> mapM canonicalise es >>* A.FunctionCall m n where gatherTerms :: A.Name -> A.Expression -> m [A.Expression] gatherTerms n (A.FunctionCall _ n' es) | n == n' = concatMapM (gatherTerms n) es gatherTerms _ e = canonicalise e >>* singleton canonicalise e = return e instance CSMR (ReaderT CompState (Either String)) where getCompState = ask flatten :: A.Expression -> ReaderT CompState (Either String) [FlattenedExp] flatten (A.Literal _ _ (A.IntLiteral _ n)) = return [Const (read n)] flatten e@(A.FunctionCall m fn [lhs, rhs]) = do mOp <- builtInOperator fn case mOp of Just "+" -> combine' (flatten lhs) (flatten rhs) Just "-" -> combine' (flatten lhs) (mapM (scale (-1)) =<< flatten rhs) Just "*" -> multiplyOut' (flatten lhs) (flatten rhs) Just "\\" -> liftM2L (Modulo 1) (flatten lhs) (flatten rhs) Just "/" ->do rhs' <- flatten rhs case onlyConst rhs' of Just _ -> liftM2L (Divide 1) (flatten lhs) (return rhs') -- Can't deal with variable divisors, leave expression as-is: Nothing -> do e' <- canonicalise e return [Scale 1 (e',0)] _ -> do e' <- canonicalise e return [Scale 1 (e',0)] where liftM2L :: MonadError String m => (Set.Set FlattenedExp -> Set.Set FlattenedExp -> c) -> m [FlattenedExp] -> m [FlattenedExp] -> m [c] liftM2L f x y = liftM singleton $ liftM2 f (x >>= makeExpSet) (y >>= makeExpSet) multiplyOut' :: (Die m, CSMR m, MonadError String m ) => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp] multiplyOut' x y = join $ liftM2 multiplyOut x y multiplyOut :: forall m. (Die m, CSMR m, MonadError String m) => [FlattenedExp] -> [FlattenedExp] -> m [FlattenedExp] multiplyOut lhs rhs = mapM (uncurry mult) pairs where pairs = product2 (lhs,rhs) mult :: FlattenedExp -> FlattenedExp -> m FlattenedExp mult (Const x) e = scale x e mult e (Const x) = scale x e mult lhs rhs = do lhs' <- backToEq lhs rhs' <- backToEq rhs e <- mulExprs lhs' rhs' >>= canonicalise return $ (Scale 1 (e, 0)) backScale :: Integer -> A.Expression -> m A.Expression backScale 1 e = return e backScale n e = do t <- astTypeOf e mulExprs (makeConstant' emptyMeta t n) e >>= canonicalise backToEq :: FlattenedExp -> m A.Expression backToEq (Const c) = return $ makeConstant emptyMeta (fromInteger c) backToEq (Scale n (e,0)) = backScale n e backToEq (Modulo n t b) | Set.null t || Set.null b = throwError "Modulo had empty top or bottom" | otherwise = do t' <- mapM backToEq $ Set.toList t b' <- mapM backToEq $ Set.toList b t'' <- foldM1 addExprs t' b'' <- foldM1 addExprs b' remExprs t'' b'' >>= backScale n backToEq (Divide n t b) | Set.null t || Set.null b = throwError "Divide had empty top or bottom" | otherwise = do t' <- mapM backToEq $ Set.toList t b' <- mapM backToEq $ Set.toList b t'' <- foldM1 addExprs t' b'' <- foldM1 addExprs b' divExprs t'' b'' >>= backScale n -- | Scales a flattened expression by the given integer scaling. scale :: Monad m => Integer -> FlattenedExp -> m FlattenedExp scale sc (Const n) = return $ Const (n * sc) scale sc (Scale n v) = return $ Scale (n * sc) v scale sc (Modulo n t b) = return $ Modulo (n * sc) t b scale sc (Divide n t b) = return $ Divide (n * sc) t b -- | An easy way of applying combine to two monadic returns combine' :: Monad m => m [FlattenedExp] -> m [FlattenedExp] -> m [FlattenedExp] combine' = liftM2 combine -- | Combines (adds) two flattened expressions. combine :: [FlattenedExp] -> [FlattenedExp] -> [FlattenedExp] combine = (++) flatten e = do e' <- canonicalise e return [Scale 1 (e',0)] -- | The "square" refers to making all equations the length of the longest -- one, and the pair refers to pairing each in a list of array accesses (e.g. -- [0, 5, i + 2]) into all possible pairings ([0 == 5, 0 == i + 2, 5 == i + 2]) -- -- There are two complications to this function. -- -- Firstly, the array accesses are not actually given in a plain list, but -- instead a list of lists. This is because for things like modulo, there are -- groups of possible accesses that should not be paired against each other. -- For example, you may have something like [0,x,-x] as the three possible -- options for a modulo. You want to pair the accesses against other accesses -- (e.g. y + 6), but not against each other. So the arguments are passed in -- in groups: [[0,x,-x],[y + 6]] and groups are paired against each other, -- but not against themselves. This all refers to the third argument to the -- function. Each item is actually a triple of (item, equalities, inequalities) -- because the modulo aspect adds additional constraints. -- -- The other complication comes from replicated variables. -- The first argument is a list of (plain,prime) coefficient indexes -- that effectively labels the indexes related to replicated variables. -- squareAndPair does two things with this information: -- 1. It discards all equations that feature only the prime version of -- a variable. You might have passed in the accesses as [[i],[i'],[3]]. -- (Altering the grouping would not be able to solve this particular problem) -- The pairings generated would be [i == i', i == 3, i' == 3]. But the -- last two are in effect identical. Therefore we drop the i' prime -- version, because it has i' but not i. In contrast, the first item -- (i == i') is retained because it features both i and i'. -- 2. For every equation that features both i and i', it adds -- the inequality "i <= i' - 1". Because all possible combinations of -- accesses are examined, in the case of [i,i + 1,i', i' + 1], the pairing -- will produce both "i = i' + 1" and "i + 1 = i'" so there is no need -- to vary the inequality itself. squareAndPair :: (label -> Either String [(EqualityProblem, InequalityProblem)]) -> (label -> labelStripped) -> [(CoeffIndex, CoeffIndex)] -> VarMap -> [ArrayAccess label] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> Either String [((labelStripped, labelStripped), VarMap, (EqualityProblem, InequalityProblem))] squareAndPair lookupBK strip repVars s v lh = concatMapM id [let f ((bkEqA, bkIneqA), (bkEqB, bkIneqB)) = (transformPair strip strip labels, s, squareEquations (nub (bkEqA ++ bkEqB) ++ eq, nub (bkIneqA ++ bkIneqB) ++ ineq ++ concat (applyAll (eq,ineq) (map addExtra repVars)))) bk = case liftM2 (curry product2) (liftM atLeastOne $ lookupBK (fst labels)) (liftM atLeastOne $ lookupBK (snd labels)) of Right [] -> Right [(([],[]),([],[]))] -- No BK xs -> xs in bk >>* map f | (labels, eq,ineq) <- pairEqsAndBounds v lh ,and (map (primeImpliesPlain (eq,ineq)) repVars) ] where atLeastOne :: [(EqualityProblem, InequalityProblem)] -> [(EqualityProblem, InequalityProblem)] atLeastOne [] = [([], [])] atLeastOne xs = xs itemPresent :: CoeffIndex -> [Array CoeffIndex Integer] -> Bool itemPresent x = any (\a -> arrayLookupWithDefault 0 a x /= 0) primeImpliesPlain :: (EqualityProblem,InequalityProblem) -> (CoeffIndex,CoeffIndex) -> Bool primeImpliesPlain (eq,ineq) (plain,prime) = if itemPresent prime (eq ++ ineq) -- There are primes, check all the plains are present: then itemPresent plain (eq ++ ineq) -- No prime, therefore fine: else True addExtra :: (CoeffIndex, CoeffIndex) -> (EqualityProblem,InequalityProblem) -> InequalityProblem addExtra (plain,prime) (eq, ineq) -- prime >= plain + 1 (prime - plain - 1 >= 0) = [mapToArray $ Map.fromList [(prime,1), (plain,-1), (0, -1)]] getSingleAccessItem :: MonadError String m => String -> ArrayAccess label -> m EqualityConstraintEquation getSingleAccessItem _ (Group [(_,_,(acc,_,_))]) = return acc getSingleAccessItem err _ = throwError err -- | Odd helper function for getting\/asserting the first item of a triple from a singleton list inside a monad transformer (!) getSingleItem :: MonadError String m => String -> [(a,b,c)] -> m a getSingleItem _ [(item,_,_)] = return item getSingleItem err _ = throwError err -- | Finds the index associated with a particular variable; either by finding an existing index -- or allocating a new one. varIndex :: FlattenedExp -> BKM Int varIndex (Scale _ (e,vi)) = do st <- get let (st',ind) = case Map.lookup (Scale 1 (e,vi)) st of Just val -> (st,val) Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in (Map.insert (Scale 1 (e,vi)) newId st, newId) put st' return ind varIndex mod@(Modulo _ top bottom) = do st <- get let (st',ind) = case Map.lookup mod st of Just val -> (st,val) Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in (Map.insert mod newId st, newId) put st' return ind varIndex div@(Divide _ top bottom) = do st <- get let (st',ind) = case Map.lookup div st of Just val -> (st,val) Nothing -> let newId = (1 + (maximum $ 0 : Map.elems st)) in (Map.insert div newId st, newId) put st' return ind -- | Pairs all possible combinations of the list of equations. pairEqsAndBounds :: [ArrayAccess label] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> [((label,label),EqualityProblem, InequalityProblem)] pairEqsAndBounds items bounds = (concatMap (uncurry pairEqs) . allPairs) items ++ concatMap pairRep items where pairEqs :: ArrayAccess label -> ArrayAccess label -> [((label,label),EqualityProblem, InequalityProblem)] pairEqs (Group accs) (Group accs') = mapMaybe (uncurry pairEqs'') $ product2 (accs,accs') pairEqs (Replicated rA rB) lacc = concatMap (pairEqs lacc) rA pairEqs lacc (Replicated rA rB) = concatMap (pairEqs lacc) rA -- Used to pair the items of a single instance of PAR replication with each other pairRep :: ArrayAccess label -> [((label,label),EqualityProblem, InequalityProblem)] pairRep (Replicated rA rB) = concatMap (uncurry pairEqs) (product2 (rA,rB)) ++ concatMap (uncurry pairEqs) (allPairs rA) pairRep _ = [] pairEqs'' :: (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem)) -> (label, ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem)) -> Maybe ((label,label), EqualityProblem, InequalityProblem) pairEqs'' (lx,x,x') (ly,y,y') = case pairEqs' (x,x') (y,y') of Just (eq,ineq) -> Just ((lx,ly),eq,ineq) Nothing -> Nothing pairEqs' :: (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem)) -> (ArrayAccessType,(EqualityConstraintEquation, EqualityProblem, InequalityProblem)) -> Maybe (EqualityProblem, InequalityProblem) pairEqs' (AARead,_) (AARead,_) = Nothing pairEqs' (_,(ex,eqX,ineqX)) (_,(ey,eqY,ineqY)) = Just ([arrayZipWith' 0 (-) ex ey] ++ eqX ++ eqY, ineqX ++ ineqY ++ getIneqs bounds [ex,ey]) addEq :: EqualityConstraintEquation -> EqualityConstraintEquation -> EqualityConstraintEquation addEq = arrayZipWith' 0 (+) -- | Given a (low,high) bound (typically: array dimensions), and a list of equations ex, -- forms the possible inequalities: -- * ex >= low -- * ex <= high getIneqs :: (EqualityConstraintEquation, EqualityConstraintEquation) -> [EqualityConstraintEquation] -> [InequalityConstraintEquation] getIneqs (low, high) = concatMap getLH where -- eq >= low => eq - low >= 0 -- eq <= high => high - eq >= 0 getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation] getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq] justState :: Error e => StateT s (ReaderT r (Either e)) a -> StateT s (ReaderT r (Either e)) (Either e a) justState m = do st <- get r <- ask let (x, st') = case runReaderT (runStateT m st) r of Left err -> (Left err, st) Right (x, st') -> (Right x, st') put st' return x -- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] -> BKM (ArrayAccess (label,[ModuloCase], BK')) makeEquation l (bk, bkF) t summedItems = do eqs <- process summedItems bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)] return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs'] where process :: [FlattenedExp] -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] process = foldM makeEquation' empty makeEquation' :: [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] -> FlattenedExp -> BKM [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] makeEquation' m (Const n) = return $ add (0,n) m makeEquation' m sc@(Scale n v) = varIndex sc >>* (\ind -> add (ind, n) m) makeEquation' m mod@(Modulo n top bottom) = do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c)) top'' <- getSingleItem "Modulo or divide not allowed in the numerator of Modulo" top' bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c)) modIndex <- varIndex mod case onlyConst (Set.toList bottom) of Just bottomConst -> let add_x_plus_my = zipMap plus top'' . zipMap plus (Map.fromList [(modIndex, abs bottomConst)]) in -- Adds n*(x + my) let add_n_x_plus_my = zipMap plus (Map.map (*n) top'') . zipMap plus (Map.fromList [(modIndex, n * abs bottomConst)]) in return $ -- The zero option (x = 0, x REM y = 0): ( map (transformQuad (++ [XZero]) id (++ [top'']) id) m) ++ -- The top-is-positive option: ( map (transformQuad (++ [XPos]) add_n_x_plus_my id (++ -- x >= 1 [zipMap plus (Map.fromList [(0,-1)]) top'' -- m <= 0 ,Map.fromList [(modIndex,-1)] -- x + my + 1 - |y| <= 0 ,Map.map negate $ add_x_plus_my $ Map.fromList [(0,1 - abs bottomConst)] -- x + my >= 0 ,add_x_plus_my $ Map.empty]) ) m) ++ -- The top-is-negative option: ( map (transformQuad (++ [XNeg]) add_n_x_plus_my id (++ -- x <= -1 [add' (0,-1) $ Map.map negate top'' -- m >= 0 ,Map.fromList [(modIndex,1)] -- x + my - 1 + |y| >= 0 ,add_x_plus_my $ Map.fromList [(0,abs bottomConst - 1)] -- x + my <= 0 ,Map.map negate $ add_x_plus_my Map.empty]) ) m) _ -> do bottom'' <- getSingleItem "Modulo or divide not allowed in the divisor of Modulo" bottom' return $ -- The zero option (x = 0, x REM y = 0): (map (transformQuad (++ [XZero]) id (++ [top'']) id) m) -- The rest: ++ twinItems True True n (top'', modIndex) bottom'' ++ twinItems True False n (top'', modIndex) bottom'' ++ twinItems False True n (top'', modIndex) bottom'' ++ twinItems False False n (top'', modIndex) bottom'' where -- Each pair for modulo (variable divisor) depending on signs of x and y (in x REM y): twinItems :: Bool -> Bool -> Integer -> (Map.Map Int Integer,Int) -> Map.Map Int Integer -> [([ModuloCase], Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] twinItems xPos yPos n (top,modIndex) bottom = (map (transformQuad (++ [findCase xPos yPos False]) (zipMap plus $ Map.map (*n) top) id (++ [xEquation] ++ [xLowerBound False] ++ [xUpperBound False])) m) ++ (map (transformQuad (++ [findCase xPos yPos True]) (zipMap plus (Map.map (*n) top) . add' (modIndex, n)) id (++ [xEquation] ++ [xLowerBound True] ++ [xUpperBound True] -- We want to add the bounds for a and y as follows: -- xPos yPos | Equation -- T T | -y - a >= 0 -- T F | y - a >= 0 -- F T | a - y >= 0 -- F F | a + y >= 0 -- Therefore the sign of a is (not xPos), the sign of y is (not yPos) ++ [add' (modIndex,if xPos then -1 else 1) (signEq (not yPos) bottom)])) m) where -- x >= 1 or x <= -1 (rearranged: -1 + x >= 0 or -1 - x >= 0) xEquation = add' (0,-1) (signEq xPos top) -- We include (x [+ a] >= 0 or x [+ a] <= 0) even though they are redundant in some cases (addA = False): xLowerBound addA = signEq xPos $ (if addA then add' (modIndex,1) else id) top -- We want to add the bounds as follows: -- xPos yPos | Equation -- T T | y - 1 - x - a >= 0 -- T F | -y - 1 - x - a >= 0 -- F T | x + a - 1 + y >= 0 -- F F | x + a - y - 1 >= 0 -- Therefore the sign of y in the equation is yPos, the sign of x and a is (not xPos) xUpperBound addA = add' (0,-1) $ zipMap plus (signEq (not xPos) ((if addA then add' (modIndex,1) else id) top)) (signEq yPos bottom) signEq sign eq = if sign then eq else Map.map negate eq findCase xPos yPos aNonZero = case (xPos, yPos, aNonZero) of (True , True , True ) -> XPosYPosANonZero (True , True , False) -> XPosYPosAZero (True , False, True ) -> XPosYNegANonZero (True , False, False) -> XPosYNegAZero (False, True , True ) -> XNegYPosANonZero (False, True , False) -> XNegYPosAZero (False, False, True ) -> XNegYNegANonZero (False, False, False) -> XNegYNegAZero makeEquation' m div@(Divide n top bottom) = do top' <- process (Set.toList top) >>* map (\(_,a,b,c) -> (a,b,c)) top'' <- getSingleItem "Modulo or Divide not allowed in the numerator of Divide" top' bottom' <- process (Set.toList bottom) >>* map (\(_,a,b,c) -> (a,b,c)) divIndex <- varIndex div case onlyConst (Set.toList bottom) of Just bottomConst -> let add_m :: Map.Map Int Integer -> Map.Map Int Integer add_m = zipMap plus (Map.fromList [(divIndex,n)]) add_x_minus_my = zipMap plus top'' . zipMap plus (Map.fromList [(divIndex,-bottomConst)]) in return $ -- The zero option (x = 0, x REM y = 0): ( map (transformQuad (++ [XZero]) id (++ [top'']) id) m) ++ -- The top-is-positive option: ( map (transformQuad (++ [XPos]) add_m id (++ -- x >= 1 [zipMap plus (Map.fromList [(0,-1)]) top'' -- m >= 0 if y positive -- m <= 0 (i.e. -m >= 0) if y negative ,Map.fromList [(divIndex, signum bottomConst)] -- x + my + 1 - y <= 0 if y positive -- x + my - 1 - y >= 0 if y negative ,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my $ Map.fromList [(0,signum bottomConst - bottomConst)] -- x + my >= 0 if y positive -- x + my <= 0 if negative ,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.empty]) ) m) ++ -- The top-is-negative option: ( map (transformQuad (++ [XNeg]) add_m id (++ -- x <= -1 [add' (0,-1) $ Map.map negate top'' -- m <= 0 if y positive -- m >= 0 if y negative ,Map.fromList [(divIndex, - signum bottomConst)] -- x + my - 1 + y >= 0 if y positive -- x + my + 1 + y <= 0 if y negative ,(if (bottomConst > 0) then id else Map.map negate) $ add_x_minus_my $ Map.fromList [(0,bottomConst - signum bottomConst)] -- x + my <= 0 if y positive -- x + my >= 0 if y negative ,(if (bottomConst > 0) then Map.map negate else id) $ add_x_minus_my Map.empty]) ) m) _ -> throwError "Variables in divisor not supported by usage checker" empty :: [([ModuloCase],Map.Map Int Integer,[Map.Map Int Integer], [Map.Map Int Integer])] empty = [([],Map.empty,[],[])] plus :: Num n => Maybe n -> Maybe n -> Maybe n plus x y = Just $ (fromMaybe 0 x) + (fromMaybe 0 y) add' :: (Int,Integer) -> Map.Map Int Integer -> Map.Map Int Integer add' (m,n) = Map.insertWith (+) m n add :: (Int,Integer) -> [(z,Map.Map Int Integer,a,b)] -> [(z,Map.Map Int Integer,a,b)] add (m,n) = map $ (\(a,b,c,d) -> (a,(Map.insertWith (+) m n) b,c,d)) -- | Converts a map to an array. Any missing elements in the middle of the bounds are given the value zero. -- Could probably be moved to Utils mapToArray :: (IArray a v, Num v, Num k, Ord k, Ix k) => Map.Map k v -> a k v mapToArray m = accumArray (+) 0 (0, highest') . Map.assocs $ m where highest' = maximum $ 0 : Map.keys m -- | Given a pair of equation sets, makes all the equations in the lists be the length -- of the longest equation. All missing elements are of course given value zero. squareEquations :: ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) -> ([Array CoeffIndex Integer],[Array CoeffIndex Integer]) squareEquations (eqs,ineqs) = uncurry transformPair (mkPair $ map $ makeArraySize (0,highest) 0) (eqs,ineqs) where highest = maximum $ 0 : (concatMap indices $ eqs ++ ineqs)