Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layers for MobileNet from TensorFlow #9517

Merged
merged 1 commit into from Sep 18, 2017

Conversation

dkurt
Copy link
Member

@dkurt dkurt commented Aug 30, 2017

This pullrequest changes

resolves #9462 (waiting for feedback)

  • ReLU6 layer added
  • depthwise_conv2d layer from TensorFlow (convolution with #groups == #input_channels)
  • Not fused batch normalization by single Mul and Add support

Merge with extra: opencv/opencv_extra#370

How to run MobileNet using DNN:

  • Go to https://github.com/tensorflow/models/blob/master/slim/nets/mobilenet_v1.md and download checkpoint for MobileNet_v1_1.0_224 model. Unpack and navigate into the folder that contains:

    mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
    mobilenet_v1_1.0_224.ckpt.index
    mobilenet_v1_1.0_224.ckpt.meta
    
  • Create .pb model by:

    python ~/tensorflow/tensorflow_models/slim/export_inference_graph.py  \
      --model_name=mobilenet_v1 \
      --output_file=mobilenet_v1.pb \
      --image_size=224

    source: https://github.com/tensorflow/models/blob/master/slim/README.md#exporting-the-inference-graph

  • Freeze

    python ~/tensorflow/tensorflow/python/tools/freeze_graph.py \
      --input_graph=mobilenet_v1.pb \
      --input_checkpoint=mobilenet_v1_1.0_224.ckpt \
      --output_graph=mobilenet_v1_frozen.pb \
      --output_node_names=MobilenetV1/Predictions/Softmax \
      --input_binary

    source: https://github.com/tensorflow/models/blob/master/slim/README.md#freezing-the-exported-graph

  • Modify for DNN: fuse batch normalizations and remove Squeeze op.

    ~/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
      --in_graph=mobilenet_v1_frozen.pb \
      --out_graph=mobilenet_v1_for_dnn.pb \
      --inputs=input \
      --outputs=MobilenetV1/Predictions/Softmax \
      --transforms="fold_constants sort_by_execution_order remove_nodes(op=Squeeze)"
  • Enjoy with DNN:

    #include <iostream>
    
    #include <opencv2/opencv.hpp>
    #include <opencv2/dnn.hpp>
    
    int main() {
      cv::dnn::Net net = cv::dnn::readNetFromTensorflow("../mobilenet_v1_for_dnn.pb");
    
      cv::Mat input = cv::imread("toucan.jpg");
      cv::Mat blob = cv::dnn::blobFromImage(input, 1.0 / 255, cv::Size(224, 224));
    
      net.setInput(blob);
      cv::Mat output = net.forward();
    
      double minVal, maxVal;
      cv::Point minLoc, maxLoc;
      cv::minMaxLoc(output, &minVal, &maxVal, &minLoc, &maxLoc);
      std::cout << maxLoc << " " << maxVal << std::endl;
    
      return 0;
    }

    output:

    [97, 0] 0.999997
    

    And if I'm right and MobileNet uses 0th class as None, 97th class is a toucan (see synset_words.txt)

@@ -343,6 +343,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
static Ptr<ReLULayer> create(const LayerParams &params);
};

class CV_EXPORTS ReLU6Layer : public ActivationLayer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the layer slightly more universal? result(x, y, c) = min(max(src(x, y, c), a), b) with customizable a and b?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalized. Name of layer is the same.

v_float32x4 x1 = v_load(srcptr + i + 4);
v_float32x4 x2 = v_load(srcptr + i + 8);
v_float32x4 x3 = v_load(srcptr + i + 12);
x0 = v_select(x0 >= z, x0, z);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use more efficient v_min and v_max instead of v_select

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

}
kshape[0] = inCh * chMultiplier;
kshape[1] = 1;
}
layerParams.set("kernel_h", kshape[2]);
layerParams.set("kernel_w", kshape[3]);
layerParams.set("num_output", kshape[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set some min_value=0 and max_value=6 here

void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = min(max(minValue, input), maxValue);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clamp(input, minValue, maxValue) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thank you!

}
#endif // HAVE_HALIDE

int64 getFLOPSPerElement() const { return 1; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 ?

@quarcko
Copy link

quarcko commented Sep 1, 2017

Tried your patch and it works, except one thing.
Recently it became possible to "retrain" model, and when doing this, there is added new layer
that is still unsupported: "PlaceholderWithDefault"

https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/placeholder-with-default

Tried removing that layer from model, but it seems without it model fails with cvassert on "add" operation later on.

@dkurt
Copy link
Member Author

dkurt commented Sep 1, 2017

@quarcko, could you please provide some way to reproduce it? I think we could resolve it faster if there were some steps like at the PR's topic.

@quarcko
Copy link

quarcko commented Sep 3, 2017

Sure,

I used this blog post to train my model:
https://hackernoon.com/creating-insanely-fast-image-classifiers-with-mobilenet-in-tensorflow-f030ce0a2991

Used model is "mobilenet_1.0_224"

After "retraining" the model as i mentioned there is added new unsupported layer.
Which is between input layer and "add" operation later on, so removing it crashes the model (i think so).

Here i will attach my retrained sample model so you can test it without doing the training part.
This is unmodified file, so you still have to: "fold_constants sort_by_execution_order remove_nodes(op=Squeeze)"

You will notice, that after this model will fail at "PlaceholderWithDefault"
Sure, you can try to remove it also, but then when running "forward()" model will crash.

https://www.dropbox.com/s/r1u6w52flwgt8ft/output_graph.pb?dl=0

@dkurt
Copy link
Member Author

dkurt commented Sep 4, 2017

@quarcko, Could you try it again? The necessary changes were made. There are transformations that must be applied to referenced model:

~/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
  --in_graph=output_graph.pb \
  --out_graph=transformed_graph.pb \
  --inputs=input \
  --outputs=final_result \
  --transforms="fold_constants sort_by_execution_order remove_nodes(op=Squeeze, op=PlaceholderWithDefault)"

@vpisarev
Copy link
Contributor

@dkurt, cannot merge the patch because of conflicts, could you please fix it?

@dkurt
Copy link
Member Author

dkurt commented Sep 15, 2017

@vpisarev, the conflicts were solved.

@vpisarev vpisarev self-assigned this Sep 18, 2017
@vpisarev
Copy link
Contributor

👍

@opencv-pushbot opencv-pushbot merged commit d891e9b into opencv:master Sep 18, 2017
@dkurt dkurt deleted the tf_mobilenet branch September 20, 2017 16:38
@JosiahKane JosiahKane mentioned this pull request Oct 20, 2017
@sEasonsQAQ
Copy link

@dkurt I met the problem : unsupported layer:PlaceholderWithDefault
and I transformed my retained model per your comment on 4 Sep:
~/tensorflow/bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=output_graph.pb --out_graph=transformed_graph.pb --inputs=input --outputs=final_result --transforms="fold_constants sort_by_execution_order remove_nodes(op=Squeeze, op=PlaceholderWithDefault)"
and dnn could read transformed_graph.pb now, but the result(class and probability) is not correct. Seems the model wasn't read correctly by dnn,
Then I run tf's own test code to read transformed_graph.pb and it gave a correct result, could you please help to find out why the result of dnn wrong?

@dkurt
Copy link
Member Author

dkurt commented Dec 19, 2017

@sEasonsQAQ, can you try to add an extra NHWC->NCHW permutation node before the first MatMul's Reshape as described at http://answers.opencv.org/question/180474/dnn-different-results-between-version-330-and-331/?

@sEasonsQAQ
Copy link

sEasonsQAQ commented Dec 20, 2017

@dkurt ,thanks for your help, and I tested per what you suggested,but the result is still not corrected

  1. I modified the model and save it as 'dnn.pbtxt' :
import tensorflow as tf

# Read the graph.
with tf.gfile.FastGFile('dnn.pb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# Remove Const nodes.
for i in reversed(range(len(graph_def.node))):
    if graph_def.node[i].op == 'Const':
        del graph_def.node[i]
    for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
                 'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
                 'Tpaddings']:
        if attr in graph_def.node[i].attr:
            del graph_def.node[i].attr[attr]

# Save as text.
tf.train.write_graph(graph_def, "", "dnn.pbtxt", as_text=True)
  1. replaced node in pbtxt:
node {
  name: "MobilenetV1/Predictions/Reshape"
  op: "Reshape"
  input: "MobilenetV1/Logits/Conv2d_1c_1x1/biases"
  input: "MobilenetV1/Predictions/Reshape/shape"
}

onto:
node {
  name: "order"
  op: "Const"
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
        int_val: 1
        int_val: 2
        int_val: 3
      }
    }
  }
}
node {
  name: "transpose"
  op: "Transpose"
  input: "MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd"
  input: "order"
}
node {
  name: "MobilenetV1/Predictions/Reshape"
  op: "Reshape"
  input: "transpose"
  input: "MobilenetV1/Predictions/Reshape/shape"
}
  1. Then I test the model tf & dnn with b0.jpg(class 'b'/'s source train picture):
import numpy as np
import tensorflow as tf
import time
import cv2 as cv

with tf.gfile.FastGFile('dnn.pb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

img = cv.imread('b0.jpg')
img = cv.resize(img,(128,128))
#escala=1.0/255.0;
inp = cv.dnn.blobFromImage(img,1.0,(128,128),(127.5,127.5,127.5));
with tf.Session() as sess:
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    for node in graph_def.node:
       print node.op, node.name

    #Generate input
    #np.random.seed(2701)
    #inp = np.random.standard_normal([1,128,128,3]).astype(np.float32)


    # Get output tensor
    outTensor = sess.graph.get_tensor_by_name('final_result:0')
    out = sess.run(outTensor, feed_dict={'input:0': inp.transpose(0,2,3,1)})

cvNet = cv.dnn.readNetFromTensorflow('dnn.pb','dnn.pbtxt')
cvNet.setInput(inp)
cvOut = cvNet.forward()
print cvOut,out
print np.max(np.abs(cvOut - out)) 	

and got 1.28523e-07 maximal absolute difference;
but the result: class both are 'a'(which should have been 'b') , both prob:0.99+

in another c++ tf test code,it returns an expected result: class 'b', prob:0.99+;

are my test steps wrong?kindly help with this condition,thanks!

another c++ code:

int main(int argc, char* argv[]) {
	// These are the command-line flags the program can understand.
	// They define where the graph and input data is located, and what kind of
	// input the model expects. If you train your own model, or use something
	// other than inception_v3, then you'll need to update these.
	string image = "b0.jpg";
	string graph =
		"./models/dnn.pb";
	string labels =
		"./models/pics_labels.txt";
	int32 input_width = 128;
	int32 input_height = 128;
	int32 input_mean = 127.5;
	int32 input_std = 127.5;
	string input_layer = "input";
	string output_layer = "final_result";
	int32 self_test = 0;
	string root_dir = "";
	std::vector<Flag> flag_list = {
		Flag("image", &image, "image to be processed"),
		Flag("graph", &graph, "graph to be executed"),
		Flag("labels", &labels, "name of file containing labels"),
		Flag("input_width", &input_width, "resize image to this width in pixels"),
		Flag("input_height", &input_height,
		"resize image to this height in pixels"),
		Flag("input_mean", &input_mean, "scale pixel values to this mean"),
		Flag("input_std", &input_std, "scale pixel values to this std deviation"),
		Flag("input_layer", &input_layer, "name of input layer"),
		Flag("output_layer", &output_layer, "name of output layer"),
		Flag("self_test", &self_test, "run a self test"),
		Flag("root_dir", &root_dir,
		"interpret image and graph file names relative to this directory"),
	};
	string usage = tensorflow::Flags::Usage(argv[0], flag_list);
	const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
	if (!parse_result) {
		LOG(ERROR) << usage;
		return -1;
	}

	// We need to call this to set up global state for TensorFlow.
	tensorflow::port::InitMain(argv[0], &argc, &argv);
	if (argc > 1) {
		LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
		return -1;
	}

	// First we load and initialize the model.
	std::unique_ptr<tensorflow::Session> session;
	string graph_path = tensorflow::io::JoinPath(root_dir, graph);
	Status load_graph_status = LoadGraph(graph_path, &session);
	if (!load_graph_status.ok()) {
		LOG(ERROR) << load_graph_status;
		return -1;
	}

	// Get the image from disk as a float array of numbers, resized and normalized
	// to the specifications the main graph expects.
	std::vector<Tensor> resized_tensors;
	string image_path = tensorflow::io::JoinPath(root_dir, image);
	Status read_tensor_status =
		ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
			input_std, &resized_tensors);
	if (!read_tensor_status.ok()) {
		LOG(ERROR) << read_tensor_status;
		return -1;
	}
	const Tensor& resized_tensor = resized_tensors[0];


	//std::cout << resized_tensors[0] <<std::endl;


	// Actually run the image through the model.
	std::vector<Tensor> outputs;
	Status run_status = session->Run({ { input_layer, resized_tensor } },
	{ output_layer }, {}, &outputs);
	if (!run_status.ok()) {
		LOG(ERROR) << "Running model failed: " << run_status;
		return -1;
	}

	// This is for automated testing to make sure we get the expected result with
	// the default settings. We know that label 653 (military uniform) should be
	// the top label for the Admiral Hopper image.
	if (self_test) {
		bool expected_matches;
		Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
		if (!check_status.ok()) {
			LOG(ERROR) << "Running check failed: " << check_status;
			return -1;
		}
		if (!expected_matches) {
			LOG(ERROR) << "Self-test failed!";
			return -1;
		}
	}

	// Do something interesting with the results we've generated.
	Status print_status = PrintTopLabels(outputs, labels);
	if (!print_status.ok()) {
		LOG(ERROR) << "Running print failed: " << print_status;
		return -1;
	}

	return 0;
}

@naguirre
Copy link

@sEasonsQAQ did you find a way to get your retrained model running correctly with opencv dnn module ?
I get the same problem as yours (wrong labels and predictions values) after trying retrained model of https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2

@mevatron
Copy link
Contributor

mevatron commented Feb 22, 2018

@sEasonsQAQ @naguirre I'm also seeing this exact behavior. It seems to test fine on the original training dataset, but real-world samples don't predict correctly, however the same model with TensorFlow Mobile works fine. We should probably create a bug report so this can be tracked.

@dkurt
Copy link
Member Author

dkurt commented Feb 22, 2018

@mevatron, May I ask you to open a topic at http://answers.opencv.org? Please make it as much reproducible as possible. Do not insert cross references to old questions or issues. The best way is to attach a .pb model and describe how it was obtained from TensorFlow. Thanks!

@mevatron
Copy link
Contributor

@dkurt As requested I posted a detailed write-up to reproduce the issue we are seeing here: http://answers.opencv.org/question/185283/opencv_dnn-provides-incorrect-inferences-after-transform_graph/

Any of your insights would be greatly appreciated!

Thanks for your time!

@sEasonsQAQ
Copy link

sEasonsQAQ commented Feb 24, 2018

@naguirre sorry for the delay, I just transformed the tensorflow-1.2.0 retrained .pb:
./transform_graph --in_graph=retrained_mobilenet.pb --out_graph=mobilenet_v1_for_dnn.pb --inputs=input --outputs=final_result --transforms="fold_constants sort_by_execution_order remove_nodes(op=Squeeze,op=PlaceholderWithDefault)"
and it works well with opencv-3.4.0, hope it useful to you.

@mevatron
Copy link
Contributor

@sEasonsQAQ Interesting, I'm using tensorflow master, I wonder if that might be the source of my issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to import mobilenet model using latest OpenCV.
8 participants