やったこと
“invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number” というエラーが発生したので、調査します。
確認環境
Google Colaboratory で試しました。
import torch
print(torch.__version__)
1.1.0
調査
問題となったコード <class 'torch.Tensor'>
は0インデックスが使えないようです。
修正前
total_loss += loss.data[0]
修正後
total_loss += loss.data.item()