image_denoising.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # main imports
  2. import os
  3. import json
  4. import pandas as pd
  5. import numpy as np
  6. import argparse
  7. # model imports
  8. from keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D
  9. from keras.models import Model
  10. from keras import backend as K
  11. from keras.callbacks import TensorBoard
  12. from sklearn.utils import shuffle
  13. # image processing imports
  14. import cv2
  15. # modules imports
  16. import custom_config as cfg
  17. def generate_model(input_shape):
  18. input_img = Input(shape=input_shape) # adapt this if using `channels_first` image data format
  19. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(input_img)
  20. x = MaxPooling3D((1, 2, 2), padding='same')(x)
  21. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  22. x = MaxPooling3D((1, 2, 2), padding='same')(x)
  23. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  24. encoded = MaxPooling3D((1, 2, 2), padding='same')(x)
  25. print(encoded)
  26. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(encoded)
  27. x = UpSampling3D((1, 2, 2))(x)
  28. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  29. x = UpSampling3D((1, 2, 2))(x)
  30. x = Conv3D(32, (1, 3, 3), activation='relu', padding='same')(x)
  31. x = UpSampling3D((1, 2, 2))(x)
  32. decoded = Conv3D(3, (1, 3, 3), activation='sigmoid', padding='same')(x)
  33. autoencoder = Model(input_img, decoded)
  34. autoencoder.compile(optimizer='adadelta', loss='mse')
  35. return autoencoder
  36. def main():
  37. # load params
  38. parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
  39. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  40. parser.add_argument('--output', type=str, help='output file name desired for model (without .json extension)', required=True)
  41. parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=cfg.keras_batch)
  42. parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=cfg.keras_epochs)
  43. args = parser.parse_args()
  44. p_data_file = args.data
  45. p_output = args.output
  46. p_batch_size = args.batch_size
  47. p_epochs = args.epochs
  48. # load data from `p_data_file`
  49. ########################
  50. # 1. Get and prepare data
  51. ########################
  52. print("Preparing data...")
  53. dataset_train = pd.read_csv(p_data_file + '.train', header=None, sep=";")
  54. dataset_test = pd.read_csv(p_data_file + '.test', header=None, sep=";")
  55. print("Train set size : ", len(dataset_train))
  56. print("Test set size : ", len(dataset_test))
  57. # default first shuffle of data
  58. dataset_train = shuffle(dataset_train)
  59. dataset_test = shuffle(dataset_test)
  60. print("Reading all images data...")
  61. # getting number of chanel
  62. n_channels = len(dataset_train[1][1].split('::'))
  63. print("Number of channels : ", n_channels)
  64. img_width, img_height = cfg.keras_img_size
  65. # specify the number of dimensions
  66. if K.image_data_format() == 'channels_first':
  67. if n_channels > 1:
  68. input_shape = (1, n_channels, img_width, img_height)
  69. else:
  70. input_shape = (n_channels, img_width, img_height)
  71. else:
  72. if n_channels > 1:
  73. input_shape = (1, img_width, img_height, n_channels)
  74. else:
  75. input_shape = (img_width, img_height, n_channels)
  76. # `:` is the separator used for getting each img path
  77. if n_channels > 1:
  78. dataset_train[1] = dataset_train[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  79. dataset_test[1] = dataset_test[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  80. else:
  81. dataset_train[1] = dataset_train[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  82. dataset_test[1] = dataset_test[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  83. x_dataset_train = dataset_train[1].apply(lambda x: np.array(x).reshape(input_shape))
  84. x_dataset_test = dataset_test[1].apply(lambda x: np.array(x).reshape(input_shape))
  85. y_dataset_train = dataset_train[0].apply(lambda x: cv2.imread(x).reshape(input_shape))
  86. y_dataset_test = dataset_test[0].apply(lambda x: cv2.imread(x).reshape(input_shape))
  87. # format data correctly
  88. x_data_train = np.array([item[0].reshape(input_shape) for item in x_dataset_train.values])
  89. x_data_test = np.array([item[0].reshape(input_shape) for item in x_dataset_test.values])
  90. y_data_train = np.array([item[0].reshape(input_shape) for item in y_dataset_train.values])
  91. y_data_test = np.array([item[0].reshape(input_shape) for item in y_dataset_test.values])
  92. # load model
  93. autoencoder = generate_model(input_shape)
  94. # tensorboard --logdir=/tmp/autoencoder
  95. autoencoder.fit(x_data_train, y_data_train,
  96. epochs=100,
  97. batch_size=32,
  98. shuffle=True,
  99. validation_data=(x_data_test, y_data_test),
  100. callbacks=[TensorBoard(log_dir='/tmp/autoencoder', histogram_freq=0, write_graph=False)])
  101. ##############
  102. # save model #
  103. ##############
  104. if not os.path.exists(cfg.saved_models_folder):
  105. os.makedirs(cfg.saved_models_folder)
  106. # save the model into HDF5 file
  107. model_output_path = os.path.join(cfg.saved_models_folder, p_output + '.json')
  108. json_model_content = autoencoder.to_json()
  109. with open(model_output_path, 'w') as f:
  110. print("Model saved into ", model_output_path)
  111. json.dump(json_model_content, f, indent=4)
  112. autoencoder.save_weights(model_output_path.replace('.json', '.h5'))
  113. if __name__ == "__main__":
  114. main()