tf.gather和tf.gather_nd的详细用法--tensorflow通过索引取tensor里的数据
发布网友
发布时间:2023-03-27 09:26
我来回答
共1个回答
热心网友
时间:2023-11-19 12:39
在numpy里取矩阵数据非常方便,比如:
这样就把矩阵a中的1,3,5行取出来了。
如果是只取某一维中单个索引的数据可以直接写成 tensor[:, 2] , 但如果要提取的索引不连续的话,在tensorflow里面的用法就要用到tf.gather.
tf.gather_nd允许在*上进行索引:
matrix中直接通过坐标取数(索引维度与tensor维度相同):
取第二行和第一行:
3维tensor的结果:
另外还有tf.batch_gather的用法如下:
tf.batch_gather(params, indices, name=None)
Gather slices from params according to indices with leading batch dims.
This operation assumes that the leading dimensions of indices are dense,
and the gathers on the axis corresponding to the last dimension of indices .
Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be
a Tensor of size [A1, ..., AN-1, C, B1, ..., BM] .
如果索引是一维的tensor,结果和 tf.gather 是一样的.