在pytorch中图片的张量结构与plt可以显示的图片格式要求是不一样的,所以plt是不能直接显示tensor格式的图片的,那么pytorch怎么用plt显示tensor图片呢?这就需要涉及到数据转换了,基本思路就是将tensor转换为numpy类型的数据结构,而numpy类型的格式刚好可以被plt支持。接下来就来看具体怎么操作吧!
问题
图像的张量结构为(C,H,W),而plt可以显示的图片格式要求(H,W,C),C为颜色通道数,可以没有。
所以问题就是将Tensor(C,H,W)=> numpy(H,W,C)
解决办法
def transimg(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
npimg1 = np.transpose(npimg,(1,2,0)) # C*H*W => H*W*C
return npimg1
以上就是pytorch怎么用plt显示tensor的方法介绍了,希望能给大家一个参考,也希望大家多多支持W3Cschool。