Pytorch のエラー 'invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number' を解消する
Python
Published: 2019-09-24

やったこと

“invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number” というエラーが発生したので、調査します。

確認環境

Google Colaboratory で試しました。

import torch
print(torch.__version__)

1.1.0

調査

問題となったコード <class 'torch.Tensor'> は0インデックスが使えないようです。

修正前

total_loss += loss.data[0]

修正後

total_loss += loss.data.item()

参考