Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d77d937

Browse files
JulianSlzrpiiswrong
authored andcommittedDec 9, 2017
Fix __repr__ for gluon.Parameter (#8956)
* Fix __repr__ for gluon.parameter * Add to contributors (PRs #8565, #8956, etc.) * Add unit test for gluon.Parameter string
1 parent 9d4bb9c commit d77d937

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed
 

‎CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,4 @@ List of Contributors
152152
* [Andre Tamm](https://github.com/andretamm)
153153
* [Marco de Abreu](https://github.com/marcoabreu)
154154
- Marco is the creator of the current MXNet CI.
155+
* [Julian Salazar](https://github.com/JulianSlzr)

‎python/mxnet/gluon/parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
117117

118118
def __repr__(self):
119119
s = 'Parameter {name} (shape={shape}, dtype={dtype})'
120-
return s.format(**self.__dict__)
120+
return s.format(name=self.name, shape=self.shape, dtype=self.dtype)
121121

122122
@property
123123
def grad_req(self):

‎tests/python/unittest/test_gluon.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,23 @@ def forward(self, x):
7070
net3.load_params('net1.params', mx.cpu())
7171

7272

73+
def test_parameter_str():
74+
class Net(gluon.Block):
75+
def __init__(self, **kwargs):
76+
super(Net, self).__init__(**kwargs)
77+
with self.name_scope():
78+
self.dense0 = nn.Dense(10, in_units=5, use_bias=False)
79+
80+
net = Net(prefix='net1_')
81+
lines = str(net.collect_params()).splitlines()
82+
83+
assert lines[0] == 'net1_ ('
84+
assert 'net1_dense0_weight' in lines[1]
85+
assert '(10, 5)' in lines[1]
86+
assert 'numpy.float32' in lines[1]
87+
assert lines[2] == ')'
88+
89+
7390
def test_basic():
7491
model = nn.Sequential()
7592
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))

0 commit comments

Comments
 (0)
This repository has been archived.