Parcourir la source

Update of tend curves

Jérôme BUISINE il y a 5 ans
Parent
commit
36b9da2ff2
1 fichiers modifiés avec 24 ajouts et 24 suppressions
  1. 24 24
      noise_svd_tend_visualization.py

+ 24 - 24
noise_svd_tend_visualization.py

@@ -2,6 +2,7 @@ import sys, os, getopt
 from PIL import Image
 
 from ipfml import processing, utils
+import ipfml.iqa.fr as fr_iqa
 
 from modules.utils import config as cfg
 from modules.utils import data_type as dt
@@ -20,29 +21,21 @@ pictures_folder       = cfg.pictures_output_folder
 
 step_picture          = 10
 
-error_data_choices    = ['MAE', 'MSE']
+error_data_choices    = ['mae', 'mse', 'ssim', 'psnr']
 
-def compute_mae(previous_data, current_data):
 
-    n = len(previous_data)
-    mae_sum = 0.
+def get_error_distance(p_error, y_true, y_test):
 
-    for id, x in enumerate(current_data):
-        y = previous_data[id] # current data reduces error
-        mae_sum += abs(x - y)
+    noise_method = None
+    function_name = p_error
 
-    return mae_sum / n
-
-def compute_mse(previous_data, current_data):
-
-    n = len(previous_data)
-    mse_sum = 0.
+    try:
+        error_method = getattr(fr_iqa, function_name)
+    except AttributeError:
+        raise NotImplementedError("Error method `{}` not implement `{}`".format(fr_iqa.__name__, function_name))
 
-    for id, x in enumerate(current_data):
-        y = previous_data[id] # current data reduces error
-        mse_sum += abs(x - y)
+    return error_method(y_true, y_test)
 
-    return mse_sum / n
 
 def main():
 
@@ -56,17 +49,17 @@ def main():
     min_value_svd = sys.maxsize
 
     if len(sys.argv) <= 1:
-        print('python noise_svd_mae_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error MAE')
+        print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error mae')
         sys.exit(2)
     try:
         opts, args = getopt.getopt(sys.argv[1:], "h:p:m:m:n:i:s:c:n:y:e", ["help=", "prefix=", "metric=", "mode=", "n=", "interval=", "step=", "color=", "norm=", "ylim=", "error="])
     except getopt.GetoptError:
         # print help information and exit:
-        print('python noise_svd_mae_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error MAE')
+        print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error mae')
         sys.exit(2)
     for o, a in opts:
         if o == "-h":
-            print('python noise_svd_mae_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error MAE')
+            print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error MAE')
             sys.exit()
         elif o in ("-p", "--prefix"):
             p_path = a
@@ -158,6 +151,7 @@ def main():
         if current_id % p_step == 0:
 
             current_data = data
+
             if p_mode == 'svdn':
                 current_data = utils.normalize_arr(current_data)
 
@@ -167,9 +161,15 @@ def main():
             svd_data.append(current_data)
             image_indices.append(current_id)
 
+            # use of whole image data for computation of ssim or psnr
+            if p_error == 'ssim' or p_error == 'psnr':
+                image_path = file_path.format(str(current_id))
+                current_data = np.asarray(Image.open(image_path))
+
             if len(previous_data) > 0:
-                current_mae = compute_mae(previous_data, current_data)
-                error_data.append(current_mae)
+
+                current_error = get_error_distance(p_error, previous_data, current_data)
+                error_data.append(current_error)
 
             if len(previous_data) == 0:
                 previous_data = current_data
@@ -188,7 +188,7 @@ def main():
 
     for id, data in enumerate(svd_data):
 
-        p_label = p_prefix + str(image_indices[id]) + " | MAE : " + str(error_data[id])
+        p_label = p_prefix + str(image_indices[id]) + " | " + p_error + ": " + str(error_data[id])
         ax1.plot(data, label=p_label)
 
     ax1.legend(bbox_to_anchor=(0.8, 1), loc=2, borderaxespad=0.2, fontsize=12)
@@ -206,7 +206,7 @@ def main():
         output_filename = output_filename + '_color'
 
     ax2.set_title(p_error + " information for : " + p_prefix  + ', ' + noise_name + ' noise, interval information ['+ str(begin) +', '+ str(end) +'], ' + p_metric + ' metric, step ' + str(p_step) + ', normalization ' + p_mode)
-    ax2.set_ylabel('Mean Squared Error')
+    ax2.set_ylabel(p_error + ' error')
     ax2.set_xlabel('Number of samples per pixels')
     ax2.set_xticks(range(len(image_indices)))
     ax2.set_xticklabels(image_indices)