Ошибка RuntimeError: ожидаемый объект скалярного типа Double, но получил скалярный тип. Ошибка с плавающей запятой в Pytorch возникает, когда функция или операция ожидала скалярное значение типа double(64-битное число с плавающей запятой) для тензора. Но на вход он получил скалярное значение типа float(32-битное число с плавающей запятой).
Чтобы исправить ошибку, входное значение должно быть преобразовано в double или функция должна быть изменена для поддержки входных данных с плавающей запятой.
- Чтобы преобразовать тензор в float64, вы можете использовать метод torch.double().
- Чтобы преобразовать тензор в float32, вы можете использовать метод torch.float().
Вы также можете использовать метод .to() для преобразования тензора в нужный тип данных.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch # Create a tensor main_tensor = torch.randn(2, 2) double_tensor = main_tensor.double() float_tensor = main_tensor.float() print(main_tensor) print(double_tensor) print(float_tensor) |
Выход
|
1 2 3 4 5 6 |
tensor([[-1.0531, 0.6726], [ 0.6947, -0.3040]]) tensor([[-1.0531, 0.6726], [ 0.6947, -0.3040]], dtype=torch.float64) tensor([[-1.0531, 0.6726], [ 0.6947, -0.3040]]) |
Из вывода видно, что мы преобразовали обычный тензор в тензоры типов float32 и float64.
При работе с типом данных float32 или float64 убедитесь, что вы не используете оба типа одновременно.
Для большей точности всегда лучше использовать тип данных float64, и если входные данные относятся к типу данных float32, вы можете преобразовать их в тип данных float64.
PyTorch имеет три основных типа данных, которые могут представлять числа с плавающей запятой.
- torch.float32 или torch.float
- torch.float64 или torch.double
- факел.bfloat16
Что такое torch.float32?
torch.float32 — это 32-битное число с плавающей запятой, также известное как число с плавающей запятой «одинарной точности». Он занимает меньше памяти и быстрее выполняет вычисления, но может быть не таким точным, как torch.float64.
|
1 2 3 4 5 6 |
import torch # Create a float32 tensor main_tensor = torch.randn(2, 2, dtype=torch.float32) print(main_tensor) |
Выход
|
1 2 |
tensor([[0.5858, 1.0883], [0.6733, 0.4548]]) |
Что такое torch.float64?
torch.float64 — это 64-битное число с плавающей запятой, также известное как число с плавающей запятой «двойной точности». Он занимает больше памяти и медленнее выполняет вычисления, но более точен, чем torch.float32.
|
1 2 3 4 5 6 7 |
import torch # Create a float64 tensor main_tensor = torch.randn(2, 2, dtype=torch.float64) print(main_tensor) |
Выход
|
1 2 |
tensor([[-0.8973, 1.3078], [-0.9175, -1.3086]], dtype=torch.float64) |
Тип данных double имеет 64-битное число с плавающей запятой.
Тип данных float имеет 32-битное число с плавающей запятой.
Если вы передадите 32-битное число с плавающей запятой вместо 64-битного числа с плавающей запятой, будет выдано RuntimeError.
