Source code for mxnet.gluon.block

# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# coding: utf-8# pylint: disable= arguments-differ, too-many-lines"""Base container class for all neural network models."""__all__=['Block','HybridBlock','SymbolBlock']importthreadingimportcopyimportwarningsimportrefromcollectionsimportOrderedDictfrom..baseimportmx_real_tfrom..importsymbol,ndarray,initializerfrom..symbolimportSymbolfrom..ndarrayimportNDArrayfrom..importnameas_namefrom.parameterimportParameter,ParameterDict,DeferredInitializationErrorfrom.utilsimport_indent,_brief_print_list,HookHandleclass_BlockScope(object):"""Scope for collecting child `Block` s."""_current=threading.local()def__init__(self,block):self._block=blockself._counter={}self._old_scope=Noneself._name_scope=None@staticmethoddefcreate(prefix,params,hint):"""Creates prefix and params for new `Block`."""current=getattr(_BlockScope._current,"value",None)ifcurrentisNone:ifprefixisNone:ifnothasattr(_name.NameManager._current,"value"):_name.NameManager._current.value=_name.NameManager()prefix=_name.NameManager._current.value.get(None,hint)+'_'ifparamsisNone:params=ParameterDict(prefix)else:params=ParameterDict(params.prefix,params)returnprefix,paramsifprefixisNone:count=current._counter.get(hint,0)prefix='%s%d_'%(hint,count)current._counter[hint]=count+1ifparamsisNone:parent=current._block.paramsparams=ParameterDict(parent.prefix+prefix,parent._shared)else:params=ParameterDict(params.prefix,params)returncurrent._block.prefix+prefix,paramsdef__enter__(self):ifself._block._empty_prefix:returnselfself._old_scope=getattr(_BlockScope._current,"value",None)_BlockScope._current.value=selfself._name_scope=_name.Prefix(self._block.prefix)self._name_scope.__enter__()returnselfdef__exit__(self,ptype,value,trace):ifself._block._empty_prefix:returnself._name_scope.__exit__(ptype,value,trace)self._name_scope=None_BlockScope._current.value=self._old_scopedef_flatten(args,inout_str):ifisinstance(args,NDArray):return[args],int(0)ifisinstance(args,Symbol):length=len(args.list_outputs())length=lengthiflength>1else0return[args],int(length)assertisinstance(args,(list,tuple)), \
"HybridBlock %s must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(inout_str,str(args),str(type(args)))flat=[]fmts=[]foriinargs:arg,fmt=_flatten(i,inout_str)flat.extend(arg)fmts.append(fmt)returnflat,fmtsdef_regroup(args,fmt):ifisinstance(fmt,int):iffmt==0:returnargs[0],args[1:]returnargs[:fmt],args[fmt:]assertisinstance(args,(list,tuple)), \
"HybridBlock output must be (nested) list of Symbol or NDArray, " \
"but got %s of type %s"%(str(args),str(type(args)))ret=[]foriinfmt:res,args=_regroup(args,i)ret.append(res)returnret,args

[docs]classBlock(object):"""Base class for all neural network layers and models. Your models should subclass this class. :py:class:`Block` can be nested recursively in a tree structure. You can create and assign child :py:class:`Block` as regular attributes:: from mxnet.gluon import Block, nn from mxnet import ndarray as F class Model(Block): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) # use name_scope to give child Blocks appropriate names. with self.name_scope(): self.dense0 = nn.Dense(20) self.dense1 = nn.Dense(20) def forward(self, x): x = F.relu(self.dense0(x)) return F.relu(self.dense1(x)) model = Model() model.initialize(ctx=mx.cpu(0)) model(F.zeros((10, 10), ctx=mx.cpu(0))) Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params` will collect their Parameters recursively. You can also manually register child blocks with :py:meth:`register_child`. Parameters ---------- prefix : str Prefix acts like a name space. All children blocks created in parent block's :py:meth:`name_scope` will have parent block's prefix in their name. Please refer to `naming tutorial <http://mxnet.incubator.apache.org/tutorials/gluon/naming.html>`_ for more info on prefix and naming. params : ParameterDict or None :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example, if you want ``dense1`` to share ``dense0``'s weights, you can do:: dense0 = nn.Dense(20) dense1 = nn.Dense(20, params=dense0.collect_params()) """def__init__(self,prefix=None,params=None):self._empty_prefix=prefix==''self._prefix,self._params=_BlockScope.create(prefix,params,self._alias())self._name=self._prefix[:-1]ifself._prefix.endswith('_')elseself._prefixself._scope=_BlockScope(self)self._children=OrderedDict()self._reg_params={}self._forward_hooks=OrderedDict()self._forward_pre_hooks=OrderedDict()def__repr__(self):s='{name}(\n{modstr}\n)'modstr='\n'.join([' ({key}): {block}'.format(key=key,block=_indent(block.__repr__(),2))forkey,blockinself.__dict__.items()ifisinstance(block,Block)])returns.format(name=self.__class__.__name__,modstr=modstr)def__setattr__(self,name,value):"""Registers parameters."""ifhasattr(self,name):existing=getattr(self,name)ifisinstance(existing,(Parameter,Block))andnotisinstance(value,type(existing)):raiseTypeError('Changing attribute type for {name} from {type1} to {type2}' \
'is not allowed.'.format(name=name,type1=type(existing),type2=type(value)))ifisinstance(value,Block):self.register_child(value,name)elifisinstance(value,Parameter):assertnamenotinself._reg_params, \
"Overriding Parameter attribute %s is not allowed. " \
"If you want to share parameters between blocks, please set " \
"'params' at Block construction instead."self._reg_params[name]=valuesuper(Block,self).__setattr__(name,value)def_check_container_with_block(self):children=set(self._children.values())def_find_unregistered_block_in_container(data):# Find whether a nested container structure contains Blocksifisinstance(data,(list,tuple)):foreleindata:if_find_unregistered_block_in_container(ele):returnTruereturnFalseelifisinstance(data,dict):for_,vindata.items():if_find_unregistered_block_in_container(v):returnTruereturnFalseelifisinstance(data,Block):returnnotdatainchildrenelse:returnFalsefork,vinself.__dict__.items():ifisinstance(v,(list,tuple,dict))andnot(k.startswith('__')ork=='_children'):if_find_unregistered_block_in_container(v):warnings.warn('"{name}" is an unregistered container with Blocks. ''Note that Blocks inside the list, tuple or dict will not be ''registered automatically. Make sure to register them using ''register_child() or switching to ''nn.Sequential/nn.HybridSequential instead. '.format(name=self.__class__.__name__+"."+k),stacklevel=3)def_alias(self):returnself.__class__.__name__.lower()@propertydefprefix(self):"""Prefix of this :py:class:`Block`."""returnself._prefix@propertydefname(self):"""Name of this :py:class:`Block`, without '_' in the end."""returnself._name

[docs]defname_scope(self):"""Returns a name space object managing a child :py:class:`Block` and parameter names. Should be used within a ``with`` statement:: with self.name_scope(): self.dense = nn.Dense(20) Please refer to `naming tutorial <http://mxnet.incubator.apache.org/tutorials/gluon/naming.html>`_ for more info on prefix and naming. """returnself._scope

@propertydefparams(self):"""Returns this :py:class:`Block`'s parameter dictionary (does not include its children's parameters)."""returnself._params

[docs]defcollect_params(self,select=None):"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its children's Parameters(default), also can returns the select :py:class:`ParameterDict` which match some given regular expressions. For example, collect the specified parameters in ['conv1_weight', 'conv1_bias', 'fc_weight', 'fc_bias']:: model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias') or collect all parameters whose names end with 'weight' or 'bias', this can be done using regular expressions:: model.collect_params('.*weight|.*bias') Parameters ---------- select : str regular expressions Returns ------- The selected :py:class:`ParameterDict` """# We need to check here because blocks inside containers are not supported.self._check_container_with_block()ret=ParameterDict(self._params.prefix)ifnotselect:ret.update(self.params)else:pattern=re.compile(select)ret.update({name:valueforname,valueinself.params.items()ifpattern.match(name)})forcldinself._children.values():ret.update(cld.collect_params(select=select))returnret

[docs]defsave_parameters(self,filename):"""Save parameters to file. Saved parameters can only be loaded with `load_parameters`. Note that this method only saves parameters, not model structure. If you want to save model structures, please use :py:meth:`HybridBlock.export`. Parameters ---------- filename : str Path to file. References ---------- `Saving and Loading Gluon Models \ <https://mxnet.incubator.apache.org/tutorials/gluon/save_load_params.html>`_ """params=self._collect_params_with_prefix()arg_dict={key:val._reduce()forkey,valinparams.items()}ndarray.save(filename,arg_dict)

[docs]defsave_params(self,filename):"""[Deprecated] Please use save_parameters. Note that if you want load from SymbolBlock later, please use export instead. Save parameters to file. filename : str Path to file. """warnings.warn("save_params is deprecated. Please use save_parameters. ""Note that if you want load from SymbolBlock later, please ""use export instead. For details, see ""https://mxnet.incubator.apache.org/tutorials/gluon/save_lo""ad_params.html")try:self.collect_params().save(filename,strip_prefix=self.prefix)exceptValueErrorase:raiseValueError('%s\nsave_params is deprecated. Using ' \
'save_parameters may resolve this error.'%e.message)

[docs]defregister_child(self,block,name=None):"""Registers block as a child of self. :py:class:`Block` s assigned to self as attributes will be registered automatically."""ifnameisNone:name=str(len(self._children))self._children[name]=block

[docs]defregister_forward_pre_hook(self,hook):r"""Registers a forward pre-hook on the block. The hook function is called immediately before :func:`forward`. It should not modify the input or output. Parameters ---------- hook : callable The forward hook function of form `hook(block, input) -> None`. Returns ------- :class:`mxnet.gluon.utils.HookHandle` """handle=HookHandle()handle.attach(self._forward_pre_hooks,hook)returnhandle

[docs]defregister_forward_hook(self,hook):r"""Registers a forward hook on the block. The hook function is called immediately after :func:`forward`. It should not modify the input or output. Parameters ---------- hook : callable The forward hook function of form `hook(block, input, output) -> None`. Returns ------- :class:`mxnet.gluon.utils.HookHandle` """handle=HookHandle()handle.attach(self._forward_hooks,hook)returnhandle

[docs]defapply(self,fn):r"""Applies ``fn`` recursively to every child block as well as self. Parameters ---------- fn : callable Function to be applied to each submodule, of form `fn(block)`. Returns ------- this block """forcldinself._children.values():cld.apply(fn)fn(self)returnself

[docs]definitialize(self,init=initializer.Uniform(),ctx=None,verbose=False,force_reinit=False):"""Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children. Equivalent to ``block.collect_params().initialize(...)`` Parameters ---------- init : Initializer Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``. Otherwise, :py:meth:`Parameter.init` takes precedence. ctx : Context or list of Context Keeps a copy of Parameters on one or many context(s). verbose : bool, default False Whether to verbosely print out details on initialization. force_reinit : bool, default False Whether to force re-initialization if parameter is already initialized. """self.collect_params().initialize(init,ctx,verbose,force_reinit)

[docs]defcast(self,dtype):"""Cast this Block to use another data type. Parameters ---------- dtype : str or numpy.dtype The new data type. """forchildinself._children.values():child.cast(dtype)for_,paraminself.params.items():param.cast(dtype)

[docs]defsummary(self,*inputs):"""Print the summary of the model's output and parameters. The network must have been initialized, and must not have been hybridized. Parameters ---------- inputs : object Any input that the model supports. For any tensor in the input, only :class:`mxnet.ndarray.NDArray` is supported. """summary=OrderedDict()seen=set()hooks=[]def_get_shape_str(args):defflatten(args):ifnotisinstance(args,(list,tuple)):return[args],int(0)flat=[]fmts=[]foriinargs:arg,fmt=flatten(i)flat.extend(arg)fmts.append(fmt)returnflat,fmtsdefregroup(args,fmt):ifisinstance(fmt,int):iffmt==0:returnargs[0],args[1:]returnargs[:fmt],args[fmt:]ret=[]foriinfmt:res,args=regroup(args,i)ret.append(res)returnret,argsflat_args,fmts=flatten(args)flat_arg_shapes=[x.shapeifisinstance(x,ndarray.NDArray)elsexforxinflat_args]shapes=regroup(flat_arg_shapes,fmts)[0]ifisinstance(shapes,list):shape_str=str(shapes)[1:-1]else:shape_str=str(shapes)returnshape_str.replace('L','')def_register_summary_hook(block):assertnotisinstance(block,HybridBlock)ornotblock._active, \
'"{}" must not be hybridized to print summary.'.format(block.name)def_summary_hook(block,_,outputs):class_name=block.__class__.__name__block_idx=len(summary)-1m_key='%s-%i'%(class_name,block_idx+1)summary[m_key]=OrderedDict()summary[m_key]['output_shape']=_get_shape_str(outputs)params=0summary[m_key]['trainable']=0summary[m_key]['shared']=0forpinblock.params.values():params+=p.data().sizesummary[m_key]['trainable']+=0ifp.grad_req=='null'elsep.data().sizeifpinseen:summary[m_key]['shared']+=p.data().sizeelse:seen.add(p)summary[m_key]['n_params']=paramsfrom.nn.basic_layersimportSequential,HybridSequentialifnotisinstance(block,(Sequential,HybridSequential)):hooks.append(block.register_forward_hook(_summary_hook))summary['Input']=OrderedDict()summary['Input']['output_shape']=_get_shape_str(inputs)summary['Input']['n_params']=0summary['Input']['trainable']=0summary['Input']['shared']=0try:self.apply(_register_summary_hook)self(*inputs)line_format='{:>20}{:>42}{:>15}'print('-'*80)print(line_format.format('Layer (type)','Output Shape','Param #'))print('='*80)total_params=0trainable_params=0shared_params=0forlayerinsummary:print(line_format.format(layer,str(summary[layer]['output_shape']),summary[layer]['n_params']))total_params+=summary[layer]['n_params']trainable_params+=summary[layer]['trainable']shared_params+=summary[layer]['shared']print('='*80)print('Parameters in forward computation graph, duplicate included')print(' Total params: '+str(total_params))print(' Trainable params: '+str(trainable_params))print(' Non-trainable params: '+str(total_params-trainable_params))print('Shared params in forward computation graph: '+str(shared_params))print('Unique parameters in model: '+str(total_params-shared_params))print('-'*80)finally:forhinhooks:h.detach()

[docs]classHybridBlock(Block):"""`HybridBlock` supports forwarding with both Symbol and NDArray. `HybridBlock` is similar to `Block`, with a few differences:: import mxnet as mx from mxnet.gluon import HybridBlock, nn class Model(HybridBlock): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) # use name_scope to give child Blocks appropriate names. with self.name_scope(): self.dense0 = nn.Dense(20) self.dense1 = nn.Dense(20) def hybrid_forward(self, F, x): x = F.relu(self.dense0(x)) return F.relu(self.dense1(x)) model = Model() model.initialize(ctx=mx.cpu(0)) model.hybridize() model(mx.nd.zeros((10, 10), ctx=mx.cpu(0))) Forward computation in :py:class:`HybridBlock` must be static to work with :py:class:`Symbol` s, i.e. you cannot call :py:meth:`NDArray.asnumpy`, :py:attr:`NDArray.shape`, :py:attr:`NDArray.dtype`, `NDArray` indexing (`x[i]`) etc on tensors. Also, you cannot use branching or loop logic that bases on non-constant expressions like random numbers or intermediate results, since they change the graph structure for each iteration. Before activating with :py:meth:`hybridize()`, :py:class:`HybridBlock` works just like normal :py:class:`Block`. After activation, :py:class:`HybridBlock` will create a symbolic graph representing the forward computation and cache it. On subsequent forwards, the cached graph will be used instead of :py:meth:`hybrid_forward`. Please see references for detailed tutorial. References ---------- `Hybrid - Faster training and easy deployment <http://mxnet.io/tutorials/gluon/hybrid.html>`_ """

def__setattr__(self,name,value):"""Registers parameters."""super(HybridBlock,self).__setattr__(name,value)ifisinstance(value,HybridBlock):self._clear_cached_op()def_get_graph(self,*args):ifnotself._cached_graph:args,self._in_format=_flatten(args,"input")iflen(args)>1:inputs=[symbol.var('data%d'%i)foriinrange(len(args))]else:inputs=[symbol.var('data')]grouped_inputs=_regroup(inputs,self._in_format)[0]params={i:j.var()fori,jinself._reg_params.items()}withself.name_scope():out=self.hybrid_forward(symbol,*grouped_inputs,**params)# pylint: disable=no-value-for-parameterout,self._out_format=_flatten(out,"output")self._cached_graph=inputs,symbol.Group(out)returnself._cached_graphdef_build_cache(self,*args):data,out=self._get_graph(*args)data_names={data.name:ifori,datainenumerate(data)}params=self.collect_params()input_names=out.list_inputs()param_names=set(params.keys())expected_names=set(input_names)fornameinexpected_names:assertnameinparam_namesornameindata_names, \
"Unknown input to HybridBlock: %s"%nameused_data_names=[iforiindata_namesifiinexpected_names]iflen(used_data_names)!=len(data_names):unused=', '.join(['%d-th'%iforname,iindata_names.items()ifnamenotinexpected_names])warnings.warn("The %s input to HybridBlock is not used by any ""computation. Is this intended?"%unused,stacklevel=4)used_param_names=[iforiinparam_namesifiinexpected_names]iflen(used_param_names)!=len(param_names):unused=', '.join(list(param_names-set(used_param_names)))warnings.warn("Parameter %s is not used by any computation. ""Is this intended?"%unused,stacklevel=4)data_indices=[]param_indices=[]self._cached_op_args=[]fori,nameinenumerate(input_names):ifnameindata_names:data_indices.append(i)self._cached_op_args.append((True,data_names[name]))else:param_indices.append(i)self._cached_op_args.append((False,params[name]))flags=[('data_indices',data_indices),('param_indices',param_indices)]+ \
self._flagsself._cached_op=ndarray.CachedOp(out,flags)def_deferred_infer_shape(self,*args):try:self.infer_shape(*args)exceptExceptionase:error_msg="Deferred initialization failed because shape"\
" cannot be inferred. {}".format(e)raiseValueError(error_msg)def_call_cached_op(self,*args):ifself._cached_opisNone:self._build_cache(*args)args,fmt=_flatten(args,"input")assertfmt==self._in_format,"Invalid input format"try:cargs=[args[i]ifis_argelsei.data()foris_arg,iinself._cached_op_args]exceptDeferredInitializationError:self._deferred_infer_shape(*args)cargs=[]foris_arg,iinself._cached_op_args:ifis_arg:cargs.append(args[i])else:i._finish_deferred_init()cargs.append(i.data())out=self._cached_op(*cargs)ifisinstance(out,NDArray):out=[out]return_regroup(out,self._out_format)[0]def_clear_cached_op(self):self._cached_graph=()self._cached_op=Nonedefregister_child(self,block,name=None):ifnotisinstance(block,HybridBlock):raiseValueError("Children of HybridBlock must also be HybridBlock, " \
"but %s has type %s. If you are using Sequential, " \
"please try HybridSequential instead."%(str(block),str(type(block))))super(HybridBlock,self).register_child(block,name)self._clear_cached_op()

[docs]defhybridize(self,active=True,**kwargs):self._active=activeself._flags=list(kwargs.items())self._clear_cached_op()ifactiveandself._forward_hooksorself._forward_pre_hooks:warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. ''If "{}" is a child of HybridBlock, the hooks will not take effect.')super(HybridBlock,self).hybridize(active,**kwargs)

[docs]defexport(self,path,epoch=0):"""Export HybridBlock to json format that can be loaded by `SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. .. note:: When there are only one input, it will have name `data`. When there Are more than one inputs, they will be named as `data0`, `data1`, etc. Parameters ---------- path : str Path to save model. Two files `path-symbol.json` and `path-xxxx.params` will be created, where xxxx is the 4 digits epoch number. epoch : int Epoch number of saved model. """ifnotself._cached_graph:raiseRuntimeError("Please first call block.hybridize() and then run forward with ""this block at least once before calling export.")sym=self._cached_graph[1]sym.save('%s-symbol.json'%path)arg_names=set(sym.list_arguments())aux_names=set(sym.list_auxiliary_states())arg_dict={}forname,paraminself.collect_params().items():ifnameinarg_names:arg_dict['arg:%s'%name]=param._reduce()else:assertnameinaux_namesarg_dict['aux:%s'%name]=param._reduce()ndarray.save('%s-%04d.params'%(path,epoch),arg_dict)

defforward(self,x,*args):"""Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`."""ifisinstance(x,NDArray):withx.contextasctx:ifself._active:returnself._call_cached_op(x,*args)try:params={i:j.data(ctx)fori,jinself._reg_params.items()}exceptDeferredInitializationError:self._deferred_infer_shape(x,*args)for_,iinself.params.items():i._finish_deferred_init()params={i:j.data(ctx)fori,jinself._reg_params.items()}returnself.hybrid_forward(ndarray,x,*args,**params)assertisinstance(x,Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)params={i:j.var()fori,jinself._reg_params.items()}withself.name_scope():returnself.hybrid_forward(symbol,x,*args,**params)

def_common_prefix(names):"""Get the common prefix for all names"""ifnotnames:return''prefix=names[0]fornameinnames:i=0whilei<len(prefix)andi<len(name)andprefix[i]==name[i]:i+=1prefix=prefix[:i]returnprefix

[docs]def__init__(self,outputs,inputs,params=None):super(SymbolBlock,self).__init__(prefix=None,params=None)self._prefix=''self._params=ParameterDict('',params)ifisinstance(inputs,symbol.Symbol)andlen(inputs.list_outputs())==1:inputs=[inputs]ifisinstance(outputs,(list,tuple))andlen(outputs)==1:outputs=outputs[0]syms,self._in_format=_flatten(inputs,"input")out,self._out_format=_flatten(outputs,"output")out=symbol.Group(out)input_names=set()foriinsyms:assertlen(i.get_internals().list_outputs())==1, \
"Input symbols must be variable, but %s is an output of operators"%str(i)input_names.add(i.name)# check if any symbol is row_sparserow_sparse_storage=ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']foriinout:forjini.get_internals():assert(j.attr("__storage_type__")!=str(row_sparse_storage)), \
"SymbolBlock doesn't support Parameter '%s' because its storage " \
"type is 'row_sparse'."%j.name# Infer type of parameters. Without this, every parameter will be created with# default type i.e., fp32arg_params=out.list_arguments()aux_params=out.list_auxiliary_states()arg_types,aux_types=_infer_param_types(syms,out,arg_params,aux_params)fori,arginenumerate(arg_params):ifargnotininput_names:self.params.get(arg,allow_deferred_init=True,dtype=arg_types[i])fori,auxinenumerate(aux_params):ifauxnotininput_names:self.params.get(aux,grad_req='null',allow_deferred_init=True,dtype=aux_types[i])self._cached_graph=syms,outlen_prefix=len(_common_prefix(list(self._params.keys())))self._reg_params={key[len_prefix:]:valforkey,valinself._params.items()}

defforward(self,x,*args):ifisinstance(x,NDArray):withx.context:returnself._call_cached_op(x,*args)assertisinstance(x,Symbol), \
"HybridBlock requires the first argument to forward be either " \
"Symbol or NDArray, but got %s"%type(x)args,in_fmt=_flatten([x]+list(args),"input")assertin_fmt==self._in_format,"Invalid input format"ret=copy.copy(self._cached_graph[1])ret._compose(**{k.name:vfork,vinzip(self._cached_graph[0],args)})return_regroup(list(ret),self._out_format)[0]def_clear_cached_op(self):tmp=self._cached_graphsuper(SymbolBlock,self)._clear_cached_op()self._cached_graph=tmpdefcast(self,dtype):self._clear_cached_op()super(SymbolBlock,self).cast(dtype)defhybrid_forward(self,F,x,*args,**kwargs):raiseNotImplementedError

def_infer_param_types(in_params,out_params,arg_params,aux_params,default_dtype=mx_real_t):"""Utility function that helps in inferring DType of args and auxs params from given input param. Parameters ---------- in_params: List of Symbol List of input symbol variables. out_params: Symbol Output symbol variable. arg_params: List of Str List of names of argument parametrs. aux_params: List of Str List of names of auxiliary parameters. default_dtype: numpy.dtype or str, default 'float32' Default data type for arg_params and aux_params, if unable to infer the type. Returns ------- arg_types: List of numpy.dtype List of arg_params type. Order is same as arg_params. Defaults to 'float32', if unable to infer type. aux_types: List of numpy.dtype List of aux_params type. Order is same as aux_params. Defaults to 'float32', if unable to infer type. """arg_types=Noneaux_types=None# Get Input symbol details. This will be used to infer types of# other parameters.input_sym_names=[in_param.nameforin_paraminin_params]# Try to infer input types. If not successful, we will set default dtype.# If successful, we will try to infer other params in the graph.input_sym_arg_types=[]can_infer_input_type=Trueforin_paraminin_params:input_sym_arg_type=in_param.infer_type()[0]ifnotinput_sym_arg_typeorlen(input_sym_arg_type)<1:can_infer_input_type=Falsebreakelse:input_sym_arg_types.append(in_param.infer_type()[0][0])# Try to infer types of other parameters.ifcan_infer_input_type:params={k:vfork,vinzip(input_sym_names,input_sym_arg_types)}arg_types,_,aux_types=out_params.infer_type(**params)ifarg_typesisNoneorlen(arg_types)!=len(arg_params):arg_types=[]for_inarg_params:arg_types.append(default_dtype)ifaux_typesisNoneorlen(aux_types)!=len(aux_params):aux_types=[]for_inaux_params:aux_types.append(default_dtype)return(arg_types,aux_types)