Tensorflow 2.x –具有周围单元格平均值的张量

弗里曼

我正在尝试在Tensorflow 2.x中编写自定义损失函数,以鼓励在输出空间(2D矩阵)中进行渐变。因此,作为损失函数的一个组成部分,我想获取一个Tensor并返回一个Tensor,其中每个像元代表原始张量中相应相邻像元的平均值。

调和矩阵变换

例如,以左上方的单元格为例:6.3 =(7 + 9 + 3)/ 3。或取中间格:4.5 =(1 + 3 + 5 + 7 + 8 + 6 + 4 + 2)/ 8。

考虑以下代码:

def gradient_encouraging_loss(y_true: Tensor, y_pred: Tensor) -> Tensor:
    gradient_loss: Tensor = tf.divide(tf.reduce_sum(tf.abs(tf.subtract(
        y_pred, tensor_harmonic(y_pred)
    ))), tf.cast(tf.size(y_pred), tf.float32))

    return gradient_loss

我将如何实施tensor_harmonic()y_pred形状为(None, X, Y),其中X和Y为输出矩阵尺寸。

Jdehesa

大部分情况下,您可以使用2D卷积运算来完成此操作,但是随后您需要特别注意外部值。您可以按照以下方法进行操作:

import tensorflow as tf

def surround_average(x):
    x = tf.convert_to_tensor(x)
    dt = x.dtype
    # Compute surround sum
    filter = tf.constant([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=dt)
    x2 = x[tf.newaxis, :, :, tf.newaxis]
    filter2 = filter[:, :, tf.newaxis, tf.newaxis]
    y2 = tf.nn.conv2d(x2, filter2, strides=1, padding='SAME')
    y = y2[0, :, :, 0]
    # Make matrix of number of surrounding elements
    s = tf.shape(x)
    d = tf.fill(s - 2, tf.constant(8, dtype=dt))
    d = tf.pad(d, [[0, 0], [1, 1]], constant_values=5)
    top_row = tf.concat([[3], tf.fill([s[1] - 2], tf.constant(5, dtype=dt)), [3]], axis=0)
    d = tf.concat([[top_row], d, [top_row]], axis=0)
    # Return average
    return y / d

# Test
x = tf.reshape(tf.range(24.), (4, 6))
print(x.numpy())
# [[ 0.  1.  2.  3.  4.  5.]
#  [ 6.  7.  8.  9. 10. 11.]
#  [12. 13. 14. 15. 16. 17.]
#  [18. 19. 20. 21. 22. 23.]]
print(surround_average(x).numpy())
# [[ 4.6666665  4.6        5.6        6.6        7.6        8.333333 ]
#  [ 6.6        7.         8.         9.        10.        10.4      ]
#  [12.6       13.        14.        15.        16.        16.4      ]
#  [14.666667  15.4       16.4       17.4       18.4       18.333334 ]]

编辑:上面的代码可以进行调整,以处理具有少量微小更改的矩阵批次:

import tensorflow as tf

def surround_average_batch(x):
    x = tf.convert_to_tensor(x)
    dt = x.dtype
    # Compute surround sum
    filter = tf.constant([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=dt)
    x2 = tf.expand_dims(x, axis=-1)
    filter2 = filter[:, :, tf.newaxis, tf.newaxis]
    y2 = tf.nn.conv2d(x2, filter2, strides=1, padding='SAME')
    y = tf.squeeze(y2, axis=-1)
    # Make matrix of number of surrounding elements
    s = tf.shape(x)
    d = tf.fill(s[1:] - 2, tf.constant(8, dtype=dt))
    d = tf.pad(d, [[0, 0], [1, 1]], constant_values=5)
    top_row = tf.concat([[3], tf.fill([s[2] - 2], tf.constant(5, dtype=dt)), [3]], axis=0)
    d = tf.concat([[top_row], d, [top_row]], axis=0)
    # Return average
    return y / d

# Test
x = tf.reshape(tf.range(24.), (2, 4, 3))
print(x.numpy())
# [[[ 0.  1.  2.]
#   [ 3.  4.  5.]
#   [ 6.  7.  8.]
#   [ 9. 10. 11.]]
# 
#  [[12. 13. 14.]
#   [15. 16. 17.]
#   [18. 19. 20.]
#   [21. 22. 23.]]]
print(surround_average_batch(x).numpy())
# [[[ 2.6666667  2.8        3.3333333]
#   [ 3.6        4.         4.4      ]
#   [ 6.6        7.         7.4      ]
#   [ 7.6666665  8.2        8.333333 ]]
# 
#  [[14.666667  14.8       15.333333 ]
#   [15.6       16.        16.4      ]
#   [18.6       19.        19.4      ]
#   [19.666666  20.2       20.333334 ]]]

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

日期 <今天() 的最后 x 个单元格的平均值

来自分类Dev

如何让 tensorflow 在具有 1 x 2 内核的 2 x 2 矩阵上进行卷积?

来自分类Dev

如何计算包含大于和小于excel中平均值的x和y值的单元格对的数量?

来自分类Dev

使用CPU的“ Keras后端+ Tensorflow”和“来自Tensorflow的Keras”之间有什么区别(在Tensorflow 2.x中)

来自分类Dev

使用ggplot2创建具有散点图平均值的网格和颜色单元

来自分类Dev

如何在具有tensorflow v2.x后端的keras中加载具有tensorflow v1.x后端的keras模型?

来自分类Dev

如何使Tensorflow在具有2.x功能的GPU上运行?

来自分类Dev

Tensorflow端口模型从1.x到2.x

来自分类Dev

TensorFlow V2.x和tf.keras的所有随机种子是什么?

来自分类Dev

BigTable到TensorFlow 2.x-是否有连接器?

来自分类Dev

如何在TensorFlow 2.x中加载Tensorflow 1.x保存的模型?

来自分类Dev

将Tensorflow 1.x代码升级到Tensorflow 2.x代码

来自分类Dev

计算DataFrame的2x2行列组的平均值

来自分类Dev

将平均值或“”返回到具有函数的单元格时出现#VALUE错误

来自分类Dev

如何找到具有多个单元格的列表的平均值?

来自分类Dev

在OpenCV上使用Tensorflow 2.X模型

来自分类Dev

仅使用PIP为CPU安装Tensorflow 2.x

来自分类Dev

在tf 2.x中以图形模式运行TensorFlow op

来自分类Dev

如果单元格1 = x,则单元格2 = y或单元格1 = y,则单元格2 = z,依此类推

来自分类Dev

替换包含“ #VALUE!”的单元格 与周围细胞的平均值

来自分类Dev

具有所有按行排列的Tensorflow1 concat 2D张量

来自分类Dev

为什么x = {1:4}返回1x1单元格,而x = {1 2 3 4}返回1x4单元格?[MATLAB]

来自分类Dev

ggplot2:具有平均值/ 95%置信区间线的密度图

来自分类Dev

如何对numpy数组的2x2子数组的平均值进行矢量化处理?

来自分类Dev

Excel:如果单元格1包含X或Y或Z,则单元格2应等于W

来自分类Dev

计算函数内的加权平均值时出现错误“ x”和“ w”必须具有相同的长度

来自分类Dev

UICollectionView接收到具有不存在索引路径的单元格的布局属性:<NSIndexPath:0x79fe0f20> {length = 2,path = 0-4}

来自分类Dev

2D bin(x,y)并计算10个最深数据点(z)的平均值(c)

来自分类Dev

TensorFlow 2.x:无法以h5格式保存经过训练的模型(OSError:无法创建链接(名称已经存在))

Related 相关文章

  1. 1

    日期 <今天() 的最后 x 个单元格的平均值

  2. 2

    如何让 tensorflow 在具有 1 x 2 内核的 2 x 2 矩阵上进行卷积?

  3. 3

    如何计算包含大于和小于excel中平均值的x和y值的单元格对的数量?

  4. 4

    使用CPU的“ Keras后端+ Tensorflow”和“来自Tensorflow的Keras”之间有什么区别(在Tensorflow 2.x中)

  5. 5

    使用ggplot2创建具有散点图平均值的网格和颜色单元

  6. 6

    如何在具有tensorflow v2.x后端的keras中加载具有tensorflow v1.x后端的keras模型?

  7. 7

    如何使Tensorflow在具有2.x功能的GPU上运行?

  8. 8

    Tensorflow端口模型从1.x到2.x

  9. 9

    TensorFlow V2.x和tf.keras的所有随机种子是什么?

  10. 10

    BigTable到TensorFlow 2.x-是否有连接器?

  11. 11

    如何在TensorFlow 2.x中加载Tensorflow 1.x保存的模型?

  12. 12

    将Tensorflow 1.x代码升级到Tensorflow 2.x代码

  13. 13

    计算DataFrame的2x2行列组的平均值

  14. 14

    将平均值或“”返回到具有函数的单元格时出现#VALUE错误

  15. 15

    如何找到具有多个单元格的列表的平均值?

  16. 16

    在OpenCV上使用Tensorflow 2.X模型

  17. 17

    仅使用PIP为CPU安装Tensorflow 2.x

  18. 18

    在tf 2.x中以图形模式运行TensorFlow op

  19. 19

    如果单元格1 = x,则单元格2 = y或单元格1 = y,则单元格2 = z,依此类推

  20. 20

    替换包含“ #VALUE!”的单元格 与周围细胞的平均值

  21. 21

    具有所有按行排列的Tensorflow1 concat 2D张量

  22. 22

    为什么x = {1:4}返回1x1单元格,而x = {1 2 3 4}返回1x4单元格?[MATLAB]

  23. 23

    ggplot2:具有平均值/ 95%置信区间线的密度图

  24. 24

    如何对numpy数组的2x2子数组的平均值进行矢量化处理?

  25. 25

    Excel:如果单元格1包含X或Y或Z,则单元格2应等于W

  26. 26

    计算函数内的加权平均值时出现错误“ x”和“ w”必须具有相同的长度

  27. 27

    UICollectionView接收到具有不存在索引路径的单元格的布局属性:<NSIndexPath:0x79fe0f20> {length = 2,path = 0-4}

  28. 28

    2D bin(x,y)并计算10个最深数据点(z)的平均值(c)

  29. 29

    TensorFlow 2.x:无法以h5格式保存经过训练的模型(OSError:无法创建链接(名称已经存在))

热门标签

归档