numpy.where的用法
在用Python处理大量数据时,Python的数据科学库极为有用,这里要提到的就是Numpy库。在Numpy库里有个where函数,它是Python中三元表达式 x if condition else y的另一种版本。为了更好的理解numpy.where函数,我们来看几个where函数的使用例子。
传统三元表达式
有种情况,假设我们有xarr和yarr两个NumPy数组,还有一个cond的包含布尔值的数组。当cond值为True时,选取xarr对应的值,否则选取yarr。
xarr = np.array([1, 2, 3, 4, 5, 6])
yarr = np.array([1.1, 2.1, 3.1, 4.1, 5.1, 6.1])
cond = np.array([True, False, True, False, True, False])
result = [(x if c else y) for x, y, c in zip(xarr, yarr, cond)]
result
传统的方式就是向上面这样使用三元表达式 x if condition else y,但是它对大数组的处理速度不佳,因为所有判断都是用纯python来完成的,而且也无法用于多维数组。如果使用使用np.where则可以很方便的解决上述问题并且改写为如下方式:
import numpy as np
xarr = np.array([1, 2, 3, 4, 5, 6])
yarr = np.array([1.1, 2.1, 3.1, 4.1, 5.1, 6.1])
cond = np.array([True, False, True, False, True, False])
np.where(cond, xarr, yarr)
# 输出如下:
# array([ 1. , 2.1, 3. , 4.1, 5. , 6.1])
生产新数组
在数据分析工作中,where通常用于根据另一个数组而产生一个新的数组。假设有一个随机数组成的矩阵如下。
arr = np.random.randn(5, 5)
arr
# 输出如下:
# array([[-1.55727013, -1.14643196, 0.33641317, 1.33949414, -0.0967725 ],
# [-0.94233914, 2.0355648 , 0.02982777, 0.25021319, 0.63198106],
# [-0.49934005, -1.34006463, -0.54190069, -0.21468117, -0.27063406],
# [ 0.76851561, -0.38540229, 1.1577072 , 2.01828294, 0.57194718],
# [ 0.40617705, 0.32532884, 1.96882039, 1.22848036, 0.94634927]])
这时候如果想把,所有数据中正数值替换为0,负数值替换为1,就可以用np.where函数很方便的做到,因为np.where的第二个和第三个参数都不一定是要数组,它们也都可以是标量值。
np.where(arr > 0, 0, 1)
# 输出如下:
# array([[1, 1, 0, 0, 1],
# [1, 0, 0, 0, 0],
# [1, 1, 1, 1, 1],
# [0, 1, 0, 0, 0],
# [0, 0, 0, 0, 0]])
返回条件判断元素下标
什么是返回条件判断元素下标呢,其实它有点类似于布尔型索引,先看一个常用的布尔型索引的例子。
xarr = np.array([1, 2, 3, 4, 5, 6])
xarr[xarr > 2]
# 输出如下:
# array([3, 4, 5, 6])
布尔型索引会返回当索引真值对应的数组内元素。前面说到np.where函数其实类似于布尔型索引,但是有一点不同,它返回的是 元素的下标 而不是 元素值本身。
xarr = np.array([1, 2, 3, 4, 5, 6])
np.where(xarr > 2)
# 输出如下:
# (array([2, 3, 4, 5]),)
可以看见输出的内容2, 3, 4, 5是元素的下标,所对应的元素内容其实是3, 4, 5, 6,这些元素都符合>2的要求。