{- | Junction Trees
The Tree data structures are not working very well with message passing algorithms. So, junction trees are using
a different representation
-}{-# LANGUAGE FlexibleInstances #-}{-# LANGUAGE MultiParamTypeClasses #-}{-# LANGUAGE FunctionalDependencies #-}{-# LANGUAGE GeneralizedNewtypeDeriving #-}moduleBayes.FactorElimination.JTree(IsCluster(..),Cluster(..),JTree(..),JunctionTree(..),Sep,setFactors,distribute,collect,fromCluster,changeEvidence,nodeIsMemberOfTree,singletonTree,addNode,addSeparator,leaves,nodeValue,NodeValue(..),SeparatorValue(..),downMessage,upMessage,nodeParent,nodeChildren,traverseTree,separatorChild,treeNodes,treeValues,displayTreeValues,Action(..))whereimportqualifiedData.MapasMapimportqualifiedData.TreeasTreeimportData.Maybe(fromJust,mapMaybe)importqualifiedData.SetasSetimportData.MonoidimportData.List((\\),intersect,partition,foldl',minimumBy,nub)importBayes.PrivateTypesimportBayes.FactorimportBayesimportData.Function(on)importBayes.VariableElimination(marginal)importDebug.Tracedebugsa=trace(s++" "++showa++"\n")atypeUpMessagea=atypeDownMessagea=Maybea-- | Separator valuedataSeparatorValuea=SeparatorValue!(UpMessagea)!(DownMessagea)|EmptySeparator-- ^ Use to track the progress in the collect phasederiving(Eq)instanceShowa=>Show(SeparatorValuea)whereshowEmptySeparator=""show(SeparatorValueuNothing)="u("++showu++")"show(SeparatorValueu(Justd))="u("++showu++") d("++showd++")"typeFactorValuesa=[a]typeEvidenceValuesa=[a]-- | Node valuedataNodeValuea=NodeValue!Vertex!(FactorValuesa)!(EvidenceValuesa)deriving(Eq)instanceShowa=>Show(NodeValuea)whereshow(NodeValuevfe)="f("++showf++") e("++showe++")"newtypeSep=SepIntderiving(Eq,Ord,Show,Num)-- | Junction tree.-- 'c' is the node / separator identifier (for instance a set of 'DV')-- a are the values for a node or separator-- Cluster are unique sor the cluster value is also the cluster key-- Separator values are not unique. Two different seperators can be the same-- cluster. So, separator unicity is enforced with a numberdataJTreecf=JTree{root::!c-- | Leaves of the tree,leavesSet::!(Set.Setc)-- | The children of a node are separators,childrenMap::!(Map.Mapc[Sep])-- | Parent of a node,parentMap::!(Map.MapcSep)-- | Parent of a separator,separatorParentMap::!(Map.MapSepc)-- | The child of a seperator is a node,separatorChildMap::!(Map.MapSepc)-- | Values for nodes and seperators,nodeValueMap::!(Map.Mapc(NodeValuef)),separatorValueMap::!(Map.MapSep(SeparatorValuef)),separatorCurrentKey::!Sep,separatorClusterMap::!(Map.MapSepc)}deriving(Eq)-- | Create a singleton tree with just one root nodesingletonTreerrootVertexfactorValueevidenceValue=lett=JTreerSet.emptyMap.emptyMap.emptyMap.emptyMap.emptyMap.emptyMap.empty(Sep0)Map.emptyinaddNoderrootVertexfactorValueevidenceValuet-- | Reset all evidences to 1 in the networkresetEvidences::Factorf=>JTreecf->JTreecfresetEvidencest=t{nodeValueMap=Map.mapresetNodeEvidence(nodeValueMapt)}whereresetNodeEvidence(NodeValuevf_)=NodeValuevf[]-- | Get the cluster for a separatorseparatorCluster::JTreeca->Sep->cseparatorClusterts=fromJust$Map.lookups(separatorClusterMapt)-- | Leaves of the treeleaves::JTreeca->[c]leaves=Set.toList.leavesSet-- | All nodes of the treetreeNodes::JTreeca->[c]treeNodes=Map.keys.nodeValueMaptreeValues::JTreecf->[(c,NodeValuef)]treeValues=Map.toList.nodeValueMap-- | Value of a nodenodeValue::Ordc=>JTreeca->c->NodeValueanodeValuete=fromJust$Map.lookupe(nodeValueMapt)-- | Change the value of a nodesetNodeValue::Ordc=>c->NodeValuea->JTreeca->JTreecasetNodeValuecvt=t{nodeValueMap=Map.insertcv(nodeValueMapt)}-- | Parent of a nodenodeParent::Ordc=>JTreeca->c->MaybeSepnodeParentte=Map.lookupe(parentMapt)-- | Value of a nodeseparatorValue::Ordc=>JTreeca->Sep->SeparatorValueaseparatorValuete=fromJust$Map.lookupe(separatorValueMapt)-- | Parent of a separatorseparatorParent::Ordc=>JTreeca->Sep->cseparatorParentte=fromJust$Map.lookupe(separatorParentMapt)-- | UpMessage for a separator nodeupMessage::Ordc=>JTreeca->Sep->aupMessagetc=caseseparatorValuetcofSeparatorValueup_->up_->error"Trying to get an up message on an empty seperator ! Should never occur !"-- | DownMessage for a separator nodedownMessage::Ordc=>JTreeca->Sep->MaybeadownMessagetc=caseseparatorValuetcofSeparatorValue_(Justdown)->JustdownSeparatorValue_Nothing->Nothing_->error"Trying to get a down message on an empty separator ! Should never occur !"-- | Return the separator childrens of a nodenodeChildren::Ordc=>JTreeca->c->[Sep]nodeChildrente=maybe[]id$Map.lookupe(childrenMapt)-- | Return the child of a separatorseparatorChild::Ordc=>JTreeca->Sep->cseparatorChildte=fromJust$Map.lookupe(separatorChildMapt)-- | Check if a node is member of the treenodeIsMemberOfTree::Ordc=>c->JTreeca->BoolnodeIsMemberOfTreect=Map.memberc(nodeValueMapt)-- | Add a separator between two nodes.-- The nodes MUST already be in the treeaddSeparator::(Ordc)=>c-- ^ Origin node ->c-- ^ Separator value->c-- ^ Destination node ->JTreeca-- ^ Current tree ->JTreeca-- ^ Modified tree addSeparatornodesepClusterdestt=letnewSep=(separatorCurrentKeyt)+1int{childrenMap=Map.insertWith'(++)node[newSep](childrenMapt),separatorChildMap=Map.insertnewSepdest(separatorChildMapt),separatorValueMap=Map.insertnewSepEmptySeparator(separatorValueMapt),separatorClusterMap=Map.insertnewSepsepCluster(separatorClusterMapt),leavesSet=Set.deletenode(leavesSett),parentMap=Map.insertdestnewSep(parentMapt),separatorParentMap=Map.insertnewSepnode(separatorParentMapt),separatorCurrentKey=newSep}-- | Add a new nodeaddNode::(Ordc)=>c-- ^ Node->Vertex->[a]-- ^ Factor value ->[a]-- ^ Evidence value->JTreeca->JTreecaaddNodenodevertexfactorValueevidenceValuet=t{nodeValueMap=Map.insertnode(NodeValuevertexfactorValueevidenceValue)(nodeValueMapt),leavesSet=Set.insertnode(leavesSett)}-- | Update the up message of a separatorupdateUpMessage::Ordc=>MaybeSep-- ^ Separator node to update (if any : none for root node)->a-- ^ New value->JTreeca-- ^ Old tree->JTreecaupdateUpMessageNothing_t=tupdateUpMessage(Justsep)newvalt=letnewSepValue=caseseparatorValuetsepofEmptySeparator->SeparatorValuenewvalNothingSeparatorValueupdown->SeparatorValuenewvaldownint{separatorValueMap=Map.insertsepnewSepValue(separatorValueMapt)}-- | Update the down message of a separatorupdateDownMessage::Ordc=>Sep-- ^ Separator node to update->a-- ^ New value->JTreeca-- ^ Old tree->JTreecaupdateDownMessagesepnewvalt=letnewSepValue=caseseparatorValuetsepofEmptySeparator->error"Can't set a down message on an empty separator"SeparatorValueup_->SeparatorValueup(Justnewval)int{separatorValueMap=Map.insertsepnewSepValue(separatorValueMapt)}{-
Message passing algorithms
-}-- | Functions used to generate new messagesclassMessagefc|f->cwhere-- | Generate a new message from the received onesnewMessage::[f]->NodeValuef->c->f-- | Check that a separator is initializedseparatorInitialized::SeparatorValuea->BoolseparatorInitializedEmptySeparator=FalseseparatorInitialized_=TrueallSeparatorsHaveReceivedAMessage::Ordc=>JTreeca-- ^ Tree->[Sep]-- ^ Separators->BoolallSeparatorsHaveReceivedAMessagetseps=allseparatorInitialized.map(separatorValuet)$seps-- | Update the up separator by sending a message-- But only if all the down separators have received a messageupdateUpSeparator::(Messageac,Ordc)=>JTreeca->c-- ^ Node generating the new upMessage->JTreecaupdateUpSeparatorth=letseps=nodeChildrenthincaseallSeparatorsHaveReceivedAMessagetsepsofFalse->tTrue->letincomingMessages=map(upMessaget)sepscurrentValue=nodeValuethdestinationNode=nodeParentthincasedestinationNodeofNothing->t-- When rootJustp->letsepC=separatorClustertpgeneratedMessage=newMessageincomingMessagescurrentValuesepCinupdateUpMessagedestinationNodegeneratedMessaget-- | Update the down separator by sending a messageupdateDownSeparator::(Messageac,Ordc)=>c-- ^ Node generating the message ->JTreeca->Sep-- ^ Child receiving the message->JTreecaupdateDownSeparatornodetchild=letincomingMessagesFromBelow=map(upMessaget)(nodeChildrentnode\\[child])messageFromAbove=downMessaget=<<(nodeParenttnode)incomingMessages=maybeincomingMessagesFromBelow(\x->x:incomingMessagesFromBelow)messageFromAbovecurrentValue=nodeValuetnodechildC=separatorClustertchildgeneratedMessage=newMessageincomingMessagescurrentValuechildCinupdateDownMessagechildgeneratedMessagetunique::Ordc=>[c]->[c]unique=Set.toList.Set.fromList-- | Collect message taking into account that the tree depth may be different for different leaves.collect::(Ordc,Messageac)=>JTreeca->JTreecacollectt=_collectNodes(leavest)t_collectSeparators::(Ordc,Messageac)=>[Sep]->JTreeca-- ^ Tree->JTreeca-- ^ Modified tree_collectSeparatorslt=_collectNodes(unique.map(separatorParentt)$l)t_collectNodes::(Ordc,Messageac)=>[c]->JTreeca-- ^ Tree->JTreeca-- ^ Modified tree _collectNodes[]t=t_collectNodeslt=letnewTree=foldl'updateUpSeparatortlin_collectSeparators(mapMaybe(nodeParentt)l)newTreedistribute::(Ordc,Messageac)=>JTreeca->JTreecadistributet=_distributeNodest(roott)_distributeSeparators::(Ordc,Messageac)=>JTreeca->Sep-- ^ Destination of the distribute->JTreeca_distributeSeparatorstnode=_distributeNodest(separatorChildtnode)_distributeNodes::(Ordc,Messageac)=>JTreeca->c-- ^ Destination of the distribute->JTreeca_distributeNodestnode=letchildren=nodeChildrentnodenewTree=foldl'(updateDownSeparatornode)t$childreninfoldl'_distributeSeparatorsnewTreechildren{-
Factors and evidence modifications
-}-- | This class is used to check if evidence or a factor is relevant-- for a clusterclassIsClustercwhere-- | Evidence contained in the clusteroverlappingEvidence::c->[DVIInt]->[DVIInt]-- | Cluser variablesclusterVariables::c->[DV]-- | Intersection of two clustersmkSeparator::c->c->cinstanceIsCluster[DV]whereoverlappingEvidencece=filter(\x->instantiationVariablex`elem`c)eclusterVariables=idmkSeparator=intersectdataActionsa=Skip!s|ModifyAndStop!s!a|Modify!s!a|Stop!s-- | Traverse a tree and modify ittraverseTree::Ordc=>(s->c->NodeValuef->Actions(NodeValuef))-- ^ Modification function->s-- ^ Current state->JTreecf-- ^ Input tree->(JTreecf,s)traverseTreeactionstatet=_traverseTreeNodesaction(t,state)(roott)_traverseTreeSeparatorsaction(t,state)current=_traverseTreeNodesaction(t,state)(separatorChildtcurrent)_traverseTreeNodesaction(t,state)current=caseactionstatecurrent(nodeValuetcurrent)ofStopnewState->(t,newState)ModifyAndStop_newValue->(setNodeValuecurrentnewValuet,state)SkipnewState->foldl'(_traverseTreeSeparatorsaction)(t,newState)(nodeChildrentcurrent)ModifynewStatenewValue->letnewTree=setNodeValuecurrentnewValuetinfoldl'(_traverseTreeSeparatorsaction)(newTree,newState)(nodeChildrennewTreecurrent)mapWithCluster::Ordc=>(c->NodeValuef->NodeValuef)->JTreecf->JTreecfmapWithClusterft=t{nodeValueMap=Map.mapWithKeyf(nodeValueMapt)}-- | Set the factors in the tree updateTreeValues::(Factorf,IsClusterc,Ordc,Showc,Showf)=>(f->NodeValuef->NodeValuef)->[f]->JTreecf->JTreecfupdateTreeValueschangefactorst=letallNodes=treeNodestfactorIncludedInClusterfc=all(`elem`clusterVariablesc)(factorVariablesf)coveringClustersf=filter(f`factorIncludedInCluster`)allNodesclusterSizea=product.map(fromIntegral.dimension).clusterVariables$a::IntegeraddFactortnewFactor=letminimumCluster=minimumBy(compare`on`clusterSize)(coveringClustersnewFactor)clusterValue=nodeValuetminimumClusterinsetNodeValueminimumCluster(changenewFactorclusterValue)tinfoldl'addFactortfactors-- | Set the factors in the tree setFactors::(Graphg,Factorf,IsClusterc,Ordc,Showc,Showf)=>BayesianNetworkgf->JTreecf->JTreecfsetFactorsgt=letfactors=allVertexValuesgchangeFactorf(NodeValuevoldfe)=NodeValuev(f:oldf)einupdateTreeValueschangeFactorfactorst-- | Change evidence in the networkchangeEvidence::(IsClusterc,Ordc,Factorf,Messagefc,Showc,Showf)=>[DVIInt]-- ^ Evidence->JTreecf->JTreecfchangeEvidenceet=letevidences=mapfactorFromInstantiationechangeEvidencenewe(NodeValuevfolde)=NodeValuevf(newe:olde)indistribute.collect.updateTreeValueschangeEvidenceevidences.resetEvidences$t{separatorValueMap=Map.map(constEmptySeparator)(separatorValueMapt)}-- | Cluster of discrete variables.-- Discrete variables instead of vertices are needed because the-- factor are using 'DV' and we need to find-- which factors must be contained in a given cluster.newtypeCluster=Cluster(Set.SetDV)deriving(Eq,Ord)instanceIsClusterClusterwhereoverlappingEvidencec=overlappingEvidence(fromClusterc)clusterVariablesc=clusterVariables(fromClusterc)mkSeparator(Clustera)(Clusterb)=Cluster(Set.intersectionab)instanceShowClusterwhereshow(Clusters)=show.Set.toList$sfromCluster(Clusters)=Set.toListsinstanceFactorf=>MessagefClusterwherenewMessageinput(NodeValue_fe)dv=letallFactors=f++e++inputvariablesToKeep=fromClusterdvvariablesToRemove=(nub(concatMapfactorVariablesallFactors))\\variablesToKeepinmarginalallFactorsvariablesToRemovevariablesToKeep[]typeJunctionTreef=JTreeClusterf{-
Implement the show function to see the structure of the tree
(without the values)
-}dataNodeKindc=N!c|S!clabelTrueca=c++"="++showalabelFalsec_=c-- | Convert the JTree into a tree of string-- using the cluster.toTree::(Ordc,Showc,Showa)=>Bool-- ^ True if the data must be displayed->JTreeca->Tree.TreeStringtoTreedt=letr=roottv=nodeValuetrnodec=nodeChildrentrinTree.Node(labeld(showr)v)(_toTreeSeparatorsdtnodec)_toTreeNodes::(Ordc,Showc,Showa)=>Bool->JTreeca->[c]->[Tree.TreeString]_toTreeNodes__[]=[]_toTreeNodesdt(h:l)=letnodec=nodeChildrenth-- Node children are separatorsv=nodeValuethinTree.Node(labeld(showh)v)(_toTreeSeparatorsdtnodec):_toTreeNodesdtl_toTreeSeparators::(Ordc,Showc,Showa)=>Bool->JTreeca->[Sep]->[Tree.TreeString]_toTreeSeparators__[]=[]_toTreeSeparatorsdt(h:l)=letseparatorc=[separatorChildth]-- separator child is a nodev=separatorValuethinTree.Node(labeld("<"++show(separatorClusterth)++">")v)(_toTreeNodesdtseparatorc):_toTreeSeparatorsdtlinstance(Ordc,Showc,Showa)=>Show(JTreeca)whereshow=Tree.drawTree.toTreeFalsedisplayTreeb=Tree.drawTree.toTreeb-- | Display the tree valuesdisplayTreeValues::(Showf,Showc)=>JTreecf->IO()displayTreeValuest=letallValues=treeValuestprintAValue(c,NodeValue_fe)=doprintcputStrLn"FACTOR"printfputStrLn"EVIDENCE"printeputStrLn"------"inmapM_printAValueallValues