0%

tensor.masked_fill

mask_fill 是 PyTorch 中的一个函数,用于根据给定的条件(掩码)对张量中的元素进行填充。掩码是一个布尔值张量,指示哪些元素需要被填充。 ## 函数原型

1
torch.masked_fill(input, mask, value)

参数说明

  • input (Tensor): 输入的张量。这个张量的形状和类型将决定输出张量的形状和类型。
  • mask (Tensor): 一个布尔类型张量,大小与 input 相同。掩码中为 True 的位置将被填充为 value,而为 False 的位置保持不变。
  • value (scalar): 用于填充的数值。所有掩码为 True 的位置会被这个值替代。

返回值

返回一个新的张量,其形状与输入张量 input 相同,掩码位置的值被 value 替代。

示例

1
2
3
4
5
6
7
8
9
10
11
12
import torch

# 创建一个张量
tensor = torch.tensor([1, 2, 3, 4, 5])

# 创建一个掩码
mask = torch.tensor([False, True, False, True, False])

# 使用mask_fill填充
filled_tensor = torch.masked_fill(tensor, mask, 0)

print(filled_tensor)

输出

1
tensor([1, 0, 3, 0, 5])

在这个例子中,mask 中为 True 的位置(即第二个和第四个位置)会被填充为 0,而其他位置保持不变。

注意事项

  • mask 张量必须与 input 张量的形状相同。
  • 如果 mask 中的元素不为布尔类型,可以先使用 mask.to(torch.bool) 转换为布尔类型。
  • value 参数可以是任意标量类型(如整数、浮点数等)。

应用场景

  • 数据预处理:在进行数据清洗或处理时,可以使用 mask_fill 来清除或替换特定位置的数据。
  • 神经网络训练:在模型训练过程中,常常会使用掩码来忽略一些无关数据或填充缺失值。