image_denoising.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D
  2. from keras.models import Model
  3. from keras import backend as K
  4. from keras.callbacks import TensorBoard
  5. import argparse
  6. def generate_model(input_shape=(3, 200, 200)):
  7. input_img = Input(shape=input_shape) # adapt this if using `channels_first` image data format
  8. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(input_img)
  9. x = MaxPooling3D((1, 2, 2), padding='same')(x)
  10. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  11. encoded = MaxPooling3D((1, 2, 2), padding='same')(x)
  12. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(encoded)
  13. x = UpSampling3D((1, 2, 2))(x)
  14. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  15. x = UpSampling3D((1, 2, 2))(x)
  16. decoded = Conv3D(1, (1, 3, 3), activation='sigmoid', padding='same')(x)
  17. autoencoder = Model(input_img, decoded)
  18. autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
  19. return autoencoder
  20. def main():
  21. # load params
  22. parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
  23. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  24. parser.add_argument('--output', type=str, help='output file name desired for model (without .json extension)', required=True)
  25. parser.add_argument('--batch_size', type=int, help='batch size used as model input')
  26. parser.add_argument('--epochs', type=int, help='number of epochs used for training model')
  27. args = parser.parse_args()
  28. p_data_file = args.data
  29. p_output = args.output
  30. p_batch_size = args.batch_size
  31. p_epochs = args.epochs
  32. # load data from `p_data_file`
  33. x_train_noisy = []
  34. x_train = []
  35. x_test_noisy = []
  36. x_test = []
  37. # load model
  38. autoencoder = generate_model()
  39. # tensorboard --logdir=/tmp/autoencoder
  40. autoencoder.fit(x_train_noisy, x_train,
  41. epochs=100,
  42. batch_size=32,
  43. shuffle=True,
  44. validation_data=(x_test_noisy, x_test),
  45. callbacks=[TensorBoard(log_dir='/tmp/autoencoder', histogram_freq=0, write_graph=False)])
  46. # save model
  47. if __name__ == "__main__":
  48. main()