{-# OPTIONS_GHC -Wall #-}{-# Language TypeOperators #-}{-# Language TypeFamilies #-}{-# Language FlexibleInstances #-}moduleDvda.FunGraph(FunGraph,ToFunGraph,NumT,(:*)(..),MVS(..),toFunGraph,countNodes,fgInputs,fgOutputs,fgLookupGExpr,fgReified,topSort-- , fgGraph,nodelistToFunGraph,exprsToFunGraph)whereimportControl.ApplicativeimportData.Foldable(Foldable)importqualifiedData.FoldableasFimportqualifiedData.GraphasGraphimportData.Hashable(Hashable)importqualifiedData.HashSetasHSimportData.Traversable(Traversable)importqualifiedData.TraversableasTimportDvda.ExprimportDvda.Reify(ReifyGraph(..),reifyGraphs)dataFunGrapha=FunGraph{fgGraph::Graph.Graph,fgInputs::[MVS(GExpraInt)],fgOutputs::[MVSInt],fgReified::[(Int,GExpraInt)],fgLookupGExpr::Int->Maybe(GExpraInt),fgVertexFromKey::Int->MaybeInt,fgNodeFromVertex::Int->(GExpraInt,Int,[Int])}instanceShowa=>Show(FunGrapha)whereshowfg="FunGraph\ninputs:\n"++show(fgInputsfg)++"\noutputs:\n"++show(fgOutputsfg)++"\ngraph:\n"++show(fgGraphfg)---- | matrix or vector or scalardataMVSa=Mat[[a]]|Vec[a]|ScaaderivingShowinstanceFunctorMVSwherefmapf(Scax)=Sca(fx)fmapf(Vecxs)=Vec(mapfxs)fmapf(Matxs)=Mat(map(mapf)xs)instanceFoldableMVSwherefoldrfx0(Scax)=foldrfx0[x]foldrfx0(Vecxs)=foldrfx0xsfoldrfx0(Matxs)=foldrfx0(concatxs)instanceTraversableMVSwheretraversef(Scax)=Sca<$>fxtraversef(Vecxs)=Vec<$>T.traversefxstraversef(Matxs)=Mat<$>T.traverse(T.traversef)xsclassToFunGraphawheretypeNumTatoMVSList::a->[MVS(Expr(NumTa))]instanceToFunGraph(Expra)wheretypeNumT(Expra)=atoMVSListx=[Scax]instanceToFunGraph[Expra]wheretypeNumT[Expra]=NumT(Expra)toMVSListx=[Vecx]instanceToFunGraph[[Expra]]wheretypeNumT[[Expra]]=NumT[Expra]toMVSListx=[Matx]dataa:*b=a:*bderivingShowinfixr6:*instance(ToFunGrapha,ToFunGraphb,NumTa~NumTb)=>ToFunGraph(a:*b)wheretypeNumT(a:*b)=NumTatoMVSList(x:*y)=toMVSListx++toMVSListy-- | find any symbols which are parents of outputs, but are not supplied by the userdetectMissingInputs::(Eqa,Hashablea,Showa)=>[MVS(Expra)]->[(Int,GExpraInt)]->[GExpraInt]detectMissingInputsexprsgr=HS.toList$HS.differenceallGraphInputsallUserInputswhereallUserInputs=letf(ESymname)acc=(GSymname):accf_e=error$"detectMissingInputs given non-ESym input \""++showe++"\""inHS.fromList$foldrf[](concatMapF.toListexprs)allGraphInputs=letf(_,(GSymname))acc=(GSymname):accf_acc=accinHS.fromList$foldrf[]gr-- | if the same input symbol (like ESym "x") is given at two different places throw an exceptionfindConflictingInputs::(Eqa,Hashablea,Showa)=>[MVS(Expra)]->[Expra]findConflictingInputsexprs=HS.toListredundantwhereredundant=snd$foldlf(HS.empty,HS.empty)(concatMapF.toListexprs)wheref(knownExprs,redundantExprs)expr@(ESym_)|HS.memberexprknownExprs=(knownExprs,HS.insertexprredundantExprs)|otherwise=(HS.insertexprknownExprs,redundantExprs)f_e=error$"findConflictingInputs saw non-ESym input \""++showe++"\""-- | Take inputs and outputs which are of classes ToFunGraph (heterogenous lists of @Expr a@)-- and traverse the outputs reifying all expressions and creating a hashmap of StableNames (stable pointers).-- Once the hashmap is created, lookup the provided inputs and return a FunGraph which contains an-- expression graph, input/output indices, and other useful functions. StableNames is non-deterministic-- so this function may return graphs with more or fewer CSE's eliminated.-- If CSE is then performed on the graph, the result is deterministic.toFunGraph::(Eqa,Hashablea,Showa,ToFunGraphb,ToFunGraphc,NumTb~a,NumTc~a)=>b->c->IO(FunGrapha)toFunGraphinputsoutputs=mvsToFunGraph(toMVSListinputs)(toMVSListoutputs)mvsToFunGraph::(Eqa,Hashablea,Showa)=>[MVS(Expra)]->[MVS(Expra)]->IO(FunGrapha)mvsToFunGraphinputMVSExprsoutputMVSExprs=do-- reify the outputs(ReifyGraphrgr,outputMVSIndices)<-reifyGraphsoutputMVSExprsletfg=nodelistToFunGraphrgrinputMVSGExprsoutputMVSIndicesinputMVSGExprs=map(fmapf)inputMVSExprswheref(ESymname)=(GSymname)fx=error$"ERROR: mvsToFunGraph given non-ESym input \""++showx++"\""return$case(detectMissingInputsinputMVSExprsrgr,findConflictingInputsinputMVSExprs)of([],[])->fg(xs,[])->error$"mvsToFunGraph found inputs that were not provided by the user: "++showxs(_,xs)->error$"mvsToFunGraph found idential inputs set more than once: "++showxsnodelistToFunGraph::[(Int,GExpraInt)]->[MVS(GExpraInt)]->[MVSInt]->FunGraphanodelistToFunGraphrgrinputMVSIndicesoutputMVSIndices=FunGraph{fgGraph=gr,fgInputs=inputMVSIndices,fgOutputs=outputMVSIndices,fgLookupGExpr=lookupG,fgReified=rgr,fgVertexFromKey=lookupKey,fgNodeFromVertex=lookupVertex}where-- make sure all the inputs are symbolic, and find their indices in the Expr graph(gr,lookupVertex,lookupKey)=Graph.graphFromEdges$map(\(k,gexpr)->(gexpr,k,getParentsgexpr))rgrlookupGk=(\(g,_,_)->g)<$>lookupVertex<$>lookupKeyk---------------------------------- utilities -----------------------------countNodes::FunGrapha->IntcountNodes=length.Graph.vertices.fgGraphtopSort::FunGrapha->[Int]topSortfg=map((\(_,k,_)->k).(fgNodeFromVertexfg))$Graph.topSort(fgGraphfg)-- | make a FunGraph out of outputs, automatically detecting the proper inputsexprsToFunGraph::(Eqa,Showa,Hashablea)=>[Expra]->IO(FunGrapha)exprsToFunGraphoutputs=doletgetSyms::[Expra]->[Sym]getSymsexprs=HS.toList$foldr(\accexpr->foldExprfexpracc)HS.emptyexprswheref(ESyms)hs=HS.insertshsf_hs=hsinputs=mapESym$getSymsoutputstoFunGraphinputsoutputs