Просмотр исходного кода

Update of data; Save and load option added

Jérôme BUISINE 5 лет назад
Родитель
Сommit
6901223dae

+ 127 - 18
ganSynthesisImage_200.py

@@ -2,7 +2,7 @@
 import random
 import argparse
 import cv2
-import os
+import os, sys, getopt
 
 import torch
 import torch.nn as nn
@@ -20,7 +20,7 @@ import numpy as np
 log = gym.logger
 log.set_level(gym.logger.INFO)
 
-LATENT_VECTOR_SIZE = 100
+LATENT_VECTOR_SIZE = 400
 DISCR_FILTERS = 200
 GENER_FILTERS = 200
 BATCH_SIZE = 16
@@ -29,9 +29,13 @@ BATCH_SIZE = 16
 IMAGE_SIZE = 200
 input_shape = (3, IMAGE_SIZE, IMAGE_SIZE)
 
+BACKUP_MODEL_NAME = "synthesis_{}_model.pt"
+BACKUP_FOLDER = "saved_models"
+BACKUP_EVERY_ITER = 1
+
 LEARNING_RATE = 0.0001
-REPORT_EVERY_ITER = 50
-SAVE_IMAGE_EVERY_ITER = 200
+REPORT_EVERY_ITER = 10
+SAVE_IMAGE_EVERY_ITER = 20
 MAX_ITERATION = 100000
 
 data_folder = 'synthesis_images/generated_blocks'
@@ -56,14 +60,19 @@ class Discriminator(nn.Module):
                       kernel_size=4, stride=2, padding=1),
             nn.BatchNorm2d(DISCR_FILTERS * 8),
             nn.ReLU(),
-            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
+
+            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=DISCR_FILTERS * 16,
+                      kernel_size=8, stride=2, padding=1),
+            nn.BatchNorm2d(DISCR_FILTERS * 16),
+            nn.ReLU(),
+
+            nn.Conv2d(in_channels=DISCR_FILTERS * 16, out_channels=1,
                       kernel_size=4, stride=1, padding=0),
             nn.Sigmoid()
         )
 
     def forward(self, x):
         conv_out = self.conv_pipe(x)
-        print(conv_out.view(-1, 1).squeeze(dim=1).shape)
         return conv_out.view(-1, 1).squeeze(dim=1)
 
 
@@ -72,22 +81,29 @@ class Generator(nn.Module):
         super(Generator, self).__init__()
         # pipe deconvolves input vector into (3, 64, 64) image
         self.pipe = nn.Sequential(
-            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 6,
-                               kernel_size=4, stride=1, padding=0),
-            nn.BatchNorm2d(GENER_FILTERS * 6),
+            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 5,
+                               kernel_size=6, stride=1, padding=0),
+            nn.BatchNorm2d(GENER_FILTERS * 5),
             nn.ReLU(),
-            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 6, out_channels=GENER_FILTERS * 4,
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 5, out_channels=GENER_FILTERS * 4,
                                kernel_size=4, stride=2, padding=1),
             nn.BatchNorm2d(GENER_FILTERS * 4),
             nn.ReLU(),
-            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 3,
                                kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(GENER_FILTERS * 3),
+            nn.ReLU(),
+
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 3, out_channels=GENER_FILTERS * 2,
+                               kernel_size=6, stride=2, padding=1),
             nn.BatchNorm2d(GENER_FILTERS * 2),
             nn.ReLU(),
+
             nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
                                kernel_size=4, stride=2, padding=1),
             nn.BatchNorm2d(GENER_FILTERS),
             nn.ReLU(),
+
             nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
                                kernel_size=4, stride=2, padding=1),
             nn.Tanh()
@@ -123,14 +139,64 @@ def iterate_batches(batch_size=BATCH_SIZE):
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
-    args = parser.parse_args()
 
-    device = torch.device("cuda" if args.cuda else "cpu")
+    save_model = False
+    load_model = False
+    p_cuda = False
+
+    #parser = argparse.ArgumentParser()
+    #parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
+    #args = parser.parse_args()
+
+    try:
+        opts, args = getopt.getopt(sys.argv[1:], "hflc", ["help=", "folder=", "load=", "cuda="])
+    except getopt.GetoptError:
+        # print help information and exit:
+        print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
+        print('python ganSynthesisImage_200.py --load model_name_to_load ')
+        sys.exit(2)
+    for o, a in opts:
+        if o in ("-h", "--help"):
+            print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
+            print('python ganSynthesisImage_200.py --load folder_name_to_load ')
+            sys.exit()
+        elif o in ("-f", "--folder"):
+            p_model_folder = a
+            save_model = True
+        elif o in ("-l", "--load"):
+            p_load = a
+            load_model = True
+        elif o in ("-c", "--cuda"):
+            p_cuda = int(a)
+        else:
+            assert False, "unhandled option"
+
+    if save_model and load_model:
+        raise Exception("Cannot save and load model. One argurment in only required.")
+    if not save_model and not load_model:
+        print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
+        print('python ganSynthesisImage_200.py --load folder_name_to_load ')
+        print("Need at least one argurment.")
+        sys.exit(2)
+
+    device = torch.device("cuda" if p_cuda else "cpu")
     #envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
 
-    print(input_shape)
+
+    # prepare folder names to save models
+    if save_model:
+
+        models_folder_path = os.path.join(BACKUP_FOLDER, p_model_folder)
+        dis_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('disc'))
+        gen_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('gen'))
+
+    if load_model:
+
+        models_folder_path = os.path.join(BACKUP_FOLDER, p_load)
+        dis_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('disc'))
+        gen_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('gen'))
+
+    # Construct model
     net_discr = Discriminator(input_shape=input_shape).to(device)
     net_gener = Generator(output_shape=input_shape).to(device)
     print(net_discr)
@@ -148,6 +214,24 @@ if __name__ == "__main__":
     true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
     fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
 
+
+    # load models checkpoint if exists
+    if load_model:
+        gen_checkpoint = torch.load(gen_model_path)
+
+        net_gener.load_state_dict(gen_checkpoint['model_state_dict'])
+        gen_optimizer.load_state_dict(gen_checkpoint['optimizer_state_dict'])
+        gen_losses = gen_checkpoint['gen_losses']
+        iteration = gen_checkpoint['iteration'] # retrieve only from the gen net the iteration number
+
+        dis_checkpoint = torch.load(dis_model_path)
+
+        net_discr.load_state_dict(dis_checkpoint['model_state_dict'])
+        dis_optimizer.load_state_dict(dis_checkpoint['optimizer_state_dict'])
+        dis_losses = dis_checkpoint['dis_losses']
+
+        iter_no = iteration
+
     for batch_v in iterate_batches():
 
         # generate extra fake samples, input is 4D: batch, filters, x, y
@@ -155,14 +239,13 @@ if __name__ == "__main__":
 
         # There we get data
         batch_v = batch_v.to(device)
+
         gen_output_v = net_gener(gen_input_v)
 
         # train discriminator
         dis_optimizer.zero_grad()
         dis_output_true_v = net_discr(batch_v)
-        print(len(dis_output_true_v))
         dis_output_fake_v = net_discr(gen_output_v.detach())
-        print(len(dis_output_true_v))
         dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
         dis_loss.backward()
         dis_optimizer.step()
@@ -177,12 +260,38 @@ if __name__ == "__main__":
         gen_losses.append(gen_loss_v.item())
 
         iter_no += 1
+        print("Iteration : ", iter_no)
+
         if iter_no % REPORT_EVERY_ITER == 0:
             log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
             writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
             writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
             gen_losses = []
             dis_losses = []
+
         if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
             writer.add_image("fake", vutils.make_grid(gen_output_v.data[:IMAGE_SIZE], normalize=True), iter_no)
             writer.add_image("real", vutils.make_grid(batch_v.data[:IMAGE_SIZE], normalize=True), iter_no)
+
+        if iter_no % BACKUP_EVERY_ITER == 0:
+            if not os.path.exists(models_folder_path):
+                os.makedirs(models_folder_path)
+
+            torch.save({
+                        'iteration': iter_no,
+                        'model_state_dict': net_gener.state_dict(),
+                        'optimizer_state_dict': gen_optimizer.state_dict(),
+                        'gen_losses': gen_losses
+                    }, gen_model_path)
+
+            torch.save({
+                        'iteration': iter_no,
+                        'model_state_dict': net_discr.state_dict(),
+                        'optimizer_state_dict': dis_optimizer.state_dict(),
+                        'dis_losses': dis_losses
+                    }, dis_model_path)
+
+
+
+
+

+ 58 - 36
noise_gan.ipynb

@@ -9,54 +9,76 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "(3, 64, 64)\n",
+      "Discriminator(\n",
+      "  (conv_pipe): Sequential(\n",
+      "    (0): Conv2d(3, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (1): ReLU()\n",
+      "    (2): Conv2d(200, 400, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (3): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (4): ReLU()\n",
+      "    (5): Conv2d(400, 800, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (6): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (7): ReLU()\n",
+      "    (8): Conv2d(800, 1600, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (9): BatchNorm2d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (10): ReLU()\n",
+      "    (11): Conv2d(1600, 3200, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))\n",
+      "    (12): BatchNorm2d(3200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (13): ReLU()\n",
+      "    (14): Conv2d(3200, 1, kernel_size=(4, 4), stride=(1, 1))\n",
+      "    (15): Sigmoid()\n",
+      "  )\n",
+      ")\n",
       "Generator(\n",
       "  (pipe): Sequential(\n",
-      "    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1))\n",
-      "    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (0): ConvTranspose2d(400, 1000, kernel_size=(6, 6), stride=(1, 1))\n",
+      "    (1): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
-      "    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
-      "    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (3): ConvTranspose2d(1000, 800, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (4): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (5): ReLU()\n",
-      "    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
-      "    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (6): ConvTranspose2d(800, 600, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (7): BatchNorm2d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (8): ReLU()\n",
-      "    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
-      "    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (9): ConvTranspose2d(600, 400, kernel_size=(6, 6), stride=(2, 2), padding=(1, 1))\n",
+      "    (10): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (11): ReLU()\n",
-      "    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
-      "    (13): Tanh()\n",
+      "    (12): ConvTranspose2d(400, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (13): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+      "    (14): ReLU()\n",
+      "    (15): ConvTranspose2d(200, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
+      "    (16): Tanh()\n",
       "  )\n",
       ")\n",
-      "INFO: Iter 50: gen_loss=3.107e+00, dis_loss=3.545e-01\n",
-      "INFO: Iter 100: gen_loss=5.438e+00, dis_loss=3.737e-02\n",
-      "INFO: Iter 150: gen_loss=6.108e+00, dis_loss=1.413e-02\n",
-      "INFO: Iter 200: gen_loss=6.497e+00, dis_loss=7.928e-03\n",
-      "INFO: Iter 250: gen_loss=6.661e+00, dis_loss=4.426e-03\n",
-      "INFO: Iter 300: gen_loss=6.974e+00, dis_loss=3.058e-03\n",
-      "INFO: Iter 350: gen_loss=7.264e+00, dis_loss=3.352e-03\n",
-      "INFO: Iter 400: gen_loss=8.008e+00, dis_loss=1.973e-03\n",
-      "INFO: Iter 450: gen_loss=7.612e+00, dis_loss=7.998e-02\n",
-      "INFO: Iter 500: gen_loss=7.241e+00, dis_loss=2.489e-02\n",
-      "INFO: Iter 550: gen_loss=6.099e+00, dis_loss=3.014e-01\n",
-      "INFO: Iter 600: gen_loss=5.541e+00, dis_loss=3.168e-01\n",
-      "INFO: Iter 650: gen_loss=4.301e+00, dis_loss=3.208e-01\n",
-      "INFO: Iter 700: gen_loss=4.265e+00, dis_loss=3.301e-01\n",
-      "INFO: Iter 750: gen_loss=4.739e+00, dis_loss=1.881e-01\n",
-      "INFO: Iter 800: gen_loss=4.413e+00, dis_loss=2.648e-01\n",
-      "INFO: Iter 850: gen_loss=5.190e+00, dis_loss=1.595e-01\n",
-      "INFO: Iter 900: gen_loss=5.641e+00, dis_loss=1.241e-01\n",
-      "INFO: Iter 950: gen_loss=4.454e+00, dis_loss=4.375e-01\n",
-      "INFO: Iter 1000: gen_loss=4.169e+00, dis_loss=2.384e-01\n",
-      "INFO: Iter 1050: gen_loss=4.946e+00, dis_loss=1.716e-01\n",
-      "INFO: Iter 1100: gen_loss=4.234e+00, dis_loss=2.619e-01\n",
-      "INFO: Iter 1150: gen_loss=4.644e+00, dis_loss=1.127e-01\n",
-      "INFO: Iter 1200: gen_loss=4.943e+00, dis_loss=2.967e-01\n"
+      "Iteration :  20\n",
+      "INFO: Iter 20: gen_loss=9.649e+00, dis_loss=5.266e-01\n",
+      "Iteration :  21\n",
+      "Iteration :  22\n",
+      "Iteration :  23\n",
+      "Iteration :  24\n",
+      "Iteration :  25\n",
+      "Iteration :  26\n",
+      "Iteration :  27\n",
+      "Iteration :  28\n",
+      "Iteration :  29\n",
+      "Iteration :  30\n",
+      "INFO: Iter 30: gen_loss=6.851e+00, dis_loss=1.911e-01\n",
+      "Iteration :  31\n",
+      "Iteration :  32\n",
+      "Iteration :  33\n",
+      "Iteration :  34\n",
+      "Iteration :  35\n",
+      "Iteration :  36\n",
+      "Iteration :  37\n",
+      "Iteration :  38\n",
+      "Iteration :  39\n",
+      "Iteration :  40\n",
+      "INFO: Iter 40: gen_loss=6.467e+00, dis_loss=3.223e-01\n"
      ]
     }
    ],
    "source": [
-    "!python ganSynthesisImage.py"
+    "!python ganSynthesisImage_200.py --load test_model"
    ]
   },
   {

+ 4 - 3
synthesis_images/prepare_data.py

@@ -4,8 +4,9 @@ from PIL import Image
 import shutil
 import os
 
-images_folder = "images"
-dest_folder = "generated_blocks"
+main_image_folder = "synthesis_images"
+images_folder = os.path.join(main_image_folder, "images")
+dest_folder = os.path.join(main_image_folder, "generated_blocks")
 
 if os.path.exists(dest_folder):
     # first remove folder if necessary
@@ -21,7 +22,7 @@ for img_path in images:
 
     img = Image.open(os.path.join(images_folder, img_path))
 
-    blocks = processing.divide_in_blocks(img, (80, 80), pil=True)
+    blocks = processing.divide_in_blocks(img, (200, 200), pil=True)
 
     for id, pil_block in enumerate(blocks):
         img_name = img_path.split('/')[-1]

BIN
synthesis_images/images/SdB2_00930.png


BIN
synthesis_images/images/SdB2_00940.png


BIN
synthesis_images/images/SdB2_00950.png


BIN
synthesis_images/images/SdB2_D_00930.png


BIN
synthesis_images/images/SdB2_D_00940.png


BIN
synthesis_images/images/SdB2_D_00950.png


BIN
synthesis_images/images/appartAopt_00850.png


BIN
synthesis_images/images/appartAopt_00860.png


BIN
synthesis_images/images/appartAopt_00870.png


BIN
synthesis_images/images/cuisine01_01180.png


BIN
synthesis_images/images/cuisine01_01190.png


BIN
synthesis_images/images/cuisine01_01200.png


+ 11 - 9
tensorboard.ipynb

@@ -1,5 +1,14 @@
 {
  "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!rm -rf runs/"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -9,20 +18,13 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "TensorBoard 1.12.2 at http://f735e09f9d2f:6006 (Press CTRL+C to quit)\n"
+      "TensorBoard 1.12.2 at http://7447cb2679c8:6006 (Press CTRL+C to quit)\n"
      ]
     }
    ],
    "source": [
-    "!tensorboard --logdir runs/Jan23_10-21-50_f735e09f9d2f"
+    "!tensorboard --logdir runs"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {