Skip to content
Snippets Groups Projects
Commit b3e0fbae authored by Ben Riggan's avatar Ben Riggan
Browse files

fixed display_images

parent 2f691919
No related branches found
No related tags found
No related merge requests found
cifar_ica_basis_64.png

57 KiB | W: | H:

cifar_ica_basis_64.png

60.1 KiB | W: | H:

cifar_ica_basis_64.png
cifar_ica_basis_64.png
cifar_ica_basis_64.png
cifar_ica_basis_64.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -22,9 +22,9 @@ def ica(x, **args):
# default parameters
if not len(args):
args['lr'] = 1e-4
args['nsteps'] = 1000
args['k'] = 20
args['lr'] = 1
args['nsteps'] = 200
args['k'] = 64
lr=args['lr']
nsteps = args['nsteps']
......@@ -42,6 +42,23 @@ def ica(x, **args):
return L, W
def sample(x, patch_size=16, num_patches=10):
''' randomly sample patches from x
inputs:
x: numpy array of images
patch_size: patch dims for patch_size x patch_size images
(default 16)
num_patches: number of patches to sample (default 10)
outputs:
y: numpy array of patches
'''
return y
if __name__ == '__main__':
# command line arg parser
parser = argparse.ArgumentParser(description='Perform ICA on cifar-10')
......@@ -53,17 +70,17 @@ if __name__ == '__main__':
parser.add_argument('--k',
type=int,
required=False,
default=20,
default=64,
help="number of latent variables / basis images")
parser.add_argument('--lr',
type=float,
required=False,
default=1e-3,
help="learning rate")
default=1,
help="initial learning rate")
parser.add_argument('--nsteps',
type=int,
required=False,
default=1000,
default=200,
help="number of iterations")
args = parser.parse_args()
......@@ -74,6 +91,11 @@ if __name__ == '__main__':
# load cifar10 data
images, labels = cifar10(batch=batch)
images = images / 255. # normalize
# sample patches
# ***complete sample function in ica.py***
images = sample(images)
# perform zca whitening
# ***complete zca_white function in zca.py***
......@@ -84,7 +106,7 @@ if __name__ == '__main__':
L, W = ica(images_, lr=lr, nsteps=nsteps, k=k)
# display ICA basis images
display_images(W.T)
display_images(W)
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
#from matplotlib.colors import Normalize
import skimage
import numpy as np
def display_images(x, figsize=(8,8), rows=5, cols=4, reshape=True, shuffle=False, normalize=True):
''' display images in a grid of rows and cols '''
n = x.shape[0]
if reshape:
# unravel vector to channel first image
x = x.reshape((-1,3,32,32))
# channel last image
x = np.moveaxis(x,1,-1)
#shuffle images to show different images
if shuffle:
idx = np.random.permutation(n)
x = x[idx]
# grid of images
fig=plt.figure(figsize=figsize)
for i in range(rows*cols):
# break if too few images exist
if i>=n:
break
def display_images(x):
''' construct grid of image filters '''
cols = int(np.round(np.sqrt(x.shape[1]))) # grid columns
rows = int(np.ceil(x.shape[1]/cols)) # grid rows
step = x.shape[0] // 3 # pixels per channel
dim = int(np.sqrt(step)) # patch dims
dim_ = dim + 1 # patch dims + 1
fig.add_subplot(rows, cols, i+1)
if not normalize:
plt.imshow(x[i])
else:
x_ = x[i]-x[i].min()
x_ = x_ / x_.max()
plt.imshow(x_)
# channels
r = x[:step]
g = x[step:2*step]
b = x[2*step:]
plt.show()
# normalize each channel
r = r / np.amax(np.abs(r), axis=0)
g = g / np.amax(np.abs(g), axis=0)
b = b / np.amax(np.abs(b), axis=0)
# image
I = np.ones((dim*rows+rows-1, dim*cols+cols-1, 3))
for i in range(rows):
for j in range(cols):
if i*cols+j >= r.shape[1]:
break
I[i*dim_:i*dim_+dim, j*dim_:j*dim_+dim, 0] = r[:,i*cols+j].reshape((dim,dim))
I[i*dim_:i*dim_+dim, j*dim_:j*dim_+dim, 1] = g[:,i*cols+j].reshape((dim,dim))
I[i*dim_:i*dim_+dim, j*dim_:j*dim_+dim, 2] = b[:,i*cols+j].reshape((dim,dim))
I += 1
I /= 2
plt.imshow(I)
plt.axis('off')
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment