Xử lí missing labels trong bài toán Multi-task Learning

deep-learning

#1

Xin chào mọi người, em là một người mới trong lĩnh vực Computer Vision. Hiện tại em đang làm 1 bài toán multi-task classification learning với 3 task (task 1: 7 class, task 2: 3 class, task 3: 3 class), train với mô hình có backbone là VGG16. Tuy nhiên có 1 số image data chỉ có label trong task 3 chứ không có label task 1,2. Theo như em tìm hiểu thì có 1 số cách để train với missing data trong Keras/Tensorflow như là thiết lập mask value: đặt các data với missing label là -1, sau đó viết hàm loss và accuracy riêng. Em đã làm thử như dưới đây (label được thiết lập dưới dạng sparse:0,1,2,3…)

# Loss function for unmasked value

def masked_loss_function(y_true, y_pred):
valid_idxs = tf.where(y_true > -1)[:, 0]
valid_logits = tf.gather(y_pred, valid_idxs)
valid_labels = tf.gather(y_true, valid_idxs)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=valid_labels, logits=valid_logits)
return loss

# Accuracy function for unmasked value

def masked_accuracy(y_true, y_pred):
valid_idxs = tf.where(y_true > -1)[:, 0]
valid_logits = tf.gather(y_pred, valid_idxs)
valid_labels = tf.gather(y_true, valid_idxs)
total = K.sum(valid_labels)
correct = K.sum(valid_logits)
return correct / total

Tuy nhiên vẫn không thành công. Không biết có ai đã có kinh nghiệm trong bài toán này chưa, có thể chia sẻ kinh nghiệm implement được không ạ?