diff --git a/cnn.py b/cnn.py index 6592e5bb98c2cd37905978fb50b11e31f21279bd..debd08a9e8bba0cb453999f1eb48afb8adda964a 100755 --- a/cnn.py +++ b/cnn.py @@ -39,7 +39,7 @@ def accuracy(labels, pred, k=1): loss: average cross_entropy loss over batch (scalar) ''' labels = np.argmax(labels, axis=-1) - pred = np.argmax(labels, axis=-1) + pred = np.argmax(pred, axis=-1) return np.mean(pred == labels)