Tensorflow中应该如何精确地进行Tensor掩蔽和索引编制?

狮子座

我已经使用TF已有2年了,在每个项目中,我都会冒出很多无意义的错误来掩盖信息,这些错误通常无济于事,也没有指出实际上是什么错误。或更糟糕的是,结果是错误的,但没有错误。我总是在训练循环之外使用伪数据测试代码,这很好。但是在训练(称合适)中,我不明白TensorFlow到底期望什么。仅举一个例子,有经验的人可以告诉我为什么此代码对于二进制交叉熵不起作用,结果是错误的,并且在这种情况下模型不收敛但没有错误:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

虽然这可以正常工作:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

对于一个明确的例子,情况恰恰相反。我无法将遮罩用作y_pred [mask]或y_pred [mask [0]]之类的索引,也不能使用tf.squeeze()等。但是使用tf.gather_nd()可以。我总是尝试我认为可能的所有组合,但我不明白为什么这么简单的事情会如此艰巨而痛苦。派托克也是这样吗?如果您知道Pytorch没有类似的烦人细节,我们很乐意切换。

编辑1:它们可以在训练循环或图形模式之外正常工作,更准确地说。

y_pred = tf.random.uniform(shape=[10,], minval=0, maxval=1, dtype='float32')
y_true = tf.random.uniform(shape=[10,], minval=0, maxval=2, dtype='int32')

# first method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, y_true, y_pred):
        y_true = tf.squeeze(y_true)
        mask = tf.where(y_true!=2)
        y_true = tf.gather_nd(y_true, mask)
        y_pred = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)

    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

# instantiate
mbxe = MaskedBXE()
print(f'first snippet: {mbxe(y_true, y_pred).numpy()}')


# second method
class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        mask = tf.where(y_true!=2, True, False)
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss)
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config}
    
# instantiate
mbxe = MaskedBXE()
print(f'second snippet: {mbxe(y_true, y_pred).numpy()}')

第一个片段:1.2907861471176147

第二段:1.2907861471176147

编辑2:在以@jdehesa的建议方式在图形模式下打印损失后,它们有所不同,它们不应这样做:

class MaskedBXE(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, y_true, y_pred):
        # first
        y_t = tf.squeeze(y_true)
        mask = tf.where(y_t!=2)
        y_t = tf.gather_nd(y_t, mask)
        y_p = tf.gather_nd(y_pred, mask)
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        first_loss =  tf.reduce_mean(loss)
        tf.print('first:')
        tf.print(first_loss, summarize=-1)
        # second
        mask = tf.where(y_true!=2, True, False)
        y_t = y_true[mask]
        y_p = y_pred[mask]
        loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
        second_loss = tf.reduce_mean(loss)
        tf.print('second:')
        tf.print(second_loss, summarize=-1)
        return second_loss

第一:

0.814215422

第二:

0.787778914

第一:

0.779697835

第二:

0.802924752

Jdehesa

我认为问题是您无意中在第一个版本中执行广播操作,这会给您带来错误的结果。如果批次(?, 1)由于tf.squeeze操作具有形状则会发生这种情况注意此示例中的形状

import tensorflow as tf

# Make random y_true and y_pred with shape (10, 1)
tf.random.set_seed(10)
y_true = tf.dtypes.cast(tf.random.uniform((10, 1), 0, 3, dtype=tf.int32), tf.float32)
y_pred = tf.random.uniform((10, 1), 0, 1, dtype=tf.float32)

# first
y_t = tf.squeeze(y_true)
mask = tf.where(y_t != 2)
y_t = tf.gather_nd(y_t, mask)
tf.print(tf.shape(y_t))
# [7]
y_p = tf.gather_nd(y_pred, mask)
tf.print(tf.shape(y_p))
# [7 1]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
first_loss =  tf.reduce_mean(loss)
tf.print(tf.shape(loss), summarize=-1)
# [7]
tf.print(first_loss, summarize=-1)
# 0.884061277

# second
mask = tf.where(y_true!=2, True, False)
y_t = y_true[mask]
tf.print(tf.shape(y_t))
# [7]
y_p = y_pred[mask]
tf.print(tf.shape(y_p))
# [7]
loss = tf.keras.losses.binary_crossentropy(y_t, y_p)
tf.print(tf.shape(loss), summarize=-1)
# []
second_loss = tf.reduce_mean(loss)
tf.print(second_loss, summarize=-1)
# 1.15896356

在第一个版本,两者y_ty_p成为广播到7x7的张量,因此交叉熵基本上是计算“所有VS一切”,然后取平均值。在第二种情况下,仅对每对对应的值计算交叉熵,这是正确的做法。

如果仅tf.squeeze在上面的示例中删除了该操作,则结果将得到纠正。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

如何精确地在HTML页面中居中引导网格?

来自分类Dev

如何精确地实现$ .width()

来自分类Dev

我应该如何精确地以Django形式实现条件逻辑?

来自分类Dev

Android如何精确地链接AIDL接口和远程服务?

来自分类Dev

如何精确地在自定义乐器上绘制先前的高点和低点?

来自分类Dev

如何使用regexp更精确地指定模式

来自分类Dev

如何完全精确地设置半径线性梯度%

来自分类Dev

如何将值精确地加或减1

来自分类Dev

gsl如何精确地实现替代的cblas链接功能?

来自分类Dev

2>&1如何精确地流水线工作?

来自分类Dev

在 Java 中精确地乘以两个数字

来自分类Dev

使用S3接收器进行流式传输时精确地进行Flink

来自分类Dev

不精确地获得回报

来自分类Dev

对n个数字进行排序:精确地需要多少个查询?

来自分类Dev

如果X中有多个Y,Python如何精确地解析“从X导入Y”?

来自分类Dev

Java 8 mapToInt(mapToInt(e-> e))如何精确地提高性能?

来自分类Dev

OpenGL如何精确地校正线性插值?

来自分类Dev

如何精确地将孩子的div调整为父div

来自分类Dev

您如何精确地使MySQL记录到期(或更新)到到期时间?

来自分类Dev

如何将本地Hadoop配置精确地模拟为GCP Dataproc

来自分类Dev

如何通过SQL精确地按连接的行查找行?

来自分类Dev

在打印浮点数时,如何更精确地打印小数位?

来自分类Dev

如何制作全天变化的墙纸-精确地调节亮度

来自分类Dev

在精确地执行给定操作k次后,如何找到可以得到的不同数组的总数?

来自分类Dev

百分比如何精确地获得顶部或底部等属性的参考高度?

来自分类Dev

Excel公式以精确地获取月份中两个日期之间的天数

来自分类Dev

如何精确地选择数字的一部分作为简单字符并排除CMD中的所有其余字符?

来自分类Dev

如何在Magento EE 1.13中进行手动重新索引编制?

来自分类Dev

Python 2.7:持续进行搜索和索引编制

Related 相关文章

  1. 1

    如何精确地在HTML页面中居中引导网格?

  2. 2

    如何精确地实现$ .width()

  3. 3

    我应该如何精确地以Django形式实现条件逻辑?

  4. 4

    Android如何精确地链接AIDL接口和远程服务?

  5. 5

    如何精确地在自定义乐器上绘制先前的高点和低点?

  6. 6

    如何使用regexp更精确地指定模式

  7. 7

    如何完全精确地设置半径线性梯度%

  8. 8

    如何将值精确地加或减1

  9. 9

    gsl如何精确地实现替代的cblas链接功能?

  10. 10

    2>&1如何精确地流水线工作?

  11. 11

    在 Java 中精确地乘以两个数字

  12. 12

    使用S3接收器进行流式传输时精确地进行Flink

  13. 13

    不精确地获得回报

  14. 14

    对n个数字进行排序:精确地需要多少个查询?

  15. 15

    如果X中有多个Y,Python如何精确地解析“从X导入Y”?

  16. 16

    Java 8 mapToInt(mapToInt(e-> e))如何精确地提高性能?

  17. 17

    OpenGL如何精确地校正线性插值?

  18. 18

    如何精确地将孩子的div调整为父div

  19. 19

    您如何精确地使MySQL记录到期(或更新)到到期时间?

  20. 20

    如何将本地Hadoop配置精确地模拟为GCP Dataproc

  21. 21

    如何通过SQL精确地按连接的行查找行?

  22. 22

    在打印浮点数时,如何更精确地打印小数位?

  23. 23

    如何制作全天变化的墙纸-精确地调节亮度

  24. 24

    在精确地执行给定操作k次后,如何找到可以得到的不同数组的总数?

  25. 25

    百分比如何精确地获得顶部或底部等属性的参考高度?

  26. 26

    Excel公式以精确地获取月份中两个日期之间的天数

  27. 27

    如何精确地选择数字的一部分作为简单字符并排除CMD中的所有其余字符?

  28. 28

    如何在Magento EE 1.13中进行手动重新索引编制?

  29. 29

    Python 2.7:持续进行搜索和索引编制

热门标签

归档