1 . tensor 的数据结构分为头信息区和存储区,信息区主要保存着 tensor 形状、步长、数据类型等信息,而真正的数据则保存成连续数组存放在存储区。一般来说一个 tensor 有着与之对应的 storage,storage 是在 data 之上封装的接口,便于使用,而不同 tensor 的头信息一般不同,但却可能使用相同的数据。

Input:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch as t

a = t.arange(0, 6)
print(a.storage())

b = a.view(2, 3)
print(b.storage())

# 一个对象的 id 值可以看作它在内存中的地址
# storage 的内存地址一样,即是同一个 storage
print(id(a.storage()) == id(b.storage()))

# a 改变,b 也随之改变,因为它们共享 storage
a[1] = 100
print(b)

c = a[2:]
print(c.storage())

# data_ptr 返回 tensor 首元素的内存地址
# 可以看出相差 16,这是因为 2*8=16
# 相差两个元素,每个元素占 8 个字节(long)
print(c.data_ptr(), a.data_ptr())

c[0] = -100 # c[0] 的内存地址对应 a[2]
print(a)

d = t.Tensor(c.storage().float())
print(id(c.storage()) == id(d.storage()))

# 下面 3 个 tensor 共享 storage
print(id(a.storage()) == id(b.storage()) == id(c.storage))

print(a.storage_offset(), c.storage_offset())

e = b[::2, ::2] # 隔 2 行/列取一个元素
print(id(e.storage()) == id(a.storage()))

print(b.stride(), e.stride())

print(e.is_contiguous())

Output:

可见绝大多数操作并不修改 tensor 的数据,而只是修改了 tensor 的头信息。这种做法更节省内存,同时提升了处理速度。此外有些操作会导致 tensor 不连续,这时需要调用 tensor.contiguous 方法将它们变成连续的数据,该方法会使数据复制一份,不再与原来的数据共享 storage。

2 . tensor 可以随意地在 GPU/CPU 上传输,使用 tensor.cuda(device_id) 或者 tensor.cpu(),另外一个更通用的方法是 tensor.to(device)

  • 尽量使用 tensor.to(device),将 device 设为一个可配置的参数,这样可以很轻松地使程序同时兼容 GPU 和 CPU;
  • 数据在 GPU 之中传输的速度要远快于内存(CPU)到显存(GPU),所以尽量避免在内存和显存之间传输数据。

3 . tensor 的保存和加载十分简单,使用 torch.savetorch.load 即可完成相应的功能。在 save/load 时可以指定使用的 pickle 模块,在 load 时还可以将 GPU tensor 映射到 CPU 或者其他 GPU 上。

1
2
3
4
5
6
7
8
9
10
if t.cuda.is_available():
a = a.cuda(1) # 把 a 转为 GPU1 上的 tensor
t.save(a, 'a.pth')

# 加载 b,存储于 GPU1 上(因为保存时 tensor 就在 GPU1 上)
b = t.load('a.pth')
# 加载为 c,存储于 CPU 上
c = t.load('a.pth', map_location=lambda storage, loc: storage)
# 加载为 d,存储于 GPU0 上
d = t.load('a.pth', map_location={'cuda:1':'cuda:0'})

4 . 关于 tensor 还有几点需要注意:

  • 大多数 torch.function 都有一个参数 out,这时候产生的结果将保存在 out 指定的 tensor 之中;
  • torch.set_num_threads 可以设置 PyTorch 进行 CPU 多线程并行计算时候所占用的线程数,这个可以用来限制 PyTorch 所占用的 CPU 数目;
  • torch.set_printoptions 可以用来设置打印 tensor 时的数值精度和格式。

Input:

1
2
3
4
5
6
import torch as t

a = t.randn(2, 3)
print(a)
t.set_printoptions(precision=10)
print(a)

Output:


笔记来源:《pytorch-book》