やったこと
“invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number” というエラーが発生したので、調査します。
確認環境
Google Colaboratory で試しました。
import torch
print(torch.__version__)
1.1.0調査
問題となったコード <class 'torch.Tensor'> は0インデックスが使えないようです。
修正前
total_loss += loss.data[0]修正後
total_loss += loss.data.item()