reconstruct_keras.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. import pandas as pd
  3. import json
  4. import os, sys, argparse
  5. from keras.models import model_from_json
  6. import modules.config as cfg
  7. import modules.metrics as metrics
  8. from joblib import dump, load
  9. from PIL import Image
  10. def reconstruct(_scene_name, _model_path, _n):
  11. # construct the empty output image
  12. output_image = np.empty([cfg.number_of_rows, cfg.number_of_columns])
  13. # load the trained model
  14. with open(_model_path, 'r') as f:
  15. json_model = json.load(f)
  16. model = model_from_json(json_model)
  17. model.load_weights(_model_path.replace('.json', '.h5'))
  18. model.compile(loss='binary_crossentropy',
  19. optimizer='adam',
  20. metrics=['accuracy'])
  21. # load scene and its `n` first pixel value data
  22. scene_path = os.path.join(cfg.folder_scenes_path, _scene_name)
  23. for id_column in range(cfg.number_of_columns):
  24. folder_path = os.path.join(scene_path, str(id_column))
  25. pixels_predicted = []
  26. for id_row in range(cfg.number_of_rows):
  27. pixel_filename = _scene_name + '_' + str(id_column) + '_' + str(id_row) + ".dat"
  28. pixel_file_path = os.path.join(folder_path, pixel_filename)
  29. with open(pixel_file_path, 'r') as f:
  30. # predict the expected pixel value
  31. lines = [float(l)/255. for l in f.readlines()]
  32. pixel_values = lines[0:int(_n)]
  33. pixel_values = np.array(pixel_values).reshape(1, (int(_n)))
  34. # predict pixel per pixel
  35. pixels_predicted.append(model.predict(pixel_values))
  36. # change normalized predicted value to pixel value
  37. pixels_predicted = [ val * 255. for val in pixels_predicted]
  38. for id_pixel, pixel in enumerate(pixels_predicted):
  39. output_image[id_pixel, id_column] = pixel
  40. print("{0:.2f}%".format(id_column / cfg.number_of_columns * 100))
  41. sys.stdout.write("\033[F")
  42. return output_image
  43. def main():
  44. parser = argparse.ArgumentParser(description="Train model and saved it")
  45. parser.add_argument('--scene', type=str, help='Scene name to reconstruct', choices=cfg.scenes_list)
  46. parser.add_argument('--model_path', type=str, help='Json model file path')
  47. parser.add_argument('--n', type=str, help='Number of pixel values approximated to keep')
  48. parser.add_argument('--image_name', type=str, help="The ouput image name")
  49. args = parser.parse_args()
  50. param_scene_name = args.scene
  51. param_n = args.n
  52. param_model_path = args.model_path
  53. param_image_name = args.image_name
  54. # get default value of `n` param
  55. if not param_n:
  56. param_n = param_model_path.split('_')[0]
  57. output_image = reconstruct(param_scene_name, param_model_path, param_n)
  58. if not os.path.exists(cfg.reconstructed_folder):
  59. os.makedirs(cfg.reconstructed_folder)
  60. image_path = os.path.join(cfg.reconstructed_folder, param_image_name)
  61. img = Image.fromarray(np.uint8(output_image))
  62. img.save(image_path)
  63. if __name__== "__main__":
  64. main()