Source code for mxnet.ndarray.sparse

# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# coding: utf-8# pylint: disable=wildcard-import, unused-wildcard-import, too-many-lines"""Sparse NDArray API of MXNet."""from__future__importabsolute_importfrom__future__importdivisiontry:from__builtin__importsliceaspy_slicefrom__builtin__importsumaspy_sumexceptImportError:frombuiltinsimportsliceaspy_slicefrombuiltinsimportsumaspy_sumimportctypesimportwarningsimportoperatorfromarrayimportarrayasnative_array__all__=["_ndarray_cls","csr_matrix","row_sparse_array","BaseSparseNDArray","CSRNDArray","RowSparseNDArray","add","subtract","multiply","divide"]importnumpyasnpfrom..baseimportNotSupportedForSparseNDArrayfrom..baseimport_LIB,numeric_typesfrom..baseimportc_array_buf,mx_real_t,integer_typesfrom..baseimportmx_uint,NDArrayHandle,check_callfrom..contextimportContext,current_contextfrom.import_internalfrom.importoptry:from.gen_sparseimportretainasgs_retain# pylint: disable=redefined-builtinexceptImportError:gs_retain=Nonefrom._internalimport_set_ndarray_classfrom.ndarrayimportNDArray,_storage_type,_DTYPE_NP_TO_MX,_DTYPE_MX_TO_NPfrom.ndarrayimport_STORAGE_TYPE_STR_TO_ID,_STORAGE_TYPE_ROW_SPARSE,_STORAGE_TYPE_CSRfrom.ndarrayimport_STORAGE_TYPE_UNDEFINED,_STORAGE_TYPE_DEFAULTfrom.ndarrayimportzerosas_zeros_ndarrayfrom.ndarrayimportarrayas_arrayfrom.ndarrayimport_ufunc_helpertry:importscipy.sparseasspspexceptImportError:spsp=None_STORAGE_AUX_TYPES={'row_sparse':[np.int64],'csr':[np.int64,np.int64]}def_new_alloc_handle(stype,shape,ctx,delay_alloc,dtype,aux_types,aux_shapes=None):"""Return a new handle with specified storage type, shape, dtype and context. Empty handle is only used to hold results Returns ------- handle A new empty ndarray handle """hdl=NDArrayHandle()foraux_tinaux_types:ifnp.dtype(aux_t)!=np.dtype("int64"):raiseNotImplementedError("only int64 is supported for aux types")aux_type_ids=[int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type])foraux_tinaux_types]aux_shapes=[(0,)foraux_tinaux_types]ifaux_shapesisNoneelseaux_shapesaux_shape_lens=[len(aux_shape)foraux_shapeinaux_shapes]aux_shapes=py_sum(aux_shapes,())num_aux=mx_uint(len(aux_types))check_call(_LIB.MXNDArrayCreateSparseEx(ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),c_array_buf(mx_uint,native_array('I',shape)),mx_uint(len(shape)),ctypes.c_int(ctx.device_typeid),ctypes.c_int(ctx.device_id),ctypes.c_int(int(delay_alloc)),ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),num_aux,c_array_buf(ctypes.c_int,native_array('i',aux_type_ids)),c_array_buf(mx_uint,native_array('I',aux_shape_lens)),c_array_buf(mx_uint,native_array('I',aux_shapes)),ctypes.byref(hdl)))returnhdlclassBaseSparseNDArray(NDArray):"""The base class of an NDArray stored in a sparse storage format. See CSRNDArray and RowSparseNDArray for more details. """def__repr__(self):"""Returns a string representation of the sparse array."""shape_info='x'.join(['%d'%xforxinself.shape])# The data content is not displayed since the array usually has big shapereturn'\n<%s%s @%s>'%(self.__class__.__name__,shape_info,self.context)def__add__(self,other):returnadd(self,other)def__sub__(self,other):returnsubtract(self,other)def__mul__(self,other):returnmultiply(self,other)def__div__(self,other):returndivide(self,other)def__iadd__(self,other):raiseNotImplementedError()def__isub__(self,other):raiseNotImplementedError()def__imul__(self,other):raiseNotImplementedError()def__idiv__(self,other):raiseNotImplementedError()def__itruediv__(self,other):raiseNotImplementedError()def_sync_copyfrom(self,source_array):raiseNotImplementedError()def_at(self,idx):raiseNotSupportedForSparseNDArray(self._at,'[idx]',idx)def_slice(self,start,stop):raiseNotSupportedForSparseNDArray(self._slice,None,start,stop)defreshape(self,*shape,**kwargs):raiseNotSupportedForSparseNDArray(self.reshape,None,shape)@propertydefsize(self):# the `size` for a sparse ndarray is ambiguous, hence disabled.raiseNotImplementedError()def_aux_type(self,i):"""Data-type of the array's ith aux data. Returns ------- numpy.dtype This BaseSparseNDArray's aux data type. """aux_type=ctypes.c_int()check_call(_LIB.MXNDArrayGetAuxType(self.handle,i,ctypes.byref(aux_type)))return_DTYPE_MX_TO_NP[aux_type.value]@propertydef_num_aux(self):"""The number of aux data used to help store the sparse ndarray. """returnlen(_STORAGE_AUX_TYPES[self.stype])@propertydef_aux_types(self):"""The data types of the aux data for the BaseSparseNDArray. """aux_types=[]num_aux=self._num_auxforiinrange(num_aux):aux_types.append(self._aux_type(i))returnaux_typesdefasnumpy(self):"""Return a dense ``numpy.ndarray`` object with value copied from this array """returnself.tostype('default').asnumpy()defastype(self,dtype,copy=True):"""Return a copy of the array after casting to a specified type. Parameters ---------- dtype : numpy.dtype or str The type of the returned array. copy : bool Default `True`. By default, astype always returns a newly allocated ndarray on the same context. If this is set to `False`, and the dtype requested is the same as the ndarray's dtype, the ndarray is returned instead of a copy. Examples -------- >>> x = mx.nd.sparse.zeros('row_sparse', (2,3), dtype='float32') >>> y = x.astype('int32') >>> y.dtype <type 'numpy.int32'> """ifnotcopyandnp.dtype(dtype)==self.dtype:returnselfres=zeros(shape=self.shape,ctx=self.context,dtype=dtype,stype=self.stype)self.copyto(res)returnresdefcopyto(self,other):"""Copies the value of this array to another array. Parameters ---------- other : NDArray or CSRNDArray or RowSparseNDArray or Context The destination array or context. Returns ------- NDArray or CSRNDArray or RowSparseNDArray The copied array. """# pylint: disable= no-member, protected-accessifisinstance(other,NDArray):ifother.handleisself.handle:warnings.warn('You are attempting to copy an array to itself',RuntimeWarning)returnFalsereturn_internal._copyto(self,out=other)elifisinstance(other,Context):hret=_ndarray_cls(_new_alloc_handle(self.stype,self.shape,other,True,self.dtype,self._aux_types))return_internal._copyto(self,out=hret)else:raiseTypeError('copyto does not support type '+str(type(other)))# pylint: enable= no-member, protected-accessdefcheck_format(self,full_check=True):"""Check whether the NDArray format is valid. Parameters ---------- full_check : bool, optional If `True`, rigorous check, O(N) operations. Otherwise basic check, O(1) operations (default True). """check_call(_LIB.MXNDArraySyncCheckFormat(self.handle,ctypes.c_bool(full_check)))def_data(self):"""A deep copy NDArray of the data array associated with the BaseSparseNDArray. This function blocks. Do not use it in performance critical code. """self.wait_to_read()hdl=NDArrayHandle()check_call(_LIB.MXNDArrayGetDataNDArray(self.handle,ctypes.byref(hdl)))returnNDArray(hdl)def_aux_data(self,i):""" Get a deep copy NDArray of the i-th aux data array associated with the BaseSparseNDArray. This function blocks. Do not use it in performance critical code. """self.wait_to_read()hdl=NDArrayHandle()check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle,i,ctypes.byref(hdl)))returnNDArray(hdl)# pylint: disable=abstract-method

[docs]classCSRNDArray(BaseSparseNDArray):"""A sparse representation of 2D NDArray in the Compressed Sparse Row format. A CSRNDArray represents an NDArray as three separate arrays: `data`, `indptr` and `indices`. It uses the CSR representation where the column indices for row i are stored in ``indices[indptr[i]:indptr[i+1]]`` and their corresponding values are stored in ``data[indptr[i]:indptr[i+1]]``. The column indices for a given row are expected to be sorted in ascending order. Duplicate column entries for the same row are not allowed. Example ------- >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]]) >>> a = a.tostype('csr') >>> a.data.asnumpy() array([ 1., 2., 3.], dtype=float32) >>> a.indices.asnumpy() array([1, 0, 2]) >>> a.indptr.asnumpy() array([0, 1, 2, 2, 3]) See Also -------- csr_matrix: Several ways to construct a CSRNDArray """def__reduce__(self):returnCSRNDArray,(None,),super(CSRNDArray,self).__getstate__()def__iadd__(self,other):(self+other).copyto(self)returnselfdef__isub__(self,other):(self-other).copyto(self)returnselfdef__imul__(self,other):(self*other).copyto(self)returnselfdef__idiv__(self,other):(self/other).copyto(self)returnselfdef__itruediv__(self,other):(self/other).copyto(self)returnself

@propertydefindices(self):"""A deep copy NDArray of the indices array of the CSRNDArray. This generates a deep copy of the column indices of the current `csr` matrix. Returns ------- NDArray This CSRNDArray's indices array. """returnself._aux_data(1)@propertydefindptr(self):"""A deep copy NDArray of the indptr array of the CSRNDArray. This generates a deep copy of the `indptr` of the current `csr` matrix. Returns ------- NDArray This CSRNDArray's indptr array. """returnself._aux_data(0)@propertydefdata(self):"""A deep copy NDArray of the data array of the CSRNDArray. This generates a deep copy of the `data` of the current `csr` matrix. Returns ------- NDArray This CSRNDArray's data array. """returnself._data()@indices.setterdefindices(self,indices):raiseNotImplementedError()@indptr.setterdefindptr(self,indptr):raiseNotImplementedError()@data.setterdefdata(self,data):raiseNotImplementedError()

[docs]deftostype(self,stype):"""Return a copy of the array with chosen storage type. Returns ------- NDArray or CSRNDArray A copy of the array with the chosen storage stype """# pylint: disable= no-member, protected-accessifstype=='row_sparse':raiseValueError("cast_storage from csr to row_sparse is not supported")returnop.cast_storage(self,stype=stype)

# pylint: enable= no-member, protected-access

[docs]defcopyto(self,other):"""Copies the value of this array to another array. If ``other`` is a ``NDArray`` or ``CSRNDArray`` object, then ``other.shape`` and ``self.shape`` should be the same. This function copies the value from ``self`` to ``other``. If ``other`` is a context, a new ``CSRNDArray`` will be first created on the target context, and the value of ``self`` is copied. Parameters ---------- other : NDArray or CSRNDArray or Context The destination array or context. Returns ------- NDArray or CSRNDArray The copied array. If ``other`` is an ``NDArray`` or ``CSRNDArray``, then the return value and ``other`` will point to the same ``NDArray`` or ``CSRNDArray``. """ifisinstance(other,Context):returnsuper(CSRNDArray,self).copyto(other)elifisinstance(other,NDArray):stype=other.stypeifstypein('default','csr'):returnsuper(CSRNDArray,self).copyto(other)else:raiseTypeError('copyto does not support destination NDArray stype '+str(stype))else:raiseTypeError('copyto does not support type '+str(type(other)))

# pylint: enable= no-member, protected-access@propertydefindices(self):"""A deep copy NDArray of the indices array of the RowSparseNDArray. This generates a deep copy of the row indices of the current `row_sparse` matrix. Returns ------- NDArray This RowSparseNDArray's indices array. """returnself._aux_data(0)@propertydefdata(self):"""A deep copy NDArray of the data array of the RowSparseNDArray. This generates a deep copy of the `data` of the current `row_sparse` matrix. Returns ------- NDArray This RowSparseNDArray's data array. """returnself._data()@indices.setterdefindices(self,indices):raiseNotImplementedError()@data.setterdefdata(self,data):raiseNotImplementedError()

[docs]deftostype(self,stype):"""Return a copy of the array with chosen storage type. Returns ------- NDArray or RowSparseNDArray A copy of the array with the chosen storage stype """# pylint: disable= no-member, protected-accessifstype=='csr':raiseValueError("cast_storage from row_sparse to csr is not supported")returnop.cast_storage(self,stype=stype)

# pylint: enable= no-member, protected-access

[docs]defcopyto(self,other):"""Copies the value of this array to another array. If ``other`` is a ``NDArray`` or ``RowSparseNDArray`` object, then ``other.shape`` and ``self.shape`` should be the same. This function copies the value from ``self`` to ``other``. If ``other`` is a context, a new ``RowSparseNDArray`` will be first created on the target context, and the value of ``self`` is copied. Parameters ---------- other : NDArray or RowSparseNDArray or Context The destination array or context. Returns ------- NDArray or RowSparseNDArray The copied array. If ``other`` is an ``NDArray`` or ``RowSparseNDArray``, then the return value and ``other`` will point to the same ``NDArray`` or ``RowSparseNDArray``. """ifisinstance(other,Context):returnsuper(RowSparseNDArray,self).copyto(other)elifisinstance(other,NDArray):stype=other.stypeifstypein('default','row_sparse'):returnsuper(RowSparseNDArray,self).copyto(other)else:raiseTypeError('copyto does not support destination NDArray stype '+str(stype))else:raiseTypeError('copyto does not support type '+str(type(other)))

[docs]defretain(self,*args,**kwargs):"""Convenience fluent method for :py:func:`retain`. The arguments are the same as for :py:func:`retain`, with this array as data. """ifnotgs_retain:raiseImportError("gen_sparse could not be imported")returngs_retain(*args,**kwargs)

def_prepare_src_array(source_array,dtype):"""Prepare `source_array` so that it can be used to construct NDArray. `source_array` is converted to a `np.ndarray` if it's neither an `NDArray` \ nor an `np.ndarray`. """ifnotisinstance(source_array,NDArray)andnotisinstance(source_array,np.ndarray):try:source_array=np.array(source_array,dtype=dtype)except:raiseTypeError('values must be array like object')returnsource_arraydef_prepare_default_dtype(src_array,dtype):"""Prepare the value of dtype if `dtype` is None. If `src_array` is an NDArray, numpy.ndarray or scipy.sparse.csr.csr_matrix, return src_array.dtype. float32 is returned otherwise."""ifdtypeisNone:ifisinstance(src_array,(NDArray,np.ndarray)):dtype=src_array.dtypeelifspspandisinstance(src_array,spsp.csr.csr_matrix):dtype=src_array.dtypeelse:dtype=mx_real_treturndtypedef_check_shape(s1,s2):"""check s1 == s2 if both are not None"""ifs1ands2ands1!=s2:raiseValueError("Shape mismatch detected. "+str(s1)+" v.s. "+str(s2))

[docs]defcsr_matrix(arg1,shape=None,ctx=None,dtype=None):"""Creates a `CSRNDArray`, an 2D array with compressed sparse row (CSR) format. The CSRNDArray can be instantiated in several ways: - csr_matrix(D): to construct a CSRNDArray with a dense 2D array ``D`` - **D** (*array_like*) - An object exposing the array interface, an object whose \ `__array__` method returns an array, or any (nested) sequence. - **ctx** (*Context, optional*) - Device context \ (default is the current default context). - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \ float32 otherwise. - csr_matrix(S) to construct a CSRNDArray with a sparse 2D array ``S`` - **S** (*CSRNDArray or scipy.sparse.csr.csr_matrix*) - A sparse matrix. - **ctx** (*Context, optional*) - Device context \ (default is the current default context). - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ The default dtype is ``S.dtype``. - csr_matrix((M, N)) to construct an empty CSRNDArray with shape ``(M, N)`` - **M** (*int*) - Number of rows in the matrix - **N** (*int*) - Number of columns in the matrix - **ctx** (*Context, optional*) - Device context \ (default is the current default context). - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ The default dtype is float32. - csr_matrix((data, indices, indptr)) to construct a CSRNDArray based on the definition of compressed sparse row format \ using three separate arrays, \ where the column indices for row i are stored in ``indices[indptr[i]:indptr[i+1]]`` \ and their corresponding values are stored in ``data[indptr[i]:indptr[i+1]]``. \ The column indices for a given row are expected to be **sorted in ascending order.** \ Duplicate column entries for the same row are not allowed. - **data** (*array_like*) - An object exposing the array interface, which \ holds all the non-zero entries of the matrix in row-major order. - **indices** (*array_like*) - An object exposing the array interface, which \ stores the column index for each non-zero element in ``data``. - **indptr** (*array_like*) - An object exposing the array interface, which \ stores the offset into ``data`` of the first non-zero element number of each \ row of the matrix. - **shape** (*tuple of int, optional*) - The shape of the array. The default \ shape is inferred from the indices and indptr arrays. - **ctx** (*Context, optional*) - Device context \ (default is the current default context). - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ The default dtype is ``data.dtype`` if ``data`` is an NDArray or numpy.ndarray, \ float32 otherwise. - csr_matrix((data, (row, col))) to construct a CSRNDArray based on the COOrdinate format \ using three seperate arrays, \ where ``row[i]`` is the row index of the element, \ ``col[i]`` is the column index of the element \ and ``data[i]`` is the data corresponding to the element. All the missing \ elements in the input are taken to be zeroes. - **data** (*array_like*) - An object exposing the array interface, which \ holds all the non-zero entries of the matrix in COO format. - **row** (*array_like*) - An object exposing the array interface, which \ stores the row index for each non zero element in ``data``. - **col** (*array_like*) - An object exposing the array interface, which \ stores the col index for each non zero element in ``data``. - **shape** (*tuple of int, optional*) - The shape of the array. The default \ shape is inferred from the ``row`` and ``col`` arrays. - **ctx** (*Context, optional*) - Device context \ (default is the current default context). - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \ The default dtype is float32. Parameters ---------- arg1: tuple of int, tuple of array_like, array_like, CSRNDArray, scipy.sparse.csr_matrix, \ scipy.sparse.coo_matrix, tuple of int or tuple of array_like The argument to help instantiate the csr matrix. See above for further details. shape : tuple of int, optional The shape of the csr matrix. ctx: Context, optional Device context (default is the current default context). dtype: str or numpy.dtype, optional The data type of the output array. Returns ------- CSRNDArray A `CSRNDArray` with the `csr` storage representation. Example ------- >>> a = mx.nd.sparse.csr_matrix(([1, 2, 3], [1, 0, 2], [0, 1, 2, 2, 3]), shape=(4, 3)) >>> a.asnumpy() array([[ 0., 1., 0.], [ 2., 0., 0.], [ 0., 0., 0.], [ 0., 0., 3.]], dtype=float32) See Also -------- CSRNDArray : MXNet NDArray in compressed sparse row format. """# construct a csr matrix from (M, N) or (data, indices, indptr)ifisinstance(arg1,tuple):arg_len=len(arg1)ifarg_len==2:# construct a sparse csr matrix from# scipy coo matrix if input format is cooifisinstance(arg1[1],tuple)andlen(arg1[1])==2:data,(row,col)=arg1ifisinstance(data,NDArray):data=data.asnumpy()ifisinstance(row,NDArray):row=row.asnumpy()ifisinstance(col,NDArray):col=col.asnumpy()coo=spsp.coo_matrix((data,(row,col)),shape=shape)_check_shape(coo.shape,shape)csr=coo.tocsr()returnarray(csr,ctx=ctx,dtype=dtype)else:# empty matrix with shape_check_shape(arg1,shape)returnempty('csr',arg1,ctx=ctx,dtype=dtype)elifarg_len==3:# data, indices, indptrreturn_csr_matrix_from_definition(arg1[0],arg1[1],arg1[2],shape=shape,ctx=ctx,dtype=dtype)else:raiseValueError("Unexpected length of input tuple: "+str(arg_len))else:# construct a csr matrix from a sparse / dense oneifisinstance(arg1,CSRNDArray)or(spspandisinstance(arg1,spsp.csr.csr_matrix)):# construct a csr matrix from scipy or CSRNDArray_check_shape(arg1.shape,shape)returnarray(arg1,ctx=ctx,dtype=dtype)elifisinstance(arg1,RowSparseNDArray):raiseValueError("Unexpected input type: RowSparseNDArray")else:# construct a csr matrix from a dense one# prepare default ctx and dtype since mx.nd.array doesn't use default values# based on source_arraydtype=_prepare_default_dtype(arg1,dtype)# create dns array with provided dtype. ctx is not passed since copy across# ctx requires dtype to be the samedns=_array(arg1,dtype=dtype)ifctxisnotNoneanddns.context!=ctx:dns=dns.as_in_context(ctx)_check_shape(dns.shape,shape)returndns.tostype('csr')

def_csr_matrix_from_definition(data,indices,indptr,shape=None,ctx=None,dtype=None,indices_type=None,indptr_type=None):"""Create a `CSRNDArray` based on data, indices and indptr"""# pylint: disable= no-member, protected-accessstorage_type='csr'# contextctx=current_context()ifctxisNoneelsectx# typesdtype=_prepare_default_dtype(data,dtype)indptr_type=_STORAGE_AUX_TYPES[storage_type][0]ifindptr_typeisNoneelseindptr_typeindices_type=_STORAGE_AUX_TYPES[storage_type][1]ifindices_typeisNoneelseindices_type# prepare src array and typesdata=_prepare_src_array(data,dtype)indptr=_prepare_src_array(indptr,indptr_type)indices=_prepare_src_array(indices,indices_type)# TODO(junwu): Convert data, indptr, and indices to mxnet NDArrays# if they are not for now. In the future, we should provide a c-api# to accept np.ndarray types to copy from to result.data and aux_dataifnotisinstance(data,NDArray):data=_array(data,ctx,dtype)ifnotisinstance(indptr,NDArray):indptr=_array(indptr,ctx,indptr_type)ifnotisinstance(indices,NDArray):indices=_array(indices,ctx,indices_type)ifshapeisNone:ifindices.shape[0]==0:raiseValueError('invalid shape')shape=(len(indptr)-1,op.max(indices).asscalar()+1)# verify shapesaux_shapes=[indptr.shape,indices.shape]ifdata.ndim!=1orindptr.ndim!=1orindices.ndim!=1or \
indptr.shape[0]==0orlen(shape)!=2:raiseValueError('invalid shape')result=CSRNDArray(_new_alloc_handle(storage_type,shape,ctx,False,dtype,[indptr_type,indices_type],aux_shapes))check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle,data.handle,ctypes.c_int(-1)))check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle,indptr.handle,ctypes.c_int(0)))check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle,indices.handle,ctypes.c_int(1)))returnresult# pylint: enable= no-member, protected-access

[docs]defempty(stype,shape,ctx=None,dtype=None):"""Returns a new array of given shape and type, without initializing entries. Parameters ---------- stype: string The storage type of the empty array, such as 'row_sparse', 'csr', etc shape : int or tuple of int The shape of the empty array. ctx : Context, optional An optional device context (default is the current default context). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). Returns ------- CSRNDArray or RowSparseNDArray A created array. """ifisinstance(shape,int):shape=(shape,)ifctxisNone:ctx=current_context()ifdtypeisNone:dtype=mx_real_tassert(stypeisnotNone)ifstypein('csr','row_sparse'):returnzeros(stype,shape,ctx=ctx,dtype=dtype)else:raiseException("unknown stype : "+str(stype))

[docs]defarray(source_array,ctx=None,dtype=None):"""Creates a sparse array from any object exposing the array interface. Parameters ---------- source_array : RowSparseNDArray, CSRNDArray or scipy.sparse.csr.csr_matrix The source sparse array ctx : Context, optional The default context is ``source_array.context`` if ``source_array`` is an NDArray. \ The current default context otherwise. dtype : str or numpy.dtype, optional The data type of the output array. The default dtype is ``source_array.dtype`` if `source_array` is an `NDArray`, `numpy.ndarray` or `scipy.sparse.csr.csr_matrix`, \ `float32` otherwise. Returns ------- RowSparseNDArray or CSRNDArray An array with the same contents as the `source_array`. Examples -------- >>> import scipy.sparse as spsp >>> csr = spsp.csr_matrix((2, 100)) >>> mx.nd.sparse.array(csr) <CSRNDArray 2x100 @cpu(0)> >>> mx.nd.sparse.array(mx.nd.sparse.zeros('csr', (3, 2))) <CSRNDArray 3x2 @cpu(0)> >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) <RowSparseNDArray 3x2 @cpu(0)> """ctx=current_context()ifctxisNoneelsectxifisinstance(source_array,NDArray):assert(source_array.stype!='default'), \
"Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray"# prepare dtype and ctx based on source_array, if not provideddtype=_prepare_default_dtype(source_array,dtype)# if both dtype and ctx are different from source_array, we cannot copy directlyifsource_array.dtype!=dtypeandsource_array.context!=ctx:arr=empty(source_array.stype,source_array.shape,dtype=dtype)arr[:]=source_arrayarr=arr.as_in_context(ctx)else:arr=empty(source_array.stype,source_array.shape,dtype=dtype,ctx=ctx)arr[:]=source_arrayreturnarrelifspspandisinstance(source_array,spsp.csr.csr_matrix):# TODO(haibin) implement `_sync_copy_from` with scipy csr object to reduce a copy# preprocess scipy csr to canonical formcsr=source_array.sorted_indices()csr.sum_duplicates()dtype=_prepare_default_dtype(source_array,dtype)returncsr_matrix((csr.data,csr.indices,csr.indptr),shape=csr.shape, \
dtype=dtype,ctx=ctx)elifisinstance(source_array,(np.ndarray,np.generic)):raiseValueError("Please use mx.nd.array to create an NDArray with source_array of type ",type(source_array))else:raiseValueError("Unexpected source_array type: ",type(source_array))