Skip to content

cuda 9.0 "error: more than one operator "==" matches these operands" #797

Open
@dllehr81

Description

@dllehr81

While attempting to build torch from master with cutorch with cuda 9.0.103-1 on Ubuntu 16.04 I hit an error with multiple attempts to overload the "==" and "!=" operators.

Below is an example of the error I receive.

lib/THC/CMakeFiles/THC.dir/build.make:4243: recipe for target 'lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathPairwise.cu.o' failed
make[2]: *** [lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathPairwise.cu.o] Error 1
/pkgbuild/torch/torch/extra/cutorch/lib/THC/generic/THCTensorMath.cu(393): error: more than one operator "==" matches these operands:
            function "operator==(const __half &, const __half &)"
            function "operator==(half, half)"
            operand types are: half == half

/pkgbuild/torch/torch/extra/cutorch/lib/THC/generic/THCTensorMath.cu(414): error: more than one operator "==" matches these operands:
            function "operator==(const __half &, const __half &)"
            function "operator==(half, half)"
            operand types are: half == half

I was able to track down the two operator overloads.
One is in
https://github.com/torch/cutorch/blob/master/lib/THC/THCTensorTypeUtils.cuh#L176

And the other is in
/usr/local/cuda-9.0/targets/ppc64le-linux/include/cuda_fp16.hpp

The operator in cuda_fp16.hpp was provided by the cuda package, but only covers the __device__ and not the __host__. So we still need to overload the "==" for halfs in the __host__, however, the code currently in cutorch fails on compile time.

It looks like @csarofeen worked on the initial port to cuda9.0 for cutorch. I'm not sure if he can provide some help on what's going on here?

Is there any additional information you need from me? Thanks in advance!!

Activity

csarofeen

csarofeen commented on Aug 23, 2017

@csarofeen
Contributor

Before you build try export TORCH_NVCC_FLAGS="-D__CUDA_NO_HALF_OPERATORS__"

dllehr81

dllehr81 commented on Aug 23, 2017

@dllehr81
Author

Hey @csarofeen . That did the trick! On a side note. This has the appearance of disabling the half operators in the cuda code. Will this impact the half variables performance when run on the device?

csarofeen

csarofeen commented on Aug 23, 2017

@csarofeen
Contributor

It will, for the better.

betterjordache

betterjordache commented on Sep 8, 2017

@betterjordache

@csarofeen this did it for me as well, thank you!

ProGamerGov

ProGamerGov commented on Oct 24, 2017

@ProGamerGov

I had the same issue:

[  4%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCSleep.cu.o
[  5%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCStorage.cu.o
[  6%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCStorageCopy.cu.o
[  7%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensor.cu.o
[  8%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorCopy.cu.o
[ 10%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMath.cu.o
[ 11%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMath2.cu.o
[ 12%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathBlas.cu.o
[ 13%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathMagma.cu.o
[ 14%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathPairwise.cu.o
/home/ubuntu/torch/extra/cutorch/lib/THC/generic/THCTensorMath.cu(393): error: more than one operator "==" matches these operands:
            function "operator==(const __half &, const __half &)"
            function "operator==(half, half)"
            operand types are: half == half

/home/ubuntu/torch/extra/cutorch/lib/THC/generic/THCTensorMath.cu(414): error: more than one operator "==" matches these operands:
            function "operator==(const __half &, const __half &)"
            function "operator==(half, half)"
            operand types are: half == half

[ 15%] Building NVCC (Device) object lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathReduce.cu.o
2 errors detected in the compilation of "/tmp/tmpxft_00002141_00000000-4_THCTensorMath.cpp4.ii".
CMake Error at THC_generated_THCTensorMath.cu.o.cmake:267 (message):
  Error generating file
  /home/ubuntu/torch/extra/cutorch/build/lib/THC/CMakeFiles/THC.dir//./THC_generated_THCTensorMath.cu.o


lib/THC/CMakeFiles/THC.dir/build.make:112: recipe for target 'lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMath.cu.o' failed
make[2]: *** [lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMath.cu.o] Error 1
make[2]: *** Waiting for unfinished jobs....
^Clib/THC/CMakeFiles/THC.dir/build.make:105: recipe for target 'lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorCopy.cu.o' failed
make[2]: *** [lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorCopy.cu.o] Interrupt
lib/THC/CMakeFiles/THC.dir/build.make:140: recipe for target 'lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathPairwise.cu.o' failed
make[2]: *** [lib/THC/CMakeFiles/THC.dir/THC_generated_THCTensorMathPairwise.cu.o] Interrupt
CMakeFiles/Makefile2:172: recipe for target 'lib/THC/CMakeFiles/THC.dir/all' failed
make[1]: *** [lib/THC/CMakeFiles/THC.dir/all] Interrupt
Makefile:127: recipe for target 'all' failed
make: *** [all] Interrupt

Error: Build error: Failed building.
ubuntu@ip-Address:~/torch$

Running ./clean.sh and then using: export TORCH_NVCC_FLAGS="-D__CUDA_NO_HALF_OPERATORS__", before finally running ./install.sh worked!

I was using Ubuntu 16.04.

ProGamerGov

ProGamerGov commented on Oct 24, 2017

@ProGamerGov

@csarofeen If it's better to disable the half operators, then what are they used for? Why are they included in the cuda code? And what kind of performance boost are we talking about here?

csarofeen

csarofeen commented on Oct 24, 2017

@csarofeen
Contributor

Cuda 9 added half operators in the cuda half header. Half operations in torch predate that so they already existed in torch. This keeps the half definition from the cuda header, while not compiling the operators.

ProGamerGov

ProGamerGov commented on Oct 25, 2017

@ProGamerGov

@csarofeen Do you have any other performance tips for Cuda and/or cuDNN with Torch7?

Because I've noticed that Cuda 9.0 and cuDNN v7 have even worse performance than Cuda 8.0 and cuDNN v5: jcjohnson/neural-style#429

sfzyk

sfzyk commented on Nov 15, 2017

@sfzyk

same issue. but export TORCH_NVCC_FLAGS="-D__CUDA_NO_HALF_OPERATORS__" didn't work for me?? how to solve that?

csarofeen

csarofeen commented on Nov 16, 2017

@csarofeen
Contributor

@sfzyk Could you please explain all steps you took to install CUDA, NCCL, cuDNN, and pytorch and paste here some of the output from the error? It is very hard to assist the only information provided is "didn't work".

38 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @ricpruss@wuyun8210@hzxie@gemfield@ckddls1321

        Issue actions

          cuda 9.0 "error: more than one operator "==" matches these operands" · Issue #797 · torch/cutorch