浏览代码

input normalization

Jérôme BUISINE 4 年之前
父节点
当前提交
7b6a3ed93b
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      train_lstm_weighted.py

+ 1 - 1
train_lstm_weighted.py

@@ -50,7 +50,7 @@ def build_input(df, seq_norm):
             for img_path in column:
                 img = Image.open(img_path)
                 # seq_elems.append(np.array(img).flatten())
-                seq_elems.append(np.array(img))
+                seq_elems.append(np.array(img) / 255.)
 
             #seq_arr.append(np.array(seq_elems).flatten())
             seq_arr.append(np.array(seq_elems))