Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way in Keras to apply different weights to a cost function? #2115

Closed
ayalalazaro opened this issue Mar 29, 2016 · 103 comments
Closed

Comments

@ayalalazaro
Copy link

Hi there,
I am trying to implement a classification problem with three classes: 0,1 and 2. I would like to fine tune my cost function so that missclassification is weighted some how. In particular, predicting 1 instead of 2 should give twice the cost than predicting 0. writing it in a table format, it should be something like that:

Costs:
Predicted:
0 | 1 | 2
__________________________
Actual 0 | 0 | 0.25 | 0.25
1 | 0.25 | 0 | 0.5
2 | 0.25 | 0.5 | 0

I really like keras framework, it would be nice if it is possible to implement it and not having to dig into tensorflow or theano code.

Thanks

@ayalalazaro
Copy link
Author

Sorry, the table has lost its format, I am sending an image:
image

@carlthome
Copy link
Contributor

Similar: #2121

@tboquet
Copy link
Contributor

tboquet commented Mar 29, 2016

You could use class_weight.

@ayalalazaro
Copy link
Author

class_weight applies a weight to all data that belongs to the class, it should be dependent on the missclassification.

@tboquet
Copy link
Contributor

tboquet commented Mar 30, 2016

You are absolutely right, I'm sorry I misunderstood your question. I will try to come back with something tomorrow using partial to define the weights. What you want to achieve should be doable with Keras abstract backend.

@tboquet
Copy link
Contributor

tboquet commented Mar 31, 2016

Ok so I had the time to quickly test it.
This is a fully reproducible example on mnist where we put a higher cost when a 1 is missclassified as a 7 and when a 7 is missclassified as a 1.

So if you want to pass constants included in the cost function, just build a new function with partial.

'''Train a simple deep NN on the MNIST dataset.
Get to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD, Adam, RMSprop
from keras.utils import np_utils
import keras.backend as K
from itertools import product

# Custom loss function with costs

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

w_array = np.ones((10,10))
w_array[1, 7] = 1.2
w_array[7, 1] = 1.2

ncce = partial(w_categorical_crossentropy, weights=np.ones((10,10)))

batch_size = 128
nb_classes = 10
nb_epoch = 20

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

model = Sequential()
model.add(Dense(512, input_shape=(784,)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))

rms = RMSprop()
model.compile(loss=ncce, optimizer=rms)

model.fit(X_train, Y_train,
          batch_size=batch_size, nb_epoch=nb_epoch,
          show_accuracy=True, verbose=1,
          validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test,
                       show_accuracy=True, verbose=1)
print('Test score:', score[0])
print('Test accuracy:', score[1])

@ayalalazaro
Copy link
Author

Wow, that s nice. Thanks for the detailed answer!

@tboquet
Copy link
Contributor

tboquet commented Apr 1, 2016

Try to test it on a toy example to verify that it actually works. If it's what you are looking for, feel free to close the issue!
Keras 1.0 will provide a more flexible way to introduce new objectives and metrics.

@ayalalazaro
Copy link
Author

Well, I am stuck, I can t make it run in my model, it says:

line 56, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (y_pred.shape[0], 1))

AttributeError: 'Tensor' object has no attribute 'shape'

This is the model I am using:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (y_pred.shape[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

w_array = np.ones((3,3))
w_array[2,1] = 1.2
w_array[1,2] = 1.2
ncce = partial(w_categorical_crossentropy, weights=w_array)

def build_model(X_data):
    data_dim = X_data.shape[2]
    timesteps = X_data.shape[1]
    model = Sequential()
    model.add(BatchNormalization(input_shape = (timesteps,data_dim)))  
    model.add(GRU(output_dim=50,init ='glorot_normal',
         return_sequences=True, W_regularizer=l2(0.00),U_regularizer=l1(0.01),dropout_W =0.2 ))
    model.add(GRU(output_dim=50,init ='glorot_normal',
        return_sequences=True,W_regularizer=l2(0.00),U_regularizer=l1(0.01),dropout_W =0.2))
    model.add(GRU(50,init ='glorot_normal',return_sequences=False,dropout_W =0.01, W_regularizer=l2(0.00),U_regularizer=l1(0.01)))
    model.add(Dense(3, init='glorot_normal'))
    model.add(Activation('softmax'))

    model.compile(loss=ncce,
              optimizer='Adam'
            )
    return model

@tboquet
Copy link
Contributor

tboquet commented Apr 4, 2016

Sure, sorry I was using Theano functionnalities. I replaced the following line in my previous example:

y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))

It should do the trick!

@ayalalazaro
Copy link
Author

Sounds the way to go, I was using tensorflow as backend. I tell you if it works as soon as posiible. Thanks!

@ayalalazaro
Copy link
Author

I still get an error:

line 57, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 271, in reshape
return tf.reshape(x, shape)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 682, in reshape
name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 411, in apply_op
as_ref=input_arg.is_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 529, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 178, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 161, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 319, in make_tensor_proto
_AssertCompatible(values, dtype)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 259, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).name))

TypeError: Expected int32, got list containing Tensors of type '_Message' instead.

I ve tried your first reply under theano backend and it works though.

@tboquet
Copy link
Contributor

tboquet commented Apr 5, 2016

Ok, I was not sure about how K.shape would behave with TensorFlow. It seems you should use:

y_pred_max = K.reshape(y_pred_max, (K.int_shape(y_pred)[0], 1))

@ayalalazaro
Copy link
Author

I get more or less the same:

line 59, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (K.int_shape(y_pred)[0], 1))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 271, in reshape
return tf.reshape(x, shape)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 682, in reshape
name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 411, in apply_op
as_ref=input_arg.is_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 529, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 178, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 161, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 319, in make_tensor_proto
_AssertCompatible(values, dtype)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 259, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).name))

TypeError: Expected int32, got None of type '_Message' instead.

It seems like it cannot get the shape of y_pred as an integer , right?

@tboquet
Copy link
Contributor

tboquet commented Apr 6, 2016

Mm, ok I will take a look at it today and work directly with tensors to try to find a way to have it work properly for both backend.

@pgallego25
Copy link

Hi there, I tried something like that:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, K.shape(y_pred))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):

        final_mask += (K.cast(weights[c_t, c_p],tf.float32) * K.cast(y_pred_max_mat[:, c_p] ,tf.float32)* K.cast(y_true[:, c_t],tf.float32))
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

I Think it will do it.

@ayalalazaro
Copy link
Author

The latter only works for non recurrent networks, but this code works for RNNs following the same idea. It only works for tensorflow though. I couldn t find a way to reshape a tensor the way we want with the keras backend:

import tensorflow as tf

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = tf.expand_dims(y_pred_max, 1)
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):

        final_mask += (K.cast(weights[c_t, c_p],K.floatx()) * K.cast(y_pred_max_mat[:, c_p] ,K.floatx())* K.cast(y_true[:, c_t],K.floatx()))
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

@ayalalazaro
Copy link
Author

My bad, just replacing tf.expand_dims with K.expand_dims worked for me:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.expand_dims(y_pred_max, 1)
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):

        final_mask += (K.cast(weights[c_t, c_p],K.floatx()) * K.cast(y_pred_max_mat[:, c_p] ,K.floatx())* K.cast(y_true[:, c_t],K.floatx()))
    return K.categorical_crossentropy(y_pred, y_true) * final_mask
w_array = np.ones((3,3))
w_array[2,1] = 1.2
w_array[1,2] = 1.2
ncce = partial(w_categorical_crossentropy, weights=w_array)
ncce.__name__ ='w_categorical_crossentropy'

The last line is necessary for tensorboard callback to work, thanks!!

@kimardenmiller
Copy link

kimardenmiller commented Dec 3, 2016

Is the Mar 31 solution for @ayalalazaro above still recommended as of v1.2? (Noticed @tboquet 's comment: Keras 1.0 will provide a more flexible way to introduce new objectives and metrics.)

My problem is binary classification where true positive accuracy is more important, and some false negatives are acceptable. Would I need the approach above to achieve that objective? I tried class_weights = {0: 1, 1: 10}, but saw no change. (examples are 25% positive, 75% negative)

@curiale
Copy link

curiale commented Jan 20, 2017

Just a small detail about the w_categorical_crossentropy implementetion. There is no need to cast weights and y_true. The following code is working in Theano and TensorFlow:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

@jerpint
Copy link

jerpint commented Feb 20, 2017

Hello, I am trying to implement this in tensorflow.

I am confused as to what partial is in the line :

ncce = partial(w_categorical_crossentropy, weights=np.ones((10,10)))

I do not see it defined anywhere in this thread, and get

NameError: name 'partial' is not defined

as output...

Thanks

@0x00b1
Copy link

0x00b1 commented Feb 20, 2017

@jerpint It’s available from functools, i.e.

import functools

ncce = functools.partial(w_categorical_crossentropy, weights=np.ones((10,10)))

@mongoose54
Copy link

mongoose54 commented Feb 23, 2017

I am trying to incorporate @curiale's implementation w_categorical_crossentropy for a binary classification where the output of my model has shape (?, 5120, 2) but I am running into a couple of issues:

  1. Assuming my classs weight distribution is e.g. class_weights=[ 0.85144055 , 1.14855945] What should thew_array be like? Something like this below?

w_array = np.ones((2,2))
w_array[1,0] = class_weights[0]
w_array[0,1] = class_weights[1]
ncce = functools.partial(w_categorical_crossentropy, weights=w_array)

  1. When I run model.compile(optimizer=Adam(lr=1e-5), loss=ncce, metrics=[dice_coef]) I get the following error:

ValueError: Dimensions must be equal, but are 5120 and 2 for 'mul_339' (op: 'Mul') with input shapes: [?,5120], [?,2].

These are the variables' shapes inside w_categorical_crossentropy

y_pred shape: (?, 5120, 2) y_true shape: (?, ?, ?) final_mask.shape: (?, 2)

Frankly I am lost in w_categorical_crossentropy function (e.g. what is final_mask be? Its shape?). Any help would be much appreciated.

@recluze
Copy link

recluze commented Feb 23, 2017

Hnn, I'm sorry but I don't quite understand: What does this (?, 5120, 2) entail? If ? is the batch size and 2 is the number of classes, what is 5120?

@mongoose54
Copy link

@recluze Sorry for the confusion. Let me clarify: The model is an image segmentation network with output (?, 5120, 2) where ? : batch_size , 5120 : total_number_of_pixels_per_image and 2 : classes (foreground, background). So basically the network does classification per pixel.

@eliadl
Copy link

eliadl commented Aug 27, 2019

TypeError: Value passed to parameter 'x' has DataType bool not in list of allowed values: float16, float32, float64, uint8, int8, uint16, int16, int32, int64, complex64, complex128

@enikkari this error can be resolved by adding another line after:

y_pred_max_mat = K.equal(y_pred, y_pred_max)

as following:

y_pred_max_mat = K.equal(y_pred, y_pred_max)
y_pred_max_mat = K.cast(y_pred_max_mat, 'float32')

@eliadl
Copy link

eliadl commented Aug 28, 2019

Also, to prevent a row in y_pred like [.4, .4, .2] being encoded into [1, 1, 0], this:

y_pred_max = K.max(y_pred, axis=1)
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
y_pred_max_mat = K.equal(y_pred, y_pred_max)

can be replaced with a more robust (and intuitive) code:

y_pred_arg_max = K.argmax(y_pred)
y_pred_max_mat = K.one_hot(y_pred_arg_max, num_classes=y_pred.shape[1])

Another added value of this, is it no longer requires to follow with the K.cast fix above.

@eliadl
Copy link

eliadl commented Sep 12, 2019

Adding to the class solution by @SpikingNeuron here in #2115 (comment)
here's a more robust and vectorized implementation:

import tensorflow.keras.backend as K
from tensorflow.keras.losses import CategoricalCrossentropy


class WeightedCategoricalCrossentropy(CategoricalCrossentropy):
    
    def __init__(self, cost_mat, name='weighted_categorical_crossentropy', **kwargs):
        assert cost_mat.ndim == 2
        assert cost_mat.shape[0] == cost_mat.shape[1]
        
        super().__init__(name=name, **kwargs)
        self.cost_mat = K.cast_to_floatx(cost_mat)
    
    def __call__(self, y_true, y_pred, sample_weight=None):
        assert sample_weight is None, "should only be derived from the cost matrix"
      
        return super().__call__(
            y_true=y_true,
            y_pred=y_pred,
            sample_weight=get_sample_weights(y_true, y_pred, self.cost_mat),
        )


def get_sample_weights(y_true, y_pred, cost_m):
    num_classes = len(cost_m)

    y_pred.shape.assert_has_rank(2)
    y_pred.shape[1:].assert_is_compatible_with(num_classes)
    y_pred.shape.assert_is_compatible_with(y_true.shape)

    y_pred = K.one_hot(K.argmax(y_pred), num_classes)

    y_true_nk1 = K.expand_dims(y_true, 2)
    y_pred_n1k = K.expand_dims(y_pred, 1)
    cost_m_1kk = K.expand_dims(cost_m, 0)

    sample_weights_nkk = cost_m_1kk * y_true_nk1 * y_pred_n1k
    sample_weights_n = K.sum(sample_weights_nkk, axis=[1, 2])

    return sample_weights_n

Usage:

model.compile(loss=WeightedCategoricalCrossentropy(cost_matrix), ...)

Similarly, this can be applied for the CategoricalAccuracy metric too:

from tensorflow.keras.metrics import CategoricalAccuracy

        
class WeightedCategoricalAccuracy(CategoricalAccuracy):

    def __init__(self, cost_mat, name='weighted_categorical_accuracy', **kwargs):
        assert cost_mat.ndim == 2
        assert cost_mat.shape[0] == cost_mat.shape[1]
        
        super().__init__(name=name, **kwargs)
        self.cost_mat = K.cast_to_floatx(cost_mat)
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        assert sample_weight is None, "should only be derived from the cost matrix"
    
        return super().update_state(
            y_true=y_true,
            y_pred=y_pred,
            sample_weight=get_sample_weights(y_true, y_pred, self.cost_mat),
        )

Usage:

model.compile(metrics=[WeightedCategoricalAccuracy(cost_matrix), ...], ...)

@girigk
Copy link

girigk commented Sep 23, 2019

In addition to w_arry given by @tboquet in the post above, how to construct the cost_matrix?
For ex. for binary classification,
w_array = np.ones((2,2))
w_array[1,2] = 5.0 (to penalize 1s being mis classified.
y_true and y_pred are the targets.

can somebody help please?

@ledakk
Copy link

ledakk commented Dec 3, 2019

@zaher88abd Say that you have 3 available classes.
Than you would start by defining a 3x3 matrix

w_array = np.ones((3, 3))

Than you can add the weights you'd like to have.
As I said in the comment above,.
w_array[i, j] defines the weight for an example of class i falsely classified as class j.

e.g if you would like to higher penalize examples of class 2 falsely classified as class 3, you could do

w_array[2, 3] = high_weight

If you would like your model to overall put more ephasis on a certain class, you could put high weights on all occurrences of that class.

For example if you'd like to put an overall emphasis on class 2 you could do the following:

w_array[2, :] = high_weight

This will penalize every mistake made with an example with class 2.
But notice that this assignment also includes

w_array[2, 2] = high_weight

This means that this will also penalize an example of class 2 which was labeled correctly but with low confidence.

This behavior may or may not fit you needs.
If you would like to avoid that behavior, you could just do the following:

w_array[2, :] = high_weight
w_array[2, 2] = 1 # restore the original weight 

@GalAvineri
i want to put an overall emphasis on class2, ( I have 3 classes 0, 1, 2 ), in your opinion, i should give w[2][0] and w[2][1] a high weight, but Should I assign the same high weight to w[0][2] and w[0][1]??

@dest-dir
Copy link

dest-dir commented Dec 4, 2019

@eliadl I'm getting an unexpected keyword argument 'sample_weight'
tf python version r1.13

@eliadl
Copy link

eliadl commented Dec 5, 2019

@dest-dir Please post a StackOverflow question with your code, and share the link here. I'll try assist there.

@damhurmuller
Copy link

@eliadl how I insert the cost matrix in another custom loss? Like focal loss

`class FocalLoss(tf.keras.losses.Loss):
def init(self, gamma=2.0, alpha=1.0,
reduction=tf.keras.losses.Reduction.AUTO, name='focal_loss'):

    super(FocalLoss, self).__init__(reduction=reduction,
                                    name=name)
    self.gamma = float(gamma)
    self.alpha = float(alpha)

def call(self, y_true, y_pred):

    epsilon = 1.e-9
    y_true = tf.convert_to_tensor(y_true, tf.float32)
    y_pred = tf.convert_to_tensor(y_pred, tf.float32)

    model_out = tf.add(y_pred, epsilon)
    ce = tf.multiply(y_true, -tf.math.log(model_out))
    weight = tf.multiply(y_true, tf.pow(
       tf.subtract(self.alpha, model_out), self.gamma))
    fl = tf.multiply(1., tf.multiply(weight, ce))
    reduced_fl = tf.reduce_max(ce, axis=1)
    return reduced_fl`

@eliadl
Copy link

eliadl commented Dec 18, 2019

@damhurmuller Please post a StackOverflow question with your code, and share the link here. I'll try assist there.

@mendi80
Copy link

mendi80 commented Dec 29, 2019

For semantic segmentation, with:
Input (rgb) shape=(batch_size, width, height, 3)
Output (one-hot) shape=(batch_size, width, height, n_classes)
The weighted categorical crossentropy loss function is:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

Usage:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)

@yacine074
Copy link

@mendi80 Please, is your function right ?

@PhilAlton
Copy link

PhilAlton commented May 16, 2020

@dest-dir , @eliadl
I encountered the same unexpected sample weight problem. I also ran into some issues when trying to save the entire model (in order to restore from interrupted training, including the optimizer state).

The sample weight problem seems to be solved by changing the magic function __call__'s to call. I also modified the return on call to multiply the output of super().call(y_t,y_p) by the return from get_sample_weights.

@eliadl - I think your approach, from what I understood, was to overwrite/overload rather than access the categorical crossentropy call method and pass in sample_weight as an expected parameter of this call; however, I couldn't figure out why this worked for you and not for us? (And, frankly, my python knowledge isn't really up for figuring this out!)

I utilised @SpikingNeuron's class code in order to get this working. I also changed the weight argument from a positional argument to a named argument as part of trying to get the model loading working

The loss class therefore became:

Class weighted_categorical_crossentropy(tensorflow.keras.losses.CategoricalCrossentropy):
    
  def __init__(
      self,
      *,
      weights,
      from_logits=False,
      label_smoothing=0,
      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
      name='categorical_crossentropy',
  ):

      super().__init__(
          from_logits, label_smoothing, reduction, name=f"weighted_{name}"
      )
      self.weights = weights

  def call(self, y_true, y_pred):
     return super().call(y_true, y_pred) * get_sample_weights(y_true, y_pred, self.weights)

  def get_config(self):
    return {'weights': self.weights}

  @classmethod
  def from_config(cls, config):
    return cls(**config)


def get_sample_weights(y_true, y_pred, cost_m):
    num_classes = len(cost_m)

    cost_m = K.cast(cost_m, 'float32')
    y_pred.shape.assert_has_rank(2)
    assert(y_pred.shape[1] == num_classes)
    y_pred.shape.assert_is_compatible_with(y_true.shape)

    y_pred = K.one_hot(K.argmax(y_pred), num_classes)

    y_true_nk1 = K.expand_dims(y_true, 2)
    y_pred_n1k = K.expand_dims(y_pred, 1)
    cost_m_1kk = K.expand_dims(cost_m, 0)

    sample_weights_nkk = cost_m_1kk * y_true_nk1 * y_pred_n1k
    sample_weights_n = K.sum(sample_weights_nkk, axis=[1, 2])

    return sample_weights_n

Note the inclusion of:

  def get_config(self):
    return {'weights': self.weights}

  @classmethod
  def from_config(cls, config):
    return cls(**config)

This is necessary in order for the custom loss function to be registered with Keras for model saving.
I also included the following (after the class code) to make sure that this registration happens:

tf.keras.losses.weighted_categorical_crossentropy = weighted_categorical_crossentropy

Usage:

model.compile(
    optimizer='adam',
    loss={'output': weighted_categorical_crossentropy(weights=cost_matrix)
    )

Saving:

model.save(filepath,,save_format='tf')

Loading:

model = tf.keras.models.load_model(
    filepath,
    compile=True,
    custom_objects={
        'weighted_categorical_crossentropy': weighted_categorical_crossentropy(weights=cost_matrix)
        }
    )

Feedback welcome.
Hope this helps.

@eliadl
Copy link

eliadl commented May 17, 2020

@PhilAlton

  1. __call__ accepts sample_weight and handles it inherently, while call doesn't. You had to provide your own implementation there. I didn't.
  2. __call__ does access the categorical crossentropy call method, as my class inherits from CategoricalCrossentropy which uses the categorical_crossentropy function.
  3. CategoricalCrossentropy.from_config is already implemented (or inherited) so there's no need to override it with the same code.
  4. I'm not sure what you exactly mean by "this doesn't work for us". If you post a link to a StackOverflow question, I'll do my best to answer it without polluting this GitHub issue.
  5. Your override of get_config doesn't account for arguments of base class. This does:
def get_config(self):
    return super().get_config().copy().update(
        {'weights': self.weights}
    )

@PhilAlton
Copy link

PhilAlton commented May 20, 2020

@eliadl - Thanks; SO Question

@eliadl
Copy link

eliadl commented May 20, 2020

@eliadl I'm getting an unexpected keyword argument 'sample_weight'
tf python version r1.13

@dest-dir as @PhilAlton found, the problem was __call__didn't match its original signature.

    def __call__(self, y_true, y_pred):

should have been this:

    def __call__(self, y_true, y_pred, sample_weight=None):

@william-allen-harris
Copy link

Hello does anyone know how to do this for sparse categorical crossentropy?

@hiyamgh
Copy link

hiyamgh commented Feb 7, 2021

Hello, thank you for this awesome thread. I have a small question though, I am trying to implement this solution in tensorflow rather than in keras backend. My question is that are these applied to the logits (the output of the last layer of the neural network which are raw values (not probabilities) that we did NOT apply softmax to) or are these the probabilities (AFTER softmax)?

In other words, is K.categorical_cross_entropy the equivalent of tf.nn.softmax_cross_entropy_with_logits or not ?

Thank you in advance.

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

@hiyamgh
Copy link

hiyamgh commented Feb 7, 2021

I guess I found the answer,

I have seen the documentation of tf.backend.categorical_cross_entropy and it states the following:

tf.keras.backend.categorical_crossentropy(
    target,
    output,
    from_logits=False
)

target: A tensor of the same shape as output.
output: A tensor resulting from a softmax (unless from_logits is True, in which case output is expected to be the logits).
from_logits: Boolean, whether output is the result of a softmax, or is a tensor of logits.

I will just do in on the logits then?

@TejashwiniDuluri
Copy link

@machisuke shouldn't it be return
K.categorical_crossentropy(y_true, y_pred) * final_mask
instead of
K.categorical_crossentropy(y_pred, y_true) * final_mask

as @blakewest pointed out from the Keras source code?

yeah even I have the same question on this.

@isaranto
Copy link

isaranto commented Apr 8, 2021

With tf.keras implementation I would propose a more vectorized approach (avoid the for loop):

def weighted_categorical_crossentropy_new(y_true, y_pred, weights):
          idx1 = K.argmax(y_pred, axis=1)
          idx2 = K.argmax(y_true, axis=1)
          mask = tf.gather_nd(weights, tf.stack((idx1, idx2), -1))
          return K.categorical_crossentropy(y_true, y_pred) * mask

You can modify the above to fit your needs but it worked for me with an example weight matrix, where you want to achieve some cost-sensitive learning where certain mispredictions are more/less important than others.

 [[0.0, 1.5, 3.0, 4.5, 6.0],
  [1.0, 0.0, 1.5, 3.0, 4.5],
  [2.0, 1.0, 0.0, 1.5, 3.0],
  [3.0, 2.0, 1.0, 0.0, 1.5],
  [4.0, 3.0, 2.0, 1.0, 0.0]]

@rafaelagrc
Copy link

Hello. How can we implement this for Sparse Categorical Cross Entropy?

@jsbryaniv
Copy link

I would also like to know how to implement this for SparseCategoricalCrossEntropy

@RachelRamirez
Copy link

RachelRamirez commented Feb 14, 2023

With tf.keras implementation I would propose a more vectorized approach (avoid the for loop):

def weighted_categorical_crossentropy_new(y_true, y_pred, weights):
          idx1 = K.argmax(y_pred, axis=1)
          idx2 = K.argmax(y_true, axis=1)
          mask = tf.gather_nd(weights, tf.stack((idx1, idx2), -1))
          return K.categorical_crossentropy(y_true, y_pred) * mask

Has anyone verified the code above works? If so can they share a minimal working example? @isaranto have you verified your vectorized approach of the original method works on the MNIST network example as given? I put a high weight on the misclassification that naturally seems to be highest when running the dense neural network given by @tboquet and the results are not intuitive, as in, the number of misclassifications does not decrease. I've compared the confusion matrix of results for using weights on w[7,9] =1, 1.1, 1.2, 1.5, 1.7, 2, ... 10, ... 100 and one would expect the number of misclassification on [7,9], to decrease as the weight increases, but there doesn't seem to be a consistent pattern and if anything it seems like 7 out of 30 times I run the results, the misclassification for [7,9] increases dramatically (like from 20 to 386). So I tried negative numbers, and that did have the immediate effect of decreasing the misclassification rates. However, using negative numbers isn't consistent with any of the above discussion.

Here's the code I've used - it's long so I posted a link to my Public Google Colab Notebook: https://github.com/RachelRamirez/misclassification_matrix/blob/main/w%5B7%2C9%5D%3D100_Misclassification_Cost_Matrix_Example.ipynb

This is the output of one of the worst confusion matrixes (run 14) using w[7,9]=100. It seems like its rewarding the misclassification instead of the reverse.
CM:

0 1 2 3 4 5 6 7 8 9
0 970 2 0 1 2 0 2 0 1 2
1 0 1129 1 1 0 0 2 0 2 0
2 5 2 1005 4 2 0 2 4 7 1
3 0 0 4 975 0 12 0 2 3 14
4 1 0 3 0 956 0 2 0 0 20
5 2 0 0 3 1 881 1 0 0 4
6 3 3 1 0 10 30 909 0 2 0
7 2 6 7 2 5 1 0 661 2 342
8 1 2 2 11 9 20 0 1 898 30
9 0 2 0 2 8 1 0 0 0 996

@isaranto
Copy link

Hey @RachelRamirez , it has been 2 years since I wrote that comment so don't remember all that well.
The thing is that it works, meaning that training is done right, but what is not trivial at all are the weight values that you are going to put over there. I think playing around some high and low weights and checking what happens to the confusion matrix and your metrics

@RachelRamirez
Copy link

RachelRamirez commented Feb 14, 2023

Thank you for the quick reply. I have played with lots of weights, and all of the numbers seem to reward misclassifications until I use a negative weight, which isn't consistent with the comments above. I wish I could comment on a more recent thread but this seems to be the only issue that addresses misclassifications and is continually referenced in all the other Kera's threads.

@PhilAlton
Copy link

Hi @RachelRamirez - I verified this class https://stackoverflow.com/a/61963004 extensively at the time... Hopefully provides a starting point for your specific problem?

@RachelRamirez
Copy link

@PhilAlton Thanks! I verified your process works in line with how I expected it to work using the MNIST example! If I raise the cost of a misclassification, the resulting costly misclassification goes down. I still wish Keras would make this a more easy to implement feature.

@PhilAlton
Copy link

Yep @RachelRamirez - I remember this being a real pain at the time! Tbh, we might be massively overcomplicating this... Loss functions do take a "sample_weights" argument, but it's not well documented (imo). It wasn't 100% clear to me if this was equivalent to class weights, plus I only discovered this when I had my own implementation working...

@eliadl
Copy link

eliadl commented Feb 16, 2023

@PhilAlton Loss functions support a sample_weights argument only in their __call__ method, but not in __init__. (example)

That's basically why we needed this #2115 (comment).

@PhilAlton
Copy link

PhilAlton commented Feb 16, 2023

@eliadl - ah yes, it's all coming back to me now! @RachelRamirez - if you were sufficiently motivated, you could raise a pull request to get this included... Not something I've done before! (imbalanced problems are very common, though accessing via call us clearly TF/Keras' preferred approach, eg: https://keras.io/examples/structured_data/imbalanced_classification/ - though it's not intuitive that the weights should be passed through model.fit)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests