diff --git a/cnn.py b/cnn.py
index 6592e5bb98c2cd37905978fb50b11e31f21279bd..debd08a9e8bba0cb453999f1eb48afb8adda964a 100755
--- a/cnn.py
+++ b/cnn.py
@@ -39,7 +39,7 @@ def accuracy(labels, pred, k=1):
             loss: average cross_entropy loss over batch (scalar)
     '''
     labels = np.argmax(labels, axis=-1)
-    pred = np.argmax(labels, axis=-1)
+    pred = np.argmax(pred, axis=-1)
     
     return np.mean(pred == labels)