Skip to content

Is there a way in Keras to apply different weights to a cost function? #2115

Closed
@ayalalazaro

Description

@ayalalazaro

Hi there,
I am trying to implement a classification problem with three classes: 0,1 and 2. I would like to fine tune my cost function so that missclassification is weighted some how. In particular, predicting 1 instead of 2 should give twice the cost than predicting 0. writing it in a table format, it should be something like that:

Costs:
Predicted:
0 | 1 | 2
__________________________
Actual 0 | 0 | 0.25 | 0.25
1 | 0.25 | 0 | 0.5
2 | 0.25 | 0.5 | 0

I really like keras framework, it would be nice if it is possible to implement it and not having to dig into tensorflow or theano code.

Thanks

Activity

ayalalazaro

ayalalazaro commented on Mar 29, 2016

@ayalalazaro
Author

Sorry, the table has lost its format, I am sending an image:
image

carlthome

carlthome commented on Mar 29, 2016

@carlthome
Contributor

Similar: #2121

tboquet

tboquet commented on Mar 29, 2016

@tboquet
Contributor

You could use class_weight.

ayalalazaro

ayalalazaro commented on Mar 29, 2016

@ayalalazaro
Author

class_weight applies a weight to all data that belongs to the class, it should be dependent on the missclassification.

tboquet

tboquet commented on Mar 30, 2016

@tboquet
Contributor

You are absolutely right, I'm sorry I misunderstood your question. I will try to come back with something tomorrow using partial to define the weights. What you want to achieve should be doable with Keras abstract backend.

tboquet

tboquet commented on Mar 31, 2016

@tboquet
Contributor

Ok so I had the time to quickly test it.
This is a fully reproducible example on mnist where we put a higher cost when a 1 is missclassified as a 7 and when a 7 is missclassified as a 1.

So if you want to pass constants included in the cost function, just build a new function with partial.

'''Train a simple deep NN on the MNIST dataset.
Get to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD, Adam, RMSprop
from keras.utils import np_utils
import keras.backend as K
from itertools import product

# Custom loss function with costs

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

w_array = np.ones((10,10))
w_array[1, 7] = 1.2
w_array[7, 1] = 1.2

ncce = partial(w_categorical_crossentropy, weights=np.ones((10,10)))

batch_size = 128
nb_classes = 10
nb_epoch = 20

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

model = Sequential()
model.add(Dense(512, input_shape=(784,)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))

rms = RMSprop()
model.compile(loss=ncce, optimizer=rms)

model.fit(X_train, Y_train,
          batch_size=batch_size, nb_epoch=nb_epoch,
          show_accuracy=True, verbose=1,
          validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test,
                       show_accuracy=True, verbose=1)
print('Test score:', score[0])
print('Test accuracy:', score[1])
ayalalazaro

ayalalazaro commented on Apr 1, 2016

@ayalalazaro
Author

Wow, that s nice. Thanks for the detailed answer!

tboquet

tboquet commented on Apr 1, 2016

@tboquet
Contributor

Try to test it on a toy example to verify that it actually works. If it's what you are looking for, feel free to close the issue!
Keras 1.0 will provide a more flexible way to introduce new objectives and metrics.

ayalalazaro

ayalalazaro commented on Apr 2, 2016

@ayalalazaro
Author

Well, I am stuck, I can t make it run in my model, it says:

line 56, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (y_pred.shape[0], 1))

AttributeError: 'Tensor' object has no attribute 'shape'

This is the model I am using:

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (y_pred.shape[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask

w_array = np.ones((3,3))
w_array[2,1] = 1.2
w_array[1,2] = 1.2
ncce = partial(w_categorical_crossentropy, weights=w_array)

def build_model(X_data):
    data_dim = X_data.shape[2]
    timesteps = X_data.shape[1]
    model = Sequential()
    model.add(BatchNormalization(input_shape = (timesteps,data_dim)))  
    model.add(GRU(output_dim=50,init ='glorot_normal',
         return_sequences=True, W_regularizer=l2(0.00),U_regularizer=l1(0.01),dropout_W =0.2 ))
    model.add(GRU(output_dim=50,init ='glorot_normal',
        return_sequences=True,W_regularizer=l2(0.00),U_regularizer=l1(0.01),dropout_W =0.2))
    model.add(GRU(50,init ='glorot_normal',return_sequences=False,dropout_W =0.01, W_regularizer=l2(0.00),U_regularizer=l1(0.01)))
    model.add(Dense(3, init='glorot_normal'))
    model.add(Activation('softmax'))

    model.compile(loss=ncce,
              optimizer='Adam'
            )
    return model
tboquet

tboquet commented on Apr 4, 2016

@tboquet
Contributor

Sure, sorry I was using Theano functionnalities. I replaced the following line in my previous example:

y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))

It should do the trick!

ayalalazaro

ayalalazaro commented on Apr 5, 2016

@ayalalazaro
Author

Sounds the way to go, I was using tensorflow as backend. I tell you if it works as soon as posiible. Thanks!

ayalalazaro

ayalalazaro commented on Apr 5, 2016

@ayalalazaro
Author

I still get an error:

line 57, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 271, in reshape
return tf.reshape(x, shape)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 682, in reshape
name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 411, in apply_op
as_ref=input_arg.is_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 529, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 178, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 161, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 319, in make_tensor_proto
_AssertCompatible(values, dtype)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 259, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).name))

TypeError: Expected int32, got list containing Tensors of type '_Message' instead.

I ve tried your first reply under theano backend and it works though.

tboquet

tboquet commented on Apr 5, 2016

@tboquet
Contributor

Ok, I was not sure about how K.shape would behave with TensorFlow. It seems you should use:

y_pred_max = K.reshape(y_pred_max, (K.int_shape(y_pred)[0], 1))
ayalalazaro

ayalalazaro commented on Apr 6, 2016

@ayalalazaro
Author

I get more or less the same:

line 59, in w_categorical_crossentropy
y_pred_max = K.reshape(y_pred_max, (K.int_shape(y_pred)[0], 1))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 271, in reshape
return tf.reshape(x, shape)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 682, in reshape
name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 411, in apply_op
as_ref=input_arg.is_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 529, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 178, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 161, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 319, in make_tensor_proto
_AssertCompatible(values, dtype)

File "/home/hal/anaconda2/envs/tflow/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 259, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).name))

TypeError: Expected int32, got None of type '_Message' instead.

It seems like it cannot get the shape of y_pred as an integer , right?

tboquet

tboquet commented on Apr 6, 2016

@tboquet
Contributor

Mm, ok I will take a look at it today and work directly with tensors to try to find a way to have it work properly for both backend.

83 remaining items

william-allen-harris

william-allen-harris commented on Sep 13, 2020

@william-allen-harris

Hello does anyone know how to do this for sparse categorical crossentropy?

hiyamgh

hiyamgh commented on Feb 7, 2021

@hiyamgh

Hello, thank you for this awesome thread. I have a small question though, I am trying to implement this solution in tensorflow rather than in keras backend. My question is that are these applied to the logits (the output of the last layer of the neural network which are raw values (not probabilities) that we did NOT apply softmax to) or are these the probabilities (AFTER softmax)?

In other words, is K.categorical_cross_entropy the equivalent of tf.nn.softmax_cross_entropy_with_logits or not ?

Thank you in advance.

def w_categorical_crossentropy(y_true, y_pred, weights):
    nb_cl = len(weights)
    final_mask = K.zeros_like(y_pred[:, 0])
    y_pred_max = K.max(y_pred, axis=1)
    y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
    y_pred_max_mat = K.equal(y_pred, y_pred_max)
    for c_p, c_t in product(range(nb_cl), range(nb_cl)):
        final_mask += (weights[c_t, c_p] * y_pred_max_mat[:, c_p] * y_true[:, c_t])
    return K.categorical_crossentropy(y_pred, y_true) * final_mask
hiyamgh

hiyamgh commented on Feb 7, 2021

@hiyamgh

I guess I found the answer,

I have seen the documentation of tf.backend.categorical_cross_entropy and it states the following:

tf.keras.backend.categorical_crossentropy(
    target,
    output,
    from_logits=False
)

target: A tensor of the same shape as output.
output: A tensor resulting from a softmax (unless from_logits is True, in which case output is expected to be the logits).
from_logits: Boolean, whether output is the result of a softmax, or is a tensor of logits.

I will just do in on the logits then?

TejashwiniDuluri

TejashwiniDuluri commented on Feb 10, 2021

@TejashwiniDuluri

@machisuke shouldn't it be return
K.categorical_crossentropy(y_true, y_pred) * final_mask
instead of
K.categorical_crossentropy(y_pred, y_true) * final_mask

as @blakewest pointed out from the Keras source code?

yeah even I have the same question on this.

isaranto

isaranto commented on Apr 8, 2021

@isaranto

With tf.keras implementation I would propose a more vectorized approach (avoid the for loop):

def weighted_categorical_crossentropy_new(y_true, y_pred, weights):
          idx1 = K.argmax(y_pred, axis=1)
          idx2 = K.argmax(y_true, axis=1)
          mask = tf.gather_nd(weights, tf.stack((idx1, idx2), -1))
          return K.categorical_crossentropy(y_true, y_pred) * mask

You can modify the above to fit your needs but it worked for me with an example weight matrix, where you want to achieve some cost-sensitive learning where certain mispredictions are more/less important than others.

 [[0.0, 1.5, 3.0, 4.5, 6.0],
  [1.0, 0.0, 1.5, 3.0, 4.5],
  [2.0, 1.0, 0.0, 1.5, 3.0],
  [3.0, 2.0, 1.0, 0.0, 1.5],
  [4.0, 3.0, 2.0, 1.0, 0.0]]
rafaelagrc

rafaelagrc commented on Apr 12, 2021

@rafaelagrc

Hello. How can we implement this for Sparse Categorical Cross Entropy?

jsbryaniv

jsbryaniv commented on Nov 17, 2021

@jsbryaniv

I would also like to know how to implement this for SparseCategoricalCrossEntropy

RachelRamirez

RachelRamirez commented on Feb 14, 2023

@RachelRamirez

With tf.keras implementation I would propose a more vectorized approach (avoid the for loop):

def weighted_categorical_crossentropy_new(y_true, y_pred, weights):
          idx1 = K.argmax(y_pred, axis=1)
          idx2 = K.argmax(y_true, axis=1)
          mask = tf.gather_nd(weights, tf.stack((idx1, idx2), -1))
          return K.categorical_crossentropy(y_true, y_pred) * mask

Has anyone verified the code above works? If so can they share a minimal working example? @isaranto have you verified your vectorized approach of the original method works on the MNIST network example as given? I put a high weight on the misclassification that naturally seems to be highest when running the dense neural network given by @tboquet and the results are not intuitive, as in, the number of misclassifications does not decrease. I've compared the confusion matrix of results for using weights on w[7,9] =1, 1.1, 1.2, 1.5, 1.7, 2, ... 10, ... 100 and one would expect the number of misclassification on [7,9], to decrease as the weight increases, but there doesn't seem to be a consistent pattern and if anything it seems like 7 out of 30 times I run the results, the misclassification for [7,9] increases dramatically (like from 20 to 386). So I tried negative numbers, and that did have the immediate effect of decreasing the misclassification rates. However, using negative numbers isn't consistent with any of the above discussion.

Here's the code I've used - it's long so I posted a link to my Public Google Colab Notebook: https://github.com/RachelRamirez/misclassification_matrix/blob/main/w%5B7%2C9%5D%3D100_Misclassification_Cost_Matrix_Example.ipynb

This is the output of one of the worst confusion matrixes (run 14) using w[7,9]=100. It seems like its rewarding the misclassification instead of the reverse.
CM:

0 1 2 3 4 5 6 7 8 9
0 970 2 0 1 2 0 2 0 1 2
1 0 1129 1 1 0 0 2 0 2 0
2 5 2 1005 4 2 0 2 4 7 1
3 0 0 4 975 0 12 0 2 3 14
4 1 0 3 0 956 0 2 0 0 20
5 2 0 0 3 1 881 1 0 0 4
6 3 3 1 0 10 30 909 0 2 0
7 2 6 7 2 5 1 0 661 2 342
8 1 2 2 11 9 20 0 1 898 30
9 0 2 0 2 8 1 0 0 0 996
isaranto

isaranto commented on Feb 14, 2023

@isaranto

Hey @RachelRamirez , it has been 2 years since I wrote that comment so don't remember all that well.
The thing is that it works, meaning that training is done right, but what is not trivial at all are the weight values that you are going to put over there. I think playing around some high and low weights and checking what happens to the confusion matrix and your metrics

RachelRamirez

RachelRamirez commented on Feb 14, 2023

@RachelRamirez

Thank you for the quick reply. I have played with lots of weights, and all of the numbers seem to reward misclassifications until I use a negative weight, which isn't consistent with the comments above. I wish I could comment on a more recent thread but this seems to be the only issue that addresses misclassifications and is continually referenced in all the other Kera's threads.

PhilAlton

PhilAlton commented on Feb 14, 2023

@PhilAlton

Hi @RachelRamirez - I verified this class https://stackoverflow.com/a/61963004 extensively at the time... Hopefully provides a starting point for your specific problem?

RachelRamirez

RachelRamirez commented on Feb 15, 2023

@RachelRamirez

@PhilAlton Thanks! I verified your process works in line with how I expected it to work using the MNIST example! If I raise the cost of a misclassification, the resulting costly misclassification goes down. I still wish Keras would make this a more easy to implement feature.

PhilAlton

PhilAlton commented on Feb 16, 2023

@PhilAlton

Yep @RachelRamirez - I remember this being a real pain at the time! Tbh, we might be massively overcomplicating this... Loss functions do take a "sample_weights" argument, but it's not well documented (imo). It wasn't 100% clear to me if this was equivalent to class weights, plus I only discovered this when I had my own implementation working...

eliadl

eliadl commented on Feb 16, 2023

@eliadl

@PhilAlton Loss functions support a sample_weights argument only in their __call__ method, but not in __init__. (example)

That's basically why we needed this #2115 (comment).

PhilAlton

PhilAlton commented on Feb 16, 2023

@PhilAlton

@eliadl - ah yes, it's all coming back to me now! @RachelRamirez - if you were sufficiently motivated, you could raise a pull request to get this included... Not something I've done before! (imbalanced problems are very common, though accessing via call us clearly TF/Keras' preferred approach, eg: https://keras.io/examples/structured_data/imbalanced_classification/ - though it's not intuitive that the weights should be passed through model.fit)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @mongoose54@recluze@0x00b1@dralves@dickreuter

        Issue actions

          Is there a way in Keras to apply different weights to a cost function? · Issue #2115 · keras-team/keras