Clay Thomas
Free f a
, the free monad over a given functor f
is often described as "trees which branch in the shape of f
and are leaf-labeled by a
". What exactly does this mean? Well, the definition of Free
is:
data Free f a
= Pure a
| Free (f (Free f a))
So if we have a functor data Pair a = Pair a a
, we indeed have Free Pair a
representing ordinary leaf labeled, nonempty binary trees.
Now, if you know about decision trees, "leaf labeled" should catch your ear. (Binary) decision trees are trees where each node represents a yes/no feature about some observation. The leafs of the tree are labeled with distributions. To predict, you descend the tree according to an observation, and the distribution you reach at the bottom is your best guess at the distribution for your new observation.
So if we want to fit a monad Free f a
to the task of a decision tree, it is clear that a
should represent the distribution you are predicting. For the sake of simplicity, we will simply set a
to Bool
and will just guess a yes or no, as opposed to providing percentages.
Our choice of f
is a little less clear. We need to read in some information from an observation, say of type r
, and descend into the next level of the tree. Well, speaking of reading information, what about the reader functor (->) r
! If we try f = (->) r
, we get
data Free (-> r) Bool
= Pure Bool
| Free (r -> Free (-> r) Bool)
This looks like exactly what we want! We fix a row type r
, and when provided new rows we can traverse through internal Free
nodes until we reach a leaf Pure
node. So, we define
-- | A model with row type `r` and class label type `c`
type TreeM r c = Free ((->) r) c
The observation presented here is hardly earth-shattering, but it does come with some advantages other than "hey, a connection!". But adapting some recursive combinators (which typically act on the Fix
data type) we can separate the recursive logic of our program from the actual computation. The goal of recursion combinators is to standardize patterns in recursion and make their implementations cleaner.
Free f a
is strikingly similar to the type Fix f
, the fixed point of the functor f
. Recall
newtype Fix f = Fix { unFix :: f (Fix f) }
Essentially, Free
allows us to stop our infinite, f
-branching tree early and return a Pure
value of type a
. In many applications, the real result of this is that the functors f
that you use with Fix
are more complicated than those you use with Free
because with Fix
you need to embed the notion of returning data into your base functor. (For example, we could model Free f a
itself using Fix
with a base functor data Br f a r = Either a (f r) deriving Functor
, and get Free f a === Fix (Br f a)
.)
For Fix
, catamorphisms and anamorphisms are very useful recursion combinators, which respectively collapse and grow a recursive structure:
-- | Tear down a recursive structure, i.e. an element of Fix f.
-- First we recursively tear down each of the subtrees in the
-- first level of the fixed point. Then we collapse the
-- last layer using the algebra directly.
cata :: Functor f => (f a -> a) -- ^ A algebra to collapse a container to a value
-> Fix f -- ^ A recursive tree of containers to start with
-> a
cata alg fix = alg . fmap (cata alg) . unFix $ fix
-- | Build up a recursive structure, i.e. an element of Fix f.
-- We first expand our seed by one step.
-- Then we map over the resultant container,
-- recursively expanding each value along the way.
ana :: Functor f => (a -> f a) -- ^ A function to expand an a
-> a -- ^ A seed value to start off with
-> Fix f
ana grow seed = Fix . fmap (ana grow) . grow $ seed
It is pretty easy to extend cata
to work on Free
instead of Fix
, we just need to add a case for Pure
:
cataF :: Functor f => (f a -> a) -> Free f a -> a
cataF alg (Free u) = alg . fmap (cataF alg) $ u
cataF _ (Pure a) = a
It is somewhat harder to logically extend ana. Indeed, the exact same code that worked for Fix works for free, it just always builds infinite trees and never allows Pure
. Thus the input function has to allow for the possibility of a Pure
value:
anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b
anaF grow seed
= case grow seed of
Left u -> Free . fmap (anaF grow) $ u
Right b -> Pure b
We start with a preamble to equip some convenient language extensions and import the needed libraries. Then we define our table data type and some simplified accessor functions.
What follows is complete and valid Haskell can be run on a modern GHC. You need only the code below this point, along with our definitions of cataF
and anaF
above. You can also snag the code here.
{-# LANGUAGE RankNTypes
, RelaxedPolyRec
, DeriveFunctor
, TupleSections
, ScopedTypeVariables
, UndecidableInstances
#-}
import qualified Data.List as List
import qualified Data.MultiSet as Set
import qualified Data.Map as Map
import Control.Monad.Free
-- | value = Bool will suffice for this code, but
-- more general Tables are certainly reasonable
data Table key value = Table
{ keys :: [key]
-- ^ For iterating and looping purposes
, rows :: Set.MultiSet (value, Row key value)
-- ^ An unordered collection of Rows associated to labels
} deriving(Show)
type Row key value = Map.Map key value
numKeys :: Table k v -> Int
numKeys = length . keys
numRows :: Table k v -> Int
numRows = Set.size . rows
--Assume all tables are full
getKey :: Ord k => k -> Row k v -> v
getKey k row
= maybe undefined id (Map.lookup k row)
emptyBinTable :: Table key value
emptyBinTable = Table [] Set.empty
Now, the learning method of decision trees is (roughly) the following:
The following code implements several tools we will need:
-- | We score the keys by how many labels the key could get correct.
-- This method returns two values: the first is if the model applied a
-- positive correlation, the second is if it assumes a negative correlation.
scores :: Ord k => k -> Table k Bool -> (Int, Int)
scores k tab = Set.fold (indicator k) ((,0) 0) $ rows tab
where indicator :: Ord k => k -> (Bool, Row k Bool) -> (Int,Int) -> (Int,Int)
indicator k (label, row) (pos, neg)
= case Map.lookup k row of
Just a -> (pos + fromEnum (label==a), neg + fromEnum (label/=a))
Nothing -> (pos, neg)
-- | Loop over all the keys and find the one that predicts with highest accuracy
bestKey :: Ord k => Table k Bool -> k
bestKey tab
= let bestScore k
= let (pos, neg) = scores k tab
in (k, max pos neg)
maxScores = fmap bestScore (keys tab)
(bestKey, _)
= List.maximumBy (\(_,s) (_,s') -> s `compare` s') maxScores
in bestKey
removeKey :: (Ord k, Ord v) => k -> Table k v -> Table k v
removeKey k tab =
emptyBinTable
{ keys = (List.\\) (keys tab) [k]
, rows = Set.map (\(lab,row) -> (lab, Map.delete k row)) (rows tab)
}
-- | Split a table into two tables based on the value of one key.
-- Also remove the key from the new tables.
filterOn :: Ord k => k -> Table k Bool -> (Table k Bool, Table k Bool)
filterOn k tab
= let kTrue = getKey k . snd -- ^ predicate to test if key k is true for some row
trueRows = Set.filter kTrue (rows tab)
falseRows = Set.filter (not . kTrue) (rows tab)
in (removeKey k tab{rows = trueRows}, removeKey k tab{rows = falseRows})
-- | Ignore all the rows and just guess a boolean based on class label
bestGuess :: Ord k => Table k Bool -> Bool
bestGuess tab
= let nTrue = Set.fold (\(b,_) accum -> accum + fromEnum b) 0 (rows tab)
-- ^ count number of true labels
in if 2 * nTrue >= numRows tab
then True
else False
These are all straightforward things that we would probably implement if we were writing this algorithm without recursion combinators.
Now that we have some functions to manipulate and extract information out of our tables, we are ready to learn our models and predict with them. We write the function discriminate
to fit the type signature of anaF
. This function accepts a table and returns one of two things:
If the table contains only class labels, we return a single Bool that represents our best guess of the class label.
If the table still has some keys, we return a mapping. This mapping takes in any row, and returns a split of the data based on the value of bestKey
within the row.
Recall that anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b
and that type TreeM r c = Free ((->) r) c
. When we apply anaF
to discriminate
, it has the effect of recursing over the newly created tables, crowing more models until we hit the base case of a table with only class labels.
discriminate :: Ord k => Table k Bool -> Either (Row k Bool -> Table k Bool) Bool
discriminate tab
| numKeys tab == 0
= Right $ bestGuess tab
| otherwise
= let key = bestKey tab
(trueTab, falseTab) = filterOn key tab
in Left $ \row ->
let bool = getKey key row
in if bool then trueTab else falseTab
-- | Finally time to use this!
type TreeM r c = Free ((->) r) c
learn :: Ord k => Table k Bool -> TreeM (Row k Bool) Bool
learn = anaF discriminate
--cataF :: Functor f => (f a -> a) -> Free f a -> a
predict :: Ord k => Row k Bool -> TreeM (Row k Bool) Bool -> Bool
predict row model = cataF ($ row) model
Now we are done! All the work paid off with very, very short definitions for learn
and predict
. Keep in mind that before learn
and predict
, nothing we had written involved recursion at all.
By changing discriminate
, we change our learning method. Using anaF
adds some constraints on how we can learn our model, but still allows some freedom. Here are some possible avenues for extending and improving our models:
We could test whether the inference gained at a given node is statistically significant, for example with a chi squared test. If the key is not a significant predictor at that level, we can stop growing our decision tree early and return our best guess at that stage. This would help prevent overfitting.
We could add fields to our table data type and set them when we split at each step of the recursion. For example, we could add a counter that prevents the tree from growing past a certain height (again, this combats overfitting). Alternatively, if we have some prior belief that certain variables work well together as predictors, we may want these variables to be close together in the decision tree. By storing the "splitting key" of the parent node, we could implement this in discriminate
.
The rows can be extended to hold non-Boolean data without changing the data type of our model very much. discriminate
can grow the branchings with any function from rows to new training tables. If the complexity of each step of inference grows, we just need to make this branching function more complicated.
The current model is totally deterministic. Recently, there has been some work in elegantly and efficiently adding probabilistic programming to Haskell (for example here), but these methods seem mostly suited to learning parametric models. Perhaps recursion combinators (or a similar idea) are one way to elegantly learn nonparametric models whose structure needs to be determined dynamically.
It is not clear how to give the learning method control over multiple levels of the hierarchy. We could add information about the parent nodes, but we cannot go back and change them based on new information. A common learning method for decision trees is to grow out a few levels at once and decide between them, and this would be difficult to simulate in the current framework. With a wealth of recursion combinators out there, one may capture this idea very well.
It is hard to verify if the above information is correct because we cannot (sensibly) print out functions in Haskell. The following code fixes this by providing a datatype for the branching instead of relying on functions. This provides much less flexibility in how we do branching, but allows us to print things! The following implements the exact same algorithm as above:
data Branch k r = Branch
{ key :: k
, bTrue :: r
, bFalse :: r
} deriving(Show, Functor)
discriminate' :: Ord k => Table k Bool -> Either (Branch k (Table k Bool)) Bool
discriminate' tab
| numKeys tab == 0
= Right $ bestGuess tab
| otherwise
= let goodKey = bestKey tab
(keyTrueTab, keyFalseTab) = filterOn goodKey tab
in Left $ Branch
{ key = goodKey
, bTrue = keyTrueTab
, bFalse = keyFalseTab
}
learn' :: Ord k => Table k Bool -> Free (Branch k) Bool
learn' = anaF discriminate'
predict' :: Ord k => Row k Bool -> Free (Branch k) Bool -> Bool
predict' row model = cataF phi model
where phi br
= case getKey (key br) row of
True -> bTrue br
False -> bFalse br