data.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # main imports
  2. import os
  3. import numpy as np
  4. import random
  5. # image processing imports
  6. from PIL import Image
  7. # modules imports
  8. from ..config.cnn_config import *
  9. _scenes_names_prefix = '_scenes_names'
  10. _scenes_indices_prefix = '_scenes_indices'
  11. # store all variables from current module context
  12. context_vars = vars()
  13. def get_renderer_scenes_indices(renderer_name):
  14. if renderer_name not in renderer_choices:
  15. raise ValueError("Unknown renderer name")
  16. if renderer_name == 'all':
  17. return scenes_indices
  18. else:
  19. return context_vars[renderer_name + _scenes_indices_prefix]
  20. def get_renderer_scenes_names(renderer_name):
  21. if renderer_name not in renderer_choices:
  22. raise ValueError("Unknown renderer name")
  23. if renderer_name == 'all':
  24. return scenes_names
  25. else:
  26. return context_vars[renderer_name + _scenes_names_prefix]
  27. def get_scene_image_quality(img_path):
  28. # if path getting last element (image name) and extract quality
  29. img_postfix = img_path.split('/')[-1].split(scene_image_quality_separator)[-1]
  30. img_quality = img_postfix.replace(scene_image_extension, '')
  31. return int(img_quality)
  32. def get_scene_image_postfix(img_path):
  33. # if path getting last element (image name) and extract quality
  34. img_postfix = img_path.split('/')[-1].split(scene_image_quality_separator)[-1]
  35. img_quality = img_postfix.replace(scene_image_extension, '')
  36. return img_quality
  37. def get_scene_image_prefix(img_path):
  38. # if path getting last element (image name) and extract prefix
  39. img_prefix = img_path.split('/')[-1].split(scene_image_quality_separator)[0]
  40. return img_prefix
  41. def augmented_data_image(block, output_folder, prefix_image_name):
  42. rotations = [0, 90, 180, 270]
  43. img_flip_labels = ['original', 'horizontal', 'vertical', 'both']
  44. horizontal_img = block.transpose(Image.FLIP_LEFT_RIGHT)
  45. vertical_img = block.transpose(Image.FLIP_TOP_BOTTOM)
  46. both_img = block.transpose(Image.TRANSPOSE)
  47. flip_images = [block, horizontal_img, vertical_img, both_img]
  48. # rotate and flip image to increase dataset size
  49. for id, flip in enumerate(flip_images):
  50. for rotation in rotations:
  51. rotated_output_img = flip.rotate(rotation)
  52. output_reconstructed_filename = prefix_image_name + post_image_name_separator
  53. output_reconstructed_filename = output_reconstructed_filename + img_flip_labels[id] + '_' + str(rotation) + '.png'
  54. output_reconstructed_path = os.path.join(output_folder, output_reconstructed_filename)
  55. if not os.path.exists(output_reconstructed_path):
  56. rotated_output_img.save(output_reconstructed_path)
  57. def remove_pixel(img, limit):
  58. width, height = img.shape
  59. output = np.zeros((width, height))
  60. for i in range(width):
  61. for j in range(height):
  62. if img[i,j] <= limit:
  63. output[i,j] = img[i,j]
  64. return output
  65. def get_random_value(distribution):
  66. rand = random.uniform(0, 1)
  67. prob_sum = 0.
  68. for id, prob in enumerate(distribution):
  69. prob_sum += prob
  70. if prob_sum >= rand:
  71. return id
  72. return len(distribution) - 1
  73. def distribution_from_data(data):
  74. occurences = np.array([data.count(x) for x in set(data)])
  75. max_occurences = sum(occurences)
  76. return occurences / max_occurences
  77. def fill_image_with_rand_value(img, func, value_to_replace):
  78. width, height = img.shape
  79. output = np.zeros((width, height))
  80. for i in range(width):
  81. for j in range(height):
  82. if img[i,j] == value_to_replace:
  83. output[i, j] = func()
  84. else:
  85. output[i, j] = img[i, j]
  86. return output