Decision Trees Are Free Monads Over the Reader Functor

Clay Thomas

Motivation

Free f a, the free monad over a given functor f is often described as "trees which branch in the shape of f and are leaf-labeled by a". What exactly does this mean? Well, the definition of Free is:

data Free f a
  = Pure a
  | Free (f (Free f a))

So if we have a functor data Pair a = Pair a a, we indeed have Free Pair a representing ordinary leaf labeled, nonempty binary trees.

Now, if you know about decision trees, "leaf labeled" should catch your ear. (Binary) decision trees are trees where each node represents a yes/no feature about some observation. The leafs of the tree are labeled with distributions. To predict, you descend the tree according to an observation, and the distribution you reach at the bottom is your best guess at the distribution for your new observation.

So if we want to fit a monad Free f a to the task of a decision tree, it is clear that a should represent the distribution you are predicting. For the sake of simplicity, we will simply set a to Bool and will just guess a yes or no, as opposed to providing percentages.

Our choice of f is a little less clear. We need to read in some information from an observation, say of type r, and descend into the next level of the tree. Well, speaking of reading information, what about the reader functor (->) r! If we try f = (->) r, we get

data Free (-> r) Bool
  = Pure Bool
  | Free (r -> Free (-> r) Bool)

This looks like exactly what we want! We fix a row type r, and when provided new rows we can traverse through internal Free nodes until we reach a leaf Pure node. So, we define

-- | A model with row type `r` and class label type `c`
type TreeM r c = Free ((->) r) c

Recursion Combinators

The observation presented here is hardly earth-shattering, but it does come with some advantages other than "hey, a connection!". But adapting some recursive combinators (which typically act on the Fix data type) we can separate the recursive logic of our program from the actual computation. The goal of recursion combinators is to standardize patterns in recursion and make their implementations cleaner.

Free f a is strikingly similar to the type Fix f, the fixed point of the functor f. Recall

newtype Fix f = Fix { unFix :: f (Fix f) }

Essentially, Free allows us to stop our infinite, f-branching tree early and return a Pure value of type a. In many applications, the real result of this is that the functors f that you use with Fix are more complicated than those you use with Free because with Fix you need to embed the notion of returning data into your base functor. (For example, we could model Free f a itself using Fix with a base functor data Br f a r = Either a (f r) deriving Functor, and get Free f a === Fix (Br f a).)

For Fix, catamorphisms and anamorphisms are very useful recursion combinators, which respectively collapse and grow a recursive structure:

-- | Tear down a recursive structure, i.e. an element of Fix f.
--   First we recursively tear down each of the subtrees in the 
--   first level of the fixed point. Then we collapse the 
--   last layer using the algebra directly.
cata :: Functor f => (f a -> a)    -- ^ A algebra to collapse a container to a value
                  -> Fix f         -- ^ A recursive tree of containers to start with
                  -> a
cata alg fix = alg . fmap (cata alg) . unFix $ fix

-- | Build up a recursive structure, i.e. an element of Fix f.
--   We first expand our seed by one step.
--   Then we map over the resultant container,
--   recursively expanding each value along the way.
ana :: Functor f => (a -> f a)     -- ^ A function to expand an a
                 -> a              -- ^ A seed value to start off with
                 -> Fix f
ana grow seed = Fix . fmap (ana grow) . grow $ seed

It is pretty easy to extend cata to work on Free instead of Fix, we just need to add a case for Pure:

cataF :: Functor f => (f a -> a) -> Free f a -> a
cataF alg (Free u) = alg . fmap (cataF alg) $ u
cataF _ (Pure a) = a

It is somewhat harder to logically extend ana. Indeed, the exact same code that worked for Fix works for free, it just always builds infinite trees and never allows Pure. Thus the input function has to allow for the possibility of a Pure value:

anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b
anaF grow seed
  = case grow seed of
         Left u -> Free . fmap (anaF grow) $ u
         Right b -> Pure b

The Hard Work

We start with a preamble to equip some convenient language extensions and import the needed libraries. Then we define our table data type and some simplified accessor functions.

What follows is complete and valid Haskell can be run on a modern GHC. You need only the code below this point, along with our definitions of cataF and anaF above. You can also snag the code here.

{-# LANGUAGE RankNTypes
           , RelaxedPolyRec
           , DeriveFunctor
           , TupleSections
           , ScopedTypeVariables
           , UndecidableInstances
  #-}
import qualified Data.List as List
import qualified Data.MultiSet as Set
import qualified Data.Map as Map
import Control.Monad.Free


-- | value = Bool will suffice for this code, but 
-- more general Tables are certainly reasonable
data Table key value = Table
  { keys :: [key]
      -- ^ For iterating and looping purposes
  , rows :: Set.MultiSet (value, Row key value)
      -- ^ An unordered collection of Rows associated to labels
  } deriving(Show)

type Row key value = Map.Map key value

numKeys :: Table k v -> Int
numKeys = length . keys

numRows :: Table k v -> Int
numRows = Set.size . rows

--Assume all tables are full
getKey :: Ord k => k -> Row k v -> v
getKey k row
  = maybe undefined id (Map.lookup k row)

emptyBinTable :: Table key value
emptyBinTable = Table [] Set.empty

Now, the learning method of decision trees is (roughly) the following:

  1. If there are no keys in the table, return a model that predicts the most common class label.
  2. If there are keys, find the key that best predicts the class label.
  3. Split the data into two new tables based on the value of that key.
  4. Recursively grow a decision tree for the two new data sets.
  5. Put the two new trees together into a model that first predicts based on the best key, then predicts based on the recursive, new trees.

The following code implements several tools we will need:


-- | We score the keys by how many labels the key could get correct.
-- This method returns two values: the first is if the model applied a 
-- positive correlation, the second is if it assumes a negative correlation.
scores :: Ord k => k -> Table k Bool -> (Int, Int)
scores k tab = Set.fold (indicator k) ((,0) 0) $ rows tab
  where indicator :: Ord k => k -> (Bool, Row k Bool) -> (Int,Int) -> (Int,Int)
        indicator k (label, row) (pos, neg)
          = case Map.lookup k row of
                 Just a -> (pos + fromEnum (label==a), neg + fromEnum (label/=a))
                 Nothing -> (pos, neg)

-- | Loop over all the keys and find the one that predicts with highest accuracy
bestKey :: Ord k => Table k Bool -> k
bestKey tab
  = let bestScore k
          = let (pos, neg) = scores k tab
             in (k, max pos neg)
        maxScores = fmap bestScore (keys tab)
        (bestKey, _)
          = List.maximumBy (\(_,s) (_,s') -> s `compare` s') maxScores
     in bestKey

removeKey :: (Ord k, Ord v) => k -> Table k v -> Table k v
removeKey k tab =
  emptyBinTable
    { keys = (List.\\) (keys tab) [k]
    , rows = Set.map (\(lab,row) -> (lab, Map.delete k row)) (rows tab)
    }

-- | Split a table into two tables based on the value of one key.
-- Also remove the key from the new tables.
filterOn :: Ord k => k -> Table k Bool -> (Table k Bool, Table k Bool)
filterOn k tab
  = let kTrue = getKey k . snd -- ^ predicate to test if key k is true for some row
        trueRows = Set.filter kTrue (rows tab)
        falseRows = Set.filter (not . kTrue) (rows tab)
     in (removeKey k tab{rows = trueRows}, removeKey k tab{rows = falseRows})

-- | Ignore all the rows and just guess a boolean based on class label
bestGuess :: Ord k => Table k Bool -> Bool
bestGuess tab
  = let nTrue = Set.fold (\(b,_) accum -> accum + fromEnum b) 0 (rows tab)
        -- ^ count number of true labels
     in if 2 * nTrue >= numRows tab
           then True
           else False

These are all straightforward things that we would probably implement if we were writing this algorithm without recursion combinators.

Applying our Recursion Combinators

Now that we have some functions to manipulate and extract information out of our tables, we are ready to learn our models and predict with them. We write the function discriminate to fit the type signature of anaF. This function accepts a table and returns one of two things:

Recall that anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b and that type TreeM r c = Free ((->) r) c. When we apply anaF to discriminate, it has the effect of recursing over the newly created tables, crowing more models until we hit the base case of a table with only class labels.

discriminate :: Ord k => Table k Bool -> Either (Row k Bool -> Table k Bool) Bool
discriminate tab
  | numKeys tab == 0
    = Right $ bestGuess tab
  | otherwise
    = let key = bestKey tab
          (trueTab, falseTab) = filterOn key tab
       in Left $ \row ->
          let bool = getKey key row
           in if bool then trueTab else falseTab

-- | Finally time to use this!
type TreeM r c = Free ((->) r) c

learn :: Ord k => Table k Bool -> TreeM (Row k Bool) Bool
learn = anaF discriminate

--cataF :: Functor f => (f a -> a) -> Free f a -> a
predict :: Ord k => Row k Bool -> TreeM (Row k Bool) Bool -> Bool
predict row model = cataF ($ row) model

Now we are done! All the work paid off with very, very short definitions for learn and predict. Keep in mind that before learn and predict, nothing we had written involved recursion at all.

Extensions

By changing discriminate, we change our learning method. Using anaF adds some constraints on how we can learn our model, but still allows some freedom. Here are some possible avenues for extending and improving our models:

Appendix: A Printable Interface

It is hard to verify if the above information is correct because we cannot (sensibly) print out functions in Haskell. The following code fixes this by providing a datatype for the branching instead of relying on functions. This provides much less flexibility in how we do branching, but allows us to print things! The following implements the exact same algorithm as above:


data Branch k r = Branch
  { key :: k
  , bTrue :: r
  , bFalse :: r
  } deriving(Show, Functor)

discriminate' :: Ord k => Table k Bool -> Either (Branch k (Table k Bool)) Bool
discriminate' tab
  | numKeys tab == 0
    = Right $ bestGuess tab
  | otherwise
    = let goodKey = bestKey tab
          (keyTrueTab, keyFalseTab) = filterOn goodKey tab
       in Left $ Branch
            { key = goodKey
            , bTrue = keyTrueTab
            , bFalse = keyFalseTab
            }

learn' :: Ord k => Table k Bool -> Free (Branch k) Bool
learn' = anaF discriminate'

predict' :: Ord k => Row k Bool -> Free (Branch k) Bool -> Bool
predict' row model = cataF phi model
  where phi br
          = case getKey (key br) row of
                 True -> bTrue br
                 False -> bFalse br