Python np.argmax() 函数

NumPy (np) 是最流行的数学和科学计算库之一。它提供了许多处理多维数组的函数。在本文中,我们将重点关注Python np.argmax() 函数


Python np.argmax() 函数

顾名思义,argmax()函数返回 NumPy 数组中最大值的索引。如果有多个索引具有相同的最大值,则将返回第一个索引。

argmax() 语法:

np.argmax( a axis=None out=None * keepdims=<无值> )

第一个参数是输入数组。如果没有提供轴,则将数组展平,然后返回最大值的索引。

如果我们指定axis,它将返回沿给定轴的索引值。

第三个参数用于传递数组参数来存储结果,它应该具有正确的形状和数据类型才能正常工作。

如果keepdims传递为 True,则减少的轴将作为大小为 1 的维度保留在结果中。

让我们看一些使用 argmax() 函数的示例,以正确理解不同参数的用法。


1. 使用 np.argmax() 查找最大值的索引

>>> import numpy as np
>>> arr = np.array([[4,2,3], [1,6,2]])
>>> arr
array([[4, 2, 3],
       [1, 6, 2]])
>>> np.ndarray.flatten(arr)
array([4, 2, 3, 1, 6, 2])
>>> np.argmax(arr)
4

np.argmax() 返回 4,因为数组首先被展平,然后返回最大值的索引。因此,在本例中,最大值为 6,其在展平数组中的索引为 4。

但是,我们希望索引值位于普通数组中,而不是扁平数组中。因此,我们必须使用argmax()unravel_index()函数来获取正确格式的索引值。

>>> np.unravel_index(np.argmax(arr), arr.shape)
(1, 1)
>>>

2. 查找沿轴最大值的索引

如果您想要沿不同轴的最大值的索引,请传递 axis 参数值。如果我们传递 axis=0,则返回沿列的最大值的索引。对于 axis=1,返回沿行的最大值的索引。

>>> arr
array([[4, 2, 3],
       [1, 6, 2]])
>>> np.argmax(arr, axis=0)
array([0, 1, 0])
>>> np.argmax(arr, axis=1)
array([0, 1])

对于 axis = 0,第一列值为 4 和 1。因此最大值索引为 0。类似地,对于第二列,值为 2 和 6,因此最大值索引为 1。对于第三列,值是 3 和 2,所以最大值索引是 0。这就是我们以数组 ([0, 1, 0]) 形式获取输出的原因。

对于 axis = 1,第一行值为 (4, 2, 3),因此最大值索引为 0。对于第二行,值为 (1, 6, 2),因此最大值索引为 1。因此输出数组([0, 1])。


3. 使用具有多个最大值的 np.argmax()

>>> import numpy as np
>>> arr = np.arange(6).reshape(2,3)
>>> arr
array([[0, 1, 2],
       [3, 4, 5]])
>>> arr[0][1] = 5
>>> arr
array([[0, 5, 2],
       [3, 4, 5]])
>>> np.argmax(arr)
1
>>> arr[0][2] = 5
>>> arr
array([[0, 5, 5],
       [3, 4, 5]])
>>> np.argmax(arr)
1
>>> np.argmax(arr, axis=0)
array([1, 0, 0])
>>> np.argmax(arr, axis=1)
array([1, 2])
>>>

我们使用arange() 函数创建一个具有一些默认值的二维数组。然后我们将其中一个值更改为具有多个具有最大值的索引。从输出中可以清楚地看出,当有多个位置具有最大值时,将返回最大值的第一个索引。


概括

NumPy argmax() 函数很容易理解,只需记住在查找最大值的索引之前将数组展平即可。此外, axis 参数对于查找行和列中最大值的索引非常有帮助。

下一步是什么?

资源