我使用下面的代码并报告错误,这是一个非常简单的例子,不应该按预期报告错误。

import numpy as np
import torch as th

# Assuming tt is some data (example as list)
tt = [1, 2, 3, 4, 5]  # Example data

# Check if tt is a NumPy array, and convert if necessary
if not isinstance(tt, np.ndarray):
    tt = np.array(tt)

# Now, convert tt to a PyTorch tensor
tensor_tt = th.from_numpy(tt)
print(tensor_tt)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[26], line 12
      9     tt = np.array(tt)
     11 # Now, convert tt to a PyTorch tensor
---> 12 tensor_tt = th.from_numpy(tt)
     13 print(tensor_tt)

TypeError: expected np.ndarray (got numpy.ndarray)

我正在使用以下 conda 环境:

conda list
# packages in environment at /opt/miniconda3/envs/ethos:
#
# Name                    Version                   Build  Channel
appnope                   0.1.4              pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
blas                      1.0                    openblas  
bottleneck                1.3.7           py312ha86b861_0  
brotli                    1.0.9                h80987f9_8  
brotli-bin                1.0.9                h80987f9_8  
bzip2                     1.0.8                h80987f9_6  
ca-certificates           2024.9.24            hca03da5_0  
click                     8.1.7                    pypi_0    pypi
colorlog                  6.8.2                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.0           py312h48ca7d4_0  
cycler                    0.11.0             pyhd3eb1b0_0  
debugpy                   1.6.7           py312h313beb8_0  
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
ethos                     0.1.0                    pypi_0    pypi
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
expat                     2.6.3                h313beb8_0  
filelock                  3.16.1                   pypi_0    pypi
fonttools                 4.51.0          py312h80987f9_0  
freetype                  2.12.1               h1192e45_0  
fsspec                    2024.9.0                 pypi_0    pypi
h5py                      3.12.1                   pypi_0    pypi
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh57ce528_0    conda-forge
ipython                   8.28.0             pyh707e725_0    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.4                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jpeg                      9e                   h80987f9_3  
jupyter_client            8.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kiwisolver                1.4.4           py312h313beb8_0  
lcms2                     2.12                 hba8e193_0  
lerc                      3.0                  hc377ac9_0  
libbrotlicommon           1.0.9                h80987f9_8  
libbrotlidec              1.0.9                h80987f9_8  
libbrotlienc              1.0.9                h80987f9_8  
libcxx                    14.0.6               h848a8c0_0  
libdeflate                1.17                 h80987f9_1  
libffi                    3.4.4                hca03da5_1  
libgfortran               5.0.0           11_3_0_hca03da5_28  
libgfortran5              11.3.0              h009349e_28  
libopenblas               0.3.21               h269037a_0  
libpng                    1.6.39               h80987f9_0  
libsodium                 1.0.18               h27ca646_1    conda-forge
libtiff                   4.5.1                h313beb8_0  
libwebp-base              1.3.2                h80987f9_0  
llvm-openmp               14.0.6               hc6e5704_0  
lz4-c                     1.9.4                h313beb8_1  
markupsafe                2.1.5                    pypi_0    pypi
matplotlib-base           3.9.2           py312h2df2da3_0  
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h313beb8_0  
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
networkx                  3.3                      pypi_0    pypi
numexpr                   2.8.7           py312h0f3ea24_0  
numpy                     2.1.2                    pypi_0    pypi
numpy-base                1.26.4          py312he047099_0  
openjpeg                  2.5.2                h54b8e55_0  
openssl                   3.3.2                h8359307_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
pandas                    2.2.3                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.4.0          py312h80987f9_0  
pip                       24.2            py312hca03da5_0  
platformdirs              4.3.6              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.48             pyha770c72_0    conda-forge
psutil                    5.9.0           py312h80987f9_0  
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pyarrow                   17.0.0                   pypi_0    pypi
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.2           py312hca03da5_0  
python                    3.12.7               h99e199e_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
python-tzdata             2023.3             pyhd3eb1b0_0  
pytz                      2024.2                   pypi_0    pypi
pyzmq                     25.1.2          py312h313beb8_0  
readline                  8.2                  h1a28f6b_0  
seaborn                   0.13.2                   pypi_0    pypi
setuptools                75.1.0          py312hca03da5_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.45.3               h80987f9_0  
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sympy                     1.13.3                   pypi_0    pypi
tk                        8.6.14               h6ba3021_0  
torch                     2.4.1                    pypi_0    pypi
tornado                   6.4.1           py312h80987f9_0  
tqdm                      4.66.5                   pypi_0    pypi
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024.2                   pypi_0    pypi
unicodedata2              15.1.0          py312h80987f9_0  
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
wheel                     0.44.0          py312hca03da5_0  
xz                        5.4.6                h80987f9_1  
zeromq                    4.3.5                h313beb8_0  
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h18a0788_1  
zstd                      1.5.5                hd90d995_2  

我尝试在 torch 中将 np.array 数据转换为张量;报告的错误令人困惑;我不知道包版本是否存在冲突。

当我降级 numpy 时它仍然报告相同的错误:

import numpy as np
import torch
print(torch.__version__)
print(np.__version__)
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
y = torch.from_numpy(x)

2.4.1
1.26.0
{
    "name": "TypeError",
    "message": "expected np.ndarray (got numpy.ndarray)",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 6
      4 print(np.__version__)
      5 x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
----> 6 y = torch.from_numpy(x)

TypeError: expected np.ndarray (got numpy.ndarray)"
}


最佳答案
2

numpy这是一个已知错误,您现在需要降级:

pip install "numpy<1.26.4"

参考:

3

  • 这对我来说似乎不可行,当使用 1.26.0 版本的 numpy 时,它会报告同样的错误


    – 

  • 你的 Torch 版本是什么?


    – 

  • 我的 Torch 版本是 2.4.1


    – 

经过多次尝试,我终于找到了问题的原因。注意,安装中有两个 numpy 版本numpy (1.26.4)和。如果你卸载 numpy (1.26.4) 然后通过 重新安装,你的错误仍然会存在;你需要再次确保版本 2.1.2 也被删除了。numpy-base (2.1.2)conda install numpy==1.26.0pip uninstall numpy