Skip to content

Commit 386a8f4

Browse files
committedAug 22, 2017
Fix the multibox problem and add tests.
1 parent 8b9d222 commit 386a8f4

File tree

8 files changed

+273
-15
lines changed

8 files changed

+273
-15
lines changed
 

‎__init__.py

Whitespace-only changes.

‎data/test.record

-532 KB
Binary file not shown.

‎data/train.record

-987 KB
Binary file not shown.

‎environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: raccoon-dataset
1+
name: raccoon_dataset
22
channels: !!python/tuple
33
- menpo
44
- defaults
@@ -36,5 +36,5 @@ dependencies:
3636
- protobuf==3.3.0
3737
- tensorflow==1.2.1
3838
- werkzeug==0.12.2
39-
prefix: /Users/datitran/anaconda/envs/raccoon-dataset
39+
prefix: /Users/datitran/anaconda/envs/raccoon_dataset
4040

‎generate_tfrecord.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import os
1515
import io
16-
import operator
1716
import pandas as pd
1817
import tensorflow as tf
1918

@@ -37,10 +36,8 @@ def class_text_to_int(row_label):
3736

3837
def split(df, group):
3938
data = namedtuple('data', ['filename', 'object'])
40-
gb = df.groupby(group, sort=True)
41-
# Use ordered dict as gp.groups() creates random order due to dict
42-
sorted_group = OrderedDict(sorted([i for i in gb.groups.items()]))
43-
return [data(filename, gb.get_group(x)) for filename, x in zip(sorted_group.keys(), sorted_group)]
39+
gb = df.groupby(group)
40+
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
4441

4542

4643
def create_tf_example(group, path):
@@ -88,11 +85,14 @@ def main(_):
8885
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
8986
path = os.path.join(os.getcwd(), 'images')
9087
examples = pd.read_csv(FLAGS.csv_input)
91-
for index, row in examples.iterrows():
92-
tf_example = create_tf_example(row, path)
88+
grouped = split(examples, 'filename')
89+
for group in grouped:
90+
tf_example = create_tf_example(group, path)
9391
writer.write(tf_example.SerializeToString())
9492

9593
writer.close()
94+
output_path = os.path.join(os.getcwd(), FLAGS.output_path)
95+
print('Successfully created the TFRecords: {}'.format(output_path))
9696

9797

9898
if __name__ == '__main__':

‎test_generate_tfrecord.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import PIL
3+
import generate_tfrecord
34
import numpy as np
45
import pandas as pd
56
import tensorflow as tf
6-
import generate_tfrecord
77

88

99
class CSVToTFExampleTest(tf.test.TestCase):
@@ -186,3 +186,87 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self):
186186
self._assertProtoEqual(
187187
example.features.feature['image/object/class/label'].int64_list.value,
188188
[1])
189+
190+
def test_csv_to_tf_example_multiple_raccoons_multiple_files(self):
191+
"""Generate tf records for multiple raccoons for multiple files."""
192+
image_file_one = 'tmp_raccoon_image_1.jpg'
193+
image_file_two = 'tmp_raccoon_image_2.jpg'
194+
image_data = np.random.rand(256, 256, 3)
195+
save_path_one = os.path.join(self.get_temp_dir(), image_file_one)
196+
save_path_two = os.path.join(self.get_temp_dir(), image_file_two)
197+
image = PIL.Image.fromarray(image_data, 'RGB')
198+
image.save(save_path_one)
199+
image.save(save_path_two)
200+
201+
column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
202+
raccoon_data = [('tmp_raccoon_image_1.jpg', 256, 256, 'raccoon', 64, 64, 192, 192),
203+
('tmp_raccoon_image_1.jpg', 256, 256, 'raccoon', 32, 32, 96, 96),
204+
('tmp_raccoon_image_2.jpg', 256, 256, 'raccoon', 96, 96, 128, 128)]
205+
raccoon_df = pd.DataFrame(raccoon_data, columns=column_names)
206+
207+
grouped = generate_tfrecord.split(raccoon_df, 'filename')
208+
for group in grouped:
209+
if group.filename == image_file_one:
210+
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
211+
self._assertProtoEqual(
212+
example.features.feature['image/height'].int64_list.value, [256])
213+
self._assertProtoEqual(
214+
example.features.feature['image/width'].int64_list.value, [256])
215+
self._assertProtoEqual(
216+
example.features.feature['image/filename'].bytes_list.value,
217+
[image_file_one.encode('utf-8')])
218+
self._assertProtoEqual(
219+
example.features.feature['image/source_id'].bytes_list.value,
220+
[image_file_one.encode('utf-8')])
221+
self._assertProtoEqual(
222+
example.features.feature['image/format'].bytes_list.value, [b'jpg'])
223+
self._assertProtoEqual(
224+
example.features.feature['image/object/bbox/xmin'].float_list.value,
225+
[0.25, 0.125])
226+
self._assertProtoEqual(
227+
example.features.feature['image/object/bbox/ymin'].float_list.value,
228+
[0.25, 0.125])
229+
self._assertProtoEqual(
230+
example.features.feature['image/object/bbox/xmax'].float_list.value,
231+
[0.75, 0.375])
232+
self._assertProtoEqual(
233+
example.features.feature['image/object/bbox/ymax'].float_list.value,
234+
[0.75, 0.375])
235+
self._assertProtoEqual(
236+
example.features.feature['image/object/class/text'].bytes_list.value,
237+
[b'raccoon', b'raccoon'])
238+
self._assertProtoEqual(
239+
example.features.feature['image/object/class/label'].int64_list.value,
240+
[1, 1])
241+
elif group.filename == image_file_two:
242+
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
243+
self._assertProtoEqual(
244+
example.features.feature['image/height'].int64_list.value, [256])
245+
self._assertProtoEqual(
246+
example.features.feature['image/width'].int64_list.value, [256])
247+
self._assertProtoEqual(
248+
example.features.feature['image/filename'].bytes_list.value,
249+
[image_file_two.encode('utf-8')])
250+
self._assertProtoEqual(
251+
example.features.feature['image/source_id'].bytes_list.value,
252+
[image_file_two.encode('utf-8')])
253+
self._assertProtoEqual(
254+
example.features.feature['image/format'].bytes_list.value, [b'jpg'])
255+
self._assertProtoEqual(
256+
example.features.feature['image/object/bbox/xmin'].float_list.value,
257+
[0.375])
258+
self._assertProtoEqual(
259+
example.features.feature['image/object/bbox/ymin'].float_list.value,
260+
[0.375])
261+
self._assertProtoEqual(
262+
example.features.feature['image/object/bbox/xmax'].float_list.value,
263+
[0.5])
264+
self._assertProtoEqual(
265+
example.features.feature['image/object/bbox/ymax'].float_list.value,
266+
[0.5])
267+
self._assertProtoEqual(
268+
example.features.feature['image/object/class/text'].bytes_list.value,
269+
[b'raccoon'])
270+
self._assertProtoEqual(
271+
example.features.feature['image/object/class/label'].int64_list.value,
272+
[1])

‎test_xml_to_csv.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import shutil
2+
import os
3+
import tempfile
4+
import unittest
5+
import xml_to_csv
6+
from xml.etree import ElementTree as ET
7+
8+
9+
class XMLToCSVTest(unittest.TestCase):
10+
def test_one_raccoon_one_xml(self):
11+
xml_file_one = """
12+
<annotation verified="yes">
13+
<folder>images</folder>
14+
<filename>raccoon-1.png</filename>
15+
<path>raccoon-1.png</path>
16+
<source>
17+
<database>Unknown</database>
18+
</source>
19+
<size>
20+
<width>256</width>
21+
<height>256</height>
22+
<depth>3</depth>
23+
</size>
24+
<segmented>0</segmented>
25+
<object>
26+
<name>raccoon</name>
27+
<pose>Unspecified</pose>
28+
<truncated>0</truncated>
29+
<difficult>0</difficult>
30+
<bndbox>
31+
<xmin>96</xmin>
32+
<ymin>96</ymin>
33+
<xmax>128</xmax>
34+
<ymax>128</ymax>
35+
</bndbox>
36+
</object>
37+
</annotation>
38+
"""
39+
40+
xml = ET.fromstring(xml_file_one)
41+
with tempfile.TemporaryDirectory() as tmpdirname:
42+
tree = ET.ElementTree(xml)
43+
tree.write(tmpdirname + '/test_raccoon_one.xml')
44+
raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
45+
self.assertEqual(raccoon_df.columns.values.tolist(),
46+
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
47+
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])
48+
49+
def test_multiple_raccoon_one_xml(self):
50+
xml_file_one = """
51+
<annotation verified="yes">
52+
<folder>images</folder>
53+
<filename>raccoon-1.png</filename>
54+
<path>raccoon-1.png</path>
55+
<source>
56+
<database>Unknown</database>
57+
</source>
58+
<size>
59+
<width>256</width>
60+
<height>256</height>
61+
<depth>3</depth>
62+
</size>
63+
<segmented>0</segmented>
64+
<object>
65+
<name>raccoon</name>
66+
<pose>Unspecified</pose>
67+
<truncated>0</truncated>
68+
<difficult>0</difficult>
69+
<bndbox>
70+
<xmin>96</xmin>
71+
<ymin>96</ymin>
72+
<xmax>128</xmax>
73+
<ymax>128</ymax>
74+
</bndbox>
75+
</object>
76+
<object>
77+
<name>raccoon</name>
78+
<pose>Unspecified</pose>
79+
<truncated>0</truncated>
80+
<difficult>0</difficult>
81+
<bndbox>
82+
<xmin>32</xmin>
83+
<ymin>32</ymin>
84+
<xmax>64</xmax>
85+
<ymax>64</ymax>
86+
</bndbox>
87+
</object>
88+
</annotation>
89+
"""
90+
91+
xml = ET.fromstring(xml_file_one)
92+
with tempfile.TemporaryDirectory() as tmpdirname:
93+
tree = ET.ElementTree(xml)
94+
tree.write(tmpdirname + '/test_raccoon_one.xml')
95+
raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
96+
self.assertEqual(raccoon_df.columns.values.tolist(),
97+
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
98+
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])
99+
self.assertEqual(raccoon_df.values.tolist()[1], ['raccoon-1.png', 256, 256, 'raccoon', 32, 32, 64, 64])
100+
101+
def test_one_raccoon_multiple_xml(self):
102+
xml_file_one = """
103+
<annotation verified="yes">
104+
<folder>images</folder>
105+
<filename>raccoon-1.png</filename>
106+
<path>raccoon-1.png</path>
107+
<source>
108+
<database>Unknown</database>
109+
</source>
110+
<size>
111+
<width>256</width>
112+
<height>256</height>
113+
<depth>3</depth>
114+
</size>
115+
<segmented>0</segmented>
116+
<object>
117+
<name>raccoon</name>
118+
<pose>Unspecified</pose>
119+
<truncated>0</truncated>
120+
<difficult>0</difficult>
121+
<bndbox>
122+
<xmin>96</xmin>
123+
<ymin>96</ymin>
124+
<xmax>128</xmax>
125+
<ymax>128</ymax>
126+
</bndbox>
127+
</object>
128+
</annotation>
129+
"""
130+
xml_file_two = """
131+
<annotation verified="yes">
132+
<folder>images</folder>
133+
<filename>raccoon-2.png</filename>
134+
<path>raccoon-2.png</path>
135+
<source>
136+
<database>Unknown</database>
137+
</source>
138+
<size>
139+
<width>256</width>
140+
<height>256</height>
141+
<depth>3</depth>
142+
</size>
143+
<segmented>0</segmented>
144+
<object>
145+
<name>raccoon</name>
146+
<pose>Unspecified</pose>
147+
<truncated>0</truncated>
148+
<difficult>0</difficult>
149+
<bndbox>
150+
<xmin>128</xmin>
151+
<ymin>128</ymin>
152+
<xmax>194</xmax>
153+
<ymax>194</ymax>
154+
</bndbox>
155+
</object>
156+
</annotation>
157+
"""
158+
xml_list = [xml_file_one, xml_file_two]
159+
tmpdirname = tempfile.mkdtemp()
160+
for index, x in enumerate(xml_list):
161+
xml = ET.fromstring(x)
162+
tree = ET.ElementTree(xml)
163+
tree.write(tmpdirname + '/test_raccoon_{}.xml'.format(index))
164+
165+
raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
166+
self.assertEqual(raccoon_df.columns.values.tolist(),
167+
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
168+
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])
169+
self.assertEqual(raccoon_df.values.tolist()[1], ['raccoon-2.png', 256, 256, 'raccoon', 128, 128, 194, 194])
170+
shutil.rmtree(tmpdirname)

‎xml_to_csv.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import xml.etree.ElementTree as ET
55

66

7-
def xml_to_csv():
8-
image_path = os.path.join(os.getcwd(), 'annotations')
9-
7+
def xml_to_csv(path):
108
xml_list = []
11-
for xml_file in glob.glob(image_path + '/*.xml'):
9+
for xml_file in glob.glob(path + '/*.xml'):
1210
tree = ET.parse(xml_file)
1311
root = tree.getroot()
1412
for member in root.findall('object'):
@@ -24,8 +22,14 @@ def xml_to_csv():
2422
xml_list.append(value)
2523
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
2624
xml_df = pd.DataFrame(xml_list, columns=column_name)
25+
return xml_df
26+
27+
28+
def main():
29+
image_path = os.path.join(os.getcwd(), 'annotations')
30+
xml_df = xml_to_csv(image_path)
2731
xml_df.to_csv('raccoon_labels.csv', index=None)
2832
print('Successfully converted xml to csv.')
2933

3034

31-
xml_to_csv()
35+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.