Bedingte Logik als Array-Operationen – where
¶
Die Funktion numpy.where ist eine vektorisierte Version von if
und else
.
Im folgenden Beispiel erzeugen wir zunächst ein boolesches Array und zwei Arrays mit Werten:
[1]:
import numpy as np
[2]:
cond = ([False, True, False, True, False, False, False])
data1 = np.random.randn(1, 7)
data2 = np.random.randn(1, 7)
Nun wollen wir Nehmen wir die Werte aus data1
übernehmen, wenn der entsprechende Wert in cond
True
ist und ansonsten den Wert aus data2
übernommen wird. Mit Pythons if
-else
könnte das wie folgt aussehen:
[3]:
result = [(x if c else y) for x, y, c in zip(data1, data2, cond)]
result
[3]:
[array([ 0.79741059, 1.01235915, -1.02336595, 0.34698571, 0.91228723,
-0.3260451 , -0.38514407])]
Dies hat jedoch die folgenden beiden Probleme:
bei großen Arrays wird die Funktion nicht sehr schnell sein
dies funktioniert nicht mit mehrdimensionalen Arrays
Mit np.where
könnt ihr diese Probleme in einem einzigen Funktionsaufruf umgehen:
[4]:
result = np.where(cond, data1, data2)
result
[4]:
array([[ 0.79741059, 0.29822359, -1.02336595, -0.75315847, 0.91228723,
-0.3260451 , -0.38514407]])
Das zweite und dritte Argument von np.where
müssen keine Arrays sein; eines oder beide können auch Skalare sein. Eine typische Anwendung von where
in der Datenanalyse besteht darin, ein neues Array von Werten auf der Grundlage eines anderen Arrays zu erzeugen. Angenommen, ihr habt eine Matrix mit zufällig generierten Daten und möchtet alle negativen Werte zu positiven Werten machen:
[5]:
data = np.random.randn(4, 4)
data
[5]:
array([[-0.78389643, -0.43952108, 0.91911346, 0.12098948],
[ 0.75084868, 0.36700152, -0.43154287, 0.90047307],
[ 1.28823883, 0.20841103, -0.08082387, 1.12856954],
[-0.48137952, -1.20155362, -0.7572543 , -0.29655235]])
[6]:
data < 0
[6]:
array([[ True, True, False, False],
[False, False, True, False],
[False, False, True, False],
[ True, True, True, True]])
[7]:
np.where(data < 0, data * -1, data)
[7]:
array([[0.78389643, 0.43952108, 0.91911346, 0.12098948],
[0.75084868, 0.36700152, 0.43154287, 0.90047307],
[1.28823883, 0.20841103, 0.08082387, 1.12856954],
[0.48137952, 1.20155362, 0.7572543 , 0.29655235]])