mask_fill
是 PyTorch 中的一个函数,用于根据给定的条件(掩码)对张量中的元素进行填充。掩码是一个布尔值张量,指示哪些元素需要被填充。 ## 函数原型
1 | torch.masked_fill(input, mask, value) |
参数说明
- input (Tensor): 输入的张量。这个张量的形状和类型将决定输出张量的形状和类型。
- mask (Tensor): 一个布尔类型张量,大小与
input
相同。掩码中为True
的位置将被填充为value
,而为False
的位置保持不变。 - value (scalar): 用于填充的数值。所有掩码为
True
的位置会被这个值替代。
返回值
返回一个新的张量,其形状与输入张量 input
相同,掩码位置的值被 value
替代。
示例
1 | import torch |
输出
1 | tensor([1, 0, 3, 0, 5]) |
在这个例子中,mask
中为 True
的位置(即第二个和第四个位置)会被填充为 0
,而其他位置保持不变。
注意事项
mask
张量必须与input
张量的形状相同。- 如果
mask
中的元素不为布尔类型,可以先使用mask.to(torch.bool)
转换为布尔类型。 value
参数可以是任意标量类型(如整数、浮点数等)。
应用场景
- 数据预处理:在进行数据清洗或处理时,可以使用
mask_fill
来清除或替换特定位置的数据。 - 神经网络训练:在模型训练过程中,常常会使用掩码来忽略一些无关数据或填充缺失值。