Skip to content

Commit

Permalink
Fix the multibox problem and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
datitran committed Aug 22, 2017
1 parent 8b9d222 commit 386a8f4
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 15 deletions.
Empty file added __init__.py
Empty file.
Binary file modified data/test.record
Binary file not shown.
Binary file modified data/train.record
Binary file not shown.
4 changes: 2 additions & 2 deletions environment.yml
@@ -1,4 +1,4 @@
name: raccoon-dataset
name: raccoon_dataset
channels: !!python/tuple
- menpo
- defaults
Expand Down Expand Up @@ -36,5 +36,5 @@ dependencies:
- protobuf==3.3.0
- tensorflow==1.2.1
- werkzeug==0.12.2
prefix: /Users/datitran/anaconda/envs/raccoon-dataset
prefix: /Users/datitran/anaconda/envs/raccoon_dataset

14 changes: 7 additions & 7 deletions generate_tfrecord.py
Expand Up @@ -13,7 +13,6 @@

import os
import io
import operator
import pandas as pd
import tensorflow as tf

Expand All @@ -37,10 +36,8 @@ def class_text_to_int(row_label):

def split(df, group):
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group, sort=True)
# Use ordered dict as gp.groups() creates random order due to dict
sorted_group = OrderedDict(sorted([i for i in gb.groups.items()]))
return [data(filename, gb.get_group(x)) for filename, x in zip(sorted_group.keys(), sorted_group)]
gb = df.groupby(group)
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


def create_tf_example(group, path):
Expand Down Expand Up @@ -88,11 +85,14 @@ def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
path = os.path.join(os.getcwd(), 'images')
examples = pd.read_csv(FLAGS.csv_input)
for index, row in examples.iterrows():
tf_example = create_tf_example(row, path)
grouped = split(examples, 'filename')
for group in grouped:
tf_example = create_tf_example(group, path)
writer.write(tf_example.SerializeToString())

writer.close()
output_path = os.path.join(os.getcwd(), FLAGS.output_path)
print('Successfully created the TFRecords: {}'.format(output_path))


if __name__ == '__main__':
Expand Down
86 changes: 85 additions & 1 deletion test_generate_tfrecord.py
@@ -1,9 +1,9 @@
import os
import PIL
import generate_tfrecord
import numpy as np
import pandas as pd
import tensorflow as tf
import generate_tfrecord


class CSVToTFExampleTest(tf.test.TestCase):
Expand Down Expand Up @@ -186,3 +186,87 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self):
self._assertProtoEqual(
example.features.feature['image/object/class/label'].int64_list.value,
[1])

def test_csv_to_tf_example_multiple_raccoons_multiple_files(self):
"""Generate tf records for multiple raccoons for multiple files."""
image_file_one = 'tmp_raccoon_image_1.jpg'
image_file_two = 'tmp_raccoon_image_2.jpg'
image_data = np.random.rand(256, 256, 3)
save_path_one = os.path.join(self.get_temp_dir(), image_file_one)
save_path_two = os.path.join(self.get_temp_dir(), image_file_two)
image = PIL.Image.fromarray(image_data, 'RGB')
image.save(save_path_one)
image.save(save_path_two)

column_names = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
raccoon_data = [('tmp_raccoon_image_1.jpg', 256, 256, 'raccoon', 64, 64, 192, 192),
('tmp_raccoon_image_1.jpg', 256, 256, 'raccoon', 32, 32, 96, 96),
('tmp_raccoon_image_2.jpg', 256, 256, 'raccoon', 96, 96, 128, 128)]
raccoon_df = pd.DataFrame(raccoon_data, columns=column_names)

grouped = generate_tfrecord.split(raccoon_df, 'filename')
for group in grouped:
if group.filename == image_file_one:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/width'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/filename'].bytes_list.value,
[image_file_one.encode('utf-8')])
self._assertProtoEqual(
example.features.feature['image/source_id'].bytes_list.value,
[image_file_one.encode('utf-8')])
self._assertProtoEqual(
example.features.feature['image/format'].bytes_list.value, [b'jpg'])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmin'].float_list.value,
[0.25, 0.125])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymin'].float_list.value,
[0.25, 0.125])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmax'].float_list.value,
[0.75, 0.375])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymax'].float_list.value,
[0.75, 0.375])
self._assertProtoEqual(
example.features.feature['image/object/class/text'].bytes_list.value,
[b'raccoon', b'raccoon'])
self._assertProtoEqual(
example.features.feature['image/object/class/label'].int64_list.value,
[1, 1])
elif group.filename == image_file_two:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/width'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/filename'].bytes_list.value,
[image_file_two.encode('utf-8')])
self._assertProtoEqual(
example.features.feature['image/source_id'].bytes_list.value,
[image_file_two.encode('utf-8')])
self._assertProtoEqual(
example.features.feature['image/format'].bytes_list.value, [b'jpg'])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmin'].float_list.value,
[0.375])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymin'].float_list.value,
[0.375])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmax'].float_list.value,
[0.5])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymax'].float_list.value,
[0.5])
self._assertProtoEqual(
example.features.feature['image/object/class/text'].bytes_list.value,
[b'raccoon'])
self._assertProtoEqual(
example.features.feature['image/object/class/label'].int64_list.value,
[1])
170 changes: 170 additions & 0 deletions test_xml_to_csv.py
@@ -0,0 +1,170 @@
import shutil
import os
import tempfile
import unittest
import xml_to_csv
from xml.etree import ElementTree as ET


class XMLToCSVTest(unittest.TestCase):
def test_one_raccoon_one_xml(self):
xml_file_one = """
<annotation verified="yes">
<folder>images</folder>
<filename>raccoon-1.png</filename>
<path>raccoon-1.png</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>256</width>
<height>256</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>96</xmin>
<ymin>96</ymin>
<xmax>128</xmax>
<ymax>128</ymax>
</bndbox>
</object>
</annotation>
"""

xml = ET.fromstring(xml_file_one)
with tempfile.TemporaryDirectory() as tmpdirname:
tree = ET.ElementTree(xml)
tree.write(tmpdirname + '/test_raccoon_one.xml')
raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
self.assertEqual(raccoon_df.columns.values.tolist(),
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])

def test_multiple_raccoon_one_xml(self):
xml_file_one = """
<annotation verified="yes">
<folder>images</folder>
<filename>raccoon-1.png</filename>
<path>raccoon-1.png</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>256</width>
<height>256</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>96</xmin>
<ymin>96</ymin>
<xmax>128</xmax>
<ymax>128</ymax>
</bndbox>
</object>
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>32</xmin>
<ymin>32</ymin>
<xmax>64</xmax>
<ymax>64</ymax>
</bndbox>
</object>
</annotation>
"""

xml = ET.fromstring(xml_file_one)
with tempfile.TemporaryDirectory() as tmpdirname:
tree = ET.ElementTree(xml)
tree.write(tmpdirname + '/test_raccoon_one.xml')
raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
self.assertEqual(raccoon_df.columns.values.tolist(),
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])
self.assertEqual(raccoon_df.values.tolist()[1], ['raccoon-1.png', 256, 256, 'raccoon', 32, 32, 64, 64])

def test_one_raccoon_multiple_xml(self):
xml_file_one = """
<annotation verified="yes">
<folder>images</folder>
<filename>raccoon-1.png</filename>
<path>raccoon-1.png</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>256</width>
<height>256</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>96</xmin>
<ymin>96</ymin>
<xmax>128</xmax>
<ymax>128</ymax>
</bndbox>
</object>
</annotation>
"""
xml_file_two = """
<annotation verified="yes">
<folder>images</folder>
<filename>raccoon-2.png</filename>
<path>raccoon-2.png</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>256</width>
<height>256</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>128</xmin>
<ymin>128</ymin>
<xmax>194</xmax>
<ymax>194</ymax>
</bndbox>
</object>
</annotation>
"""
xml_list = [xml_file_one, xml_file_two]
tmpdirname = tempfile.mkdtemp()
for index, x in enumerate(xml_list):
xml = ET.fromstring(x)
tree = ET.ElementTree(xml)
tree.write(tmpdirname + '/test_raccoon_{}.xml'.format(index))

raccoon_df = xml_to_csv.xml_to_csv(tmpdirname)
self.assertEqual(raccoon_df.columns.values.tolist(),
['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'])
self.assertEqual(raccoon_df.values.tolist()[0], ['raccoon-1.png', 256, 256, 'raccoon', 96, 96, 128, 128])
self.assertEqual(raccoon_df.values.tolist()[1], ['raccoon-2.png', 256, 256, 'raccoon', 128, 128, 194, 194])
shutil.rmtree(tmpdirname)
14 changes: 9 additions & 5 deletions xml_to_csv.py
Expand Up @@ -4,11 +4,9 @@
import xml.etree.ElementTree as ET


def xml_to_csv():
image_path = os.path.join(os.getcwd(), 'annotations')

def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(image_path + '/*.xml'):
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
Expand All @@ -24,8 +22,14 @@ def xml_to_csv():
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
xml_df = pd.DataFrame(xml_list, columns=column_name)
return xml_df


def main():
image_path = os.path.join(os.getcwd(), 'annotations')
xml_df = xml_to_csv(image_path)
xml_df.to_csv('raccoon_labels.csv', index=None)
print('Successfully converted xml to csv.')


xml_to_csv()
main()

0 comments on commit 386a8f4

Please sign in to comment.