Clay Thomas

## Motivation

`Free f a`, the free monad over a given functor `f` is often described as "trees which branch in the shape of `f` and are leaf-labeled by `a`". What exactly does this mean? Well, the definition of `Free` is:

``````data Free f a
= Pure a
| Free (f (Free f a))``````

So if we have a functor `data Pair a = Pair a a`, we indeed have `Free Pair a` representing ordinary leaf labeled, nonempty binary trees.

Now, if you know about decision trees, "leaf labeled" should catch your ear. (Binary) decision trees are trees where each node represents a yes/no feature about some observation. The leafs of the tree are labeled with distributions. To predict, you descend the tree according to an observation, and the distribution you reach at the bottom is your best guess at the distribution for your new observation.

So if we want to fit a monad `Free f a` to the task of a decision tree, it is clear that `a` should represent the distribution you are predicting. For the sake of simplicity, we will simply set `a` to `Bool` and will just guess a yes or no, as opposed to providing percentages.

Our choice of `f` is a little less clear. We need to read in some information from an observation, say of type `r`, and descend into the next level of the tree. Well, speaking of reading information, what about the reader functor `(->) r`! If we try `f = (->) r`, we get

``````data Free (-> r) Bool
= Pure Bool
| Free (r -> Free (-> r) Bool)``````

This looks like exactly what we want! We fix a row type `r`, and when provided new rows we can traverse through internal `Free` nodes until we reach a leaf `Pure` node. So, we define

``````-- | A model with row type `r` and class label type `c`
type TreeM r c = Free ((->) r) c``````

## Recursion Combinators

The observation presented here is hardly earth-shattering, but it does come with some advantages other than "hey, a connection!". But adapting some recursive combinators (which typically act on the `Fix` data type) we can separate the recursive logic of our program from the actual computation. The goal of recursion combinators is to standardize patterns in recursion and make their implementations cleaner.

`Free f a` is strikingly similar to the type `Fix f`, the fixed point of the functor `f`. Recall

``newtype Fix f = Fix { unFix :: f (Fix f) }``

Essentially, `Free` allows us to stop our infinite, `f`-branching tree early and return a `Pure` value of type `a`. In many applications, the real result of this is that the functors `f` that you use with `Fix` are more complicated than those you use with `Free` because with `Fix` you need to embed the notion of returning data into your base functor. (For example, we could model `Free f a` itself using `Fix` with a base functor `data Br f a r = Either a (f r) deriving Functor`, and get `Free f a === Fix (Br f a)`.)

For `Fix`, catamorphisms and anamorphisms are very useful recursion combinators, which respectively collapse and grow a recursive structure:

``````-- | Tear down a recursive structure, i.e. an element of Fix f.
--   First we recursively tear down each of the subtrees in the
--   first level of the fixed point. Then we collapse the
--   last layer using the algebra directly.
cata :: Functor f => (f a -> a)    -- ^ A algebra to collapse a container to a value
-> Fix f         -- ^ A recursive tree of containers to start with
-> a
cata alg fix = alg . fmap (cata alg) . unFix \$ fix

-- | Build up a recursive structure, i.e. an element of Fix f.
--   We first expand our seed by one step.
--   Then we map over the resultant container,
--   recursively expanding each value along the way.
ana :: Functor f => (a -> f a)     -- ^ A function to expand an a
-> a              -- ^ A seed value to start off with
-> Fix f
ana grow seed = Fix . fmap (ana grow) . grow \$ seed``````

It is pretty easy to extend `cata` to work on `Free` instead of `Fix`, we just need to add a case for `Pure`:

``````cataF :: Functor f => (f a -> a) -> Free f a -> a
cataF alg (Free u) = alg . fmap (cataF alg) \$ u
cataF _ (Pure a) = a``````

It is somewhat harder to logically extend ana. Indeed, the exact same code that worked for Fix works for free, it just always builds infinite trees and never allows `Pure`. Thus the input function has to allow for the possibility of a `Pure` value:

``````anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b
anaF grow seed
= case grow seed of
Left u -> Free . fmap (anaF grow) \$ u
Right b -> Pure b``````

## The Hard Work

We start with a preamble to equip some convenient language extensions and import the needed libraries. Then we define our table data type and some simplified accessor functions.

What follows is complete and valid Haskell can be run on a modern GHC. You need only the code below this point, along with our definitions of `cataF` and `anaF` above. You can also snag the code here.

``````{-# LANGUAGE RankNTypes
, RelaxedPolyRec
, DeriveFunctor
, TupleSections
, ScopedTypeVariables
, UndecidableInstances
#-}
import qualified Data.List as List
import qualified Data.MultiSet as Set
import qualified Data.Map as Map

-- | value = Bool will suffice for this code, but
-- more general Tables are certainly reasonable
data Table key value = Table
{ keys :: [key]
-- ^ For iterating and looping purposes
, rows :: Set.MultiSet (value, Row key value)
-- ^ An unordered collection of Rows associated to labels
} deriving(Show)

type Row key value = Map.Map key value

numKeys :: Table k v -> Int
numKeys = length . keys

numRows :: Table k v -> Int
numRows = Set.size . rows

--Assume all tables are full
getKey :: Ord k => k -> Row k v -> v
getKey k row
= maybe undefined id (Map.lookup k row)

emptyBinTable :: Table key value
emptyBinTable = Table [] Set.empty``````

Now, the learning method of decision trees is (roughly) the following:

1. If there are no keys in the table, return a model that predicts the most common class label.
2. If there are keys, find the key that best predicts the class label.
3. Split the data into two new tables based on the value of that key.
4. Recursively grow a decision tree for the two new data sets.
5. Put the two new trees together into a model that first predicts based on the best key, then predicts based on the recursive, new trees.

The following code implements several tools we will need:

``````
-- | We score the keys by how many labels the key could get correct.
-- This method returns two values: the first is if the model applied a
-- positive correlation, the second is if it assumes a negative correlation.
scores :: Ord k => k -> Table k Bool -> (Int, Int)
scores k tab = Set.fold (indicator k) ((,0) 0) \$ rows tab
where indicator :: Ord k => k -> (Bool, Row k Bool) -> (Int,Int) -> (Int,Int)
indicator k (label, row) (pos, neg)
= case Map.lookup k row of
Nothing -> (pos, neg)

-- | Loop over all the keys and find the one that predicts with highest accuracy
bestKey :: Ord k => Table k Bool -> k
bestKey tab
= let bestScore k
= let (pos, neg) = scores k tab
in (k, max pos neg)
maxScores = fmap bestScore (keys tab)
(bestKey, _)
= List.maximumBy (\(_,s) (_,s') -> s `compare` s') maxScores
in bestKey

removeKey :: (Ord k, Ord v) => k -> Table k v -> Table k v
removeKey k tab =
emptyBinTable
{ keys = (List.\\) (keys tab) [k]
, rows = Set.map (\(lab,row) -> (lab, Map.delete k row)) (rows tab)
}

-- | Split a table into two tables based on the value of one key.
-- Also remove the key from the new tables.
filterOn :: Ord k => k -> Table k Bool -> (Table k Bool, Table k Bool)
filterOn k tab
= let kTrue = getKey k . snd -- ^ predicate to test if key k is true for some row
trueRows = Set.filter kTrue (rows tab)
falseRows = Set.filter (not . kTrue) (rows tab)
in (removeKey k tab{rows = trueRows}, removeKey k tab{rows = falseRows})

-- | Ignore all the rows and just guess a boolean based on class label
bestGuess :: Ord k => Table k Bool -> Bool
bestGuess tab
= let nTrue = Set.fold (\(b,_) accum -> accum + fromEnum b) 0 (rows tab)
-- ^ count number of true labels
in if 2 * nTrue >= numRows tab
then True
else False``````

These are all straightforward things that we would probably implement if we were writing this algorithm without recursion combinators.

## Applying our Recursion Combinators

Now that we have some functions to manipulate and extract information out of our tables, we are ready to learn our models and predict with them. We write the function `discriminate` to fit the type signature of `anaF`. This function accepts a table and returns one of two things:

• If the table contains only class labels, we return a single Bool that represents our best guess of the class label.

• If the table still has some keys, we return a mapping. This mapping takes in any row, and returns a split of the data based on the value of `bestKey` within the row.

Recall that `anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b` and that `type TreeM r c = Free ((->) r) c`. When we apply `anaF` to `discriminate`, it has the effect of recursing over the newly created tables, crowing more models until we hit the base case of a table with only class labels.

``````discriminate :: Ord k => Table k Bool -> Either (Row k Bool -> Table k Bool) Bool
discriminate tab
| numKeys tab == 0
= Right \$ bestGuess tab
| otherwise
= let key = bestKey tab
(trueTab, falseTab) = filterOn key tab
in Left \$ \row ->
let bool = getKey key row
in if bool then trueTab else falseTab

-- | Finally time to use this!
type TreeM r c = Free ((->) r) c

learn :: Ord k => Table k Bool -> TreeM (Row k Bool) Bool
learn = anaF discriminate

--cataF :: Functor f => (f a -> a) -> Free f a -> a
predict :: Ord k => Row k Bool -> TreeM (Row k Bool) Bool -> Bool
predict row model = cataF (\$ row) model``````

Now we are done! All the work paid off with very, very short definitions for `learn` and `predict`. Keep in mind that before `learn` and `predict`, nothing we had written involved recursion at all.

## Extensions

By changing `discriminate`, we change our learning method. Using `anaF` adds some constraints on how we can learn our model, but still allows some freedom. Here are some possible avenues for extending and improving our models:

• We could test whether the inference gained at a given node is statistically significant, for example with a chi squared test. If the key is not a significant predictor at that level, we can stop growing our decision tree early and return our best guess at that stage. This would help prevent overfitting.

• We could add fields to our table data type and set them when we split at each step of the recursion. For example, we could add a counter that prevents the tree from growing past a certain height (again, this combats overfitting). Alternatively, if we have some prior belief that certain variables work well together as predictors, we may want these variables to be close together in the decision tree. By storing the "splitting key" of the parent node, we could implement this in `discriminate`.

• The rows can be extended to hold non-Boolean data without changing the data type of our model very much. `discriminate` can grow the branchings with any function from rows to new training tables. If the complexity of each step of inference grows, we just need to make this branching function more complicated.

• The current model is totally deterministic. Recently, there has been some work in elegantly and efficiently adding probabilistic programming to Haskell (for example here), but these methods seem mostly suited to learning parametric models. Perhaps recursion combinators (or a similar idea) are one way to elegantly learn nonparametric models whose structure needs to be determined dynamically.

• It is not clear how to give the learning method control over multiple levels of the hierarchy. We could add information about the parent nodes, but we cannot go back and change them based on new information. A common learning method for decision trees is to grow out a few levels at once and decide between them, and this would be difficult to simulate in the current framework. With a wealth of recursion combinators out there, one may capture this idea very well.

## Appendix: A Printable Interface

It is hard to verify if the above information is correct because we cannot (sensibly) print out functions in Haskell. The following code fixes this by providing a datatype for the branching instead of relying on functions. This provides much less flexibility in how we do branching, but allows us to print things! The following implements the exact same algorithm as above:

``````
data Branch k r = Branch
{ key :: k
, bTrue :: r
, bFalse :: r
} deriving(Show, Functor)

discriminate' :: Ord k => Table k Bool -> Either (Branch k (Table k Bool)) Bool
discriminate' tab
| numKeys tab == 0
= Right \$ bestGuess tab
| otherwise
= let goodKey = bestKey tab
(keyTrueTab, keyFalseTab) = filterOn goodKey tab
in Left \$ Branch
{ key = goodKey
, bTrue = keyTrueTab
, bFalse = keyFalseTab
}

learn' :: Ord k => Table k Bool -> Free (Branch k) Bool
learn' = anaF discriminate'

predict' :: Ord k => Row k Bool -> Free (Branch k) Bool -> Bool
predict' row model = cataF phi model
where phi br
= case getKey (key br) row of
True -> bTrue br
False -> bFalse br``````