Parcourir la source

try using 200 pixels images

Jérôme BUISINE il y a 5 ans
Parent
commit
70d4356303
5 fichiers modifiés avec 224 ajouts et 8 suppressions
  1. 1 0
      .gitignore
  2. 9 0
      ganSynthesisImage.py
  3. 188 0
      ganSynthesisImage_200.py
  4. 24 6
      noise_gan.ipynb
  5. 2 2
      tensorboard.ipynb

+ 1 - 0
.gitignore

@@ -4,3 +4,4 @@ runs
 
 # do not track blocks images
 synthesis_images/generated_blocks
+saved_models

+ 9 - 0
ganSynthesisImage.py

@@ -35,6 +35,7 @@ SAVE_IMAGE_EVERY_ITER = 200
 MAX_ITERATION = 100000
 
 data_folder = 'synthesis_images/generated_blocks'
+models_folder = 'saved_models'
 
 class Discriminator(nn.Module):
     def __init__(self, input_shape):
@@ -182,3 +183,11 @@ if __name__ == "__main__":
         if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
             writer.add_image("fake", vutils.make_grid(gen_output_v.data[:IMAGE_SIZE], normalize=True), iter_no)
             writer.add_image("real", vutils.make_grid(batch_v.data[:IMAGE_SIZE], normalize=True), iter_no)
+
+        if iter_no >= MAX_ITERATION:
+            # end of train
+            break
+
+    # now save these two models
+    torch.save(net_discr.state_dict(), os.path.join(models_folder, 'net_discr_model.pt'))
+    torch.save(net_gener.state_dict(), os.path.join(models_folder, 'net_gener_model.pt'))

+ 188 - 0
ganSynthesisImage_200.py

@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+import random
+import argparse
+import cv2
+import os
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from tensorboardX import SummaryWriter
+from PIL import Image
+
+import torchvision.utils as vutils
+
+import gym
+import gym.spaces
+
+import numpy as np
+
+log = gym.logger
+log.set_level(gym.logger.INFO)
+
+LATENT_VECTOR_SIZE = 100
+DISCR_FILTERS = 200
+GENER_FILTERS = 200
+BATCH_SIZE = 16
+
+# dimension input image will be rescaled
+IMAGE_SIZE = 200
+input_shape = (3, IMAGE_SIZE, IMAGE_SIZE)
+
+LEARNING_RATE = 0.0001
+REPORT_EVERY_ITER = 50
+SAVE_IMAGE_EVERY_ITER = 200
+MAX_ITERATION = 100000
+
+data_folder = 'synthesis_images/generated_blocks'
+
+class Discriminator(nn.Module):
+    def __init__(self, input_shape):
+        super(Discriminator, self).__init__()
+        # this pipe converges image into the single number
+        self.conv_pipe = nn.Sequential(
+            nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
+                      kernel_size=4, stride=2, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
+                      kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(DISCR_FILTERS*2),
+            nn.ReLU(),
+            nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
+                      kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(DISCR_FILTERS * 4),
+            nn.ReLU(),
+            nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
+                      kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(DISCR_FILTERS * 8),
+            nn.ReLU(),
+            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
+                      kernel_size=4, stride=1, padding=0),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        conv_out = self.conv_pipe(x)
+        print(conv_out.view(-1, 1).squeeze(dim=1).shape)
+        return conv_out.view(-1, 1).squeeze(dim=1)
+
+
+class Generator(nn.Module):
+    def __init__(self, output_shape):
+        super(Generator, self).__init__()
+        # pipe deconvolves input vector into (3, 64, 64) image
+        self.pipe = nn.Sequential(
+            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 6,
+                               kernel_size=4, stride=1, padding=0),
+            nn.BatchNorm2d(GENER_FILTERS * 6),
+            nn.ReLU(),
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 6, out_channels=GENER_FILTERS * 4,
+                               kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(GENER_FILTERS * 4),
+            nn.ReLU(),
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
+                               kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(GENER_FILTERS * 2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
+                               kernel_size=4, stride=2, padding=1),
+            nn.BatchNorm2d(GENER_FILTERS),
+            nn.ReLU(),
+            nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
+                               kernel_size=4, stride=2, padding=1),
+            nn.Tanh()
+        )
+
+    def forward(self, x):
+        return self.pipe(x)
+
+# here we have to generate our batches from final or noisy synthesis images
+def iterate_batches(batch_size=BATCH_SIZE):
+
+    batch = []
+    images = os.listdir(data_folder)
+    nb_images = len(images)
+
+    while True:
+        i = random.randint(0, nb_images - 1)
+
+        img = Image.open(os.path.join(data_folder, images[i]))
+        img_arr = np.asarray(img)
+
+        new_obs = cv2.resize(img_arr, (IMAGE_SIZE, IMAGE_SIZE))
+        # transform (210, 160, 3) -> (3, 210, 160)
+        new_obs = np.moveaxis(new_obs, 2, 0)
+
+        batch.append(new_obs)
+
+        if len(batch) == batch_size:
+            # Normalising input between -1 to 1
+            batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
+            yield torch.tensor(batch_np)
+            batch.clear()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
+    args = parser.parse_args()
+
+    device = torch.device("cuda" if args.cuda else "cpu")
+    #envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
+
+    print(input_shape)
+    net_discr = Discriminator(input_shape=input_shape).to(device)
+    net_gener = Generator(output_shape=input_shape).to(device)
+    print(net_discr)
+    print(net_gener)
+
+    objective = nn.BCELoss()
+    gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
+    dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
+    writer = SummaryWriter()
+
+    gen_losses = []
+    dis_losses = []
+    iter_no = 0
+
+    true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
+    fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
+
+    for batch_v in iterate_batches():
+
+        # generate extra fake samples, input is 4D: batch, filters, x, y
+        gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
+
+        # There we get data
+        batch_v = batch_v.to(device)
+        gen_output_v = net_gener(gen_input_v)
+
+        # train discriminator
+        dis_optimizer.zero_grad()
+        dis_output_true_v = net_discr(batch_v)
+        print(len(dis_output_true_v))
+        dis_output_fake_v = net_discr(gen_output_v.detach())
+        print(len(dis_output_true_v))
+        dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
+        dis_loss.backward()
+        dis_optimizer.step()
+        dis_losses.append(dis_loss.item())
+
+        # train generator
+        gen_optimizer.zero_grad()
+        dis_output_v = net_discr(gen_output_v)
+        gen_loss_v = objective(dis_output_v, true_labels_v)
+        gen_loss_v.backward()
+        gen_optimizer.step()
+        gen_losses.append(gen_loss_v.item())
+
+        iter_no += 1
+        if iter_no % REPORT_EVERY_ITER == 0:
+            log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
+            writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
+            writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
+            gen_losses = []
+            dis_losses = []
+        if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
+            writer.add_image("fake", vutils.make_grid(gen_output_v.data[:IMAGE_SIZE], normalize=True), iter_no)
+            writer.add_image("real", vutils.make_grid(batch_v.data[:IMAGE_SIZE], normalize=True), iter_no)

+ 24 - 6
noise_gan.ipynb

@@ -28,12 +28,30 @@
       "    (13): Tanh()\n",
       "  )\n",
       ")\n",
-      "INFO: Iter 50: gen_loss=2.939e+00, dis_loss=4.371e-01\n",
-      "INFO: Iter 100: gen_loss=5.404e+00, dis_loss=5.104e-02\n",
-      "INFO: Iter 150: gen_loss=6.753e+00, dis_loss=2.676e-02\n",
-      "INFO: Iter 200: gen_loss=6.347e+00, dis_loss=2.015e-02\n",
-      "INFO: Iter 250: gen_loss=6.294e+00, dis_loss=3.237e-01\n",
-      "INFO: Iter 300: gen_loss=5.793e+00, dis_loss=2.149e-01\n"
+      "INFO: Iter 50: gen_loss=3.107e+00, dis_loss=3.545e-01\n",
+      "INFO: Iter 100: gen_loss=5.438e+00, dis_loss=3.737e-02\n",
+      "INFO: Iter 150: gen_loss=6.108e+00, dis_loss=1.413e-02\n",
+      "INFO: Iter 200: gen_loss=6.497e+00, dis_loss=7.928e-03\n",
+      "INFO: Iter 250: gen_loss=6.661e+00, dis_loss=4.426e-03\n",
+      "INFO: Iter 300: gen_loss=6.974e+00, dis_loss=3.058e-03\n",
+      "INFO: Iter 350: gen_loss=7.264e+00, dis_loss=3.352e-03\n",
+      "INFO: Iter 400: gen_loss=8.008e+00, dis_loss=1.973e-03\n",
+      "INFO: Iter 450: gen_loss=7.612e+00, dis_loss=7.998e-02\n",
+      "INFO: Iter 500: gen_loss=7.241e+00, dis_loss=2.489e-02\n",
+      "INFO: Iter 550: gen_loss=6.099e+00, dis_loss=3.014e-01\n",
+      "INFO: Iter 600: gen_loss=5.541e+00, dis_loss=3.168e-01\n",
+      "INFO: Iter 650: gen_loss=4.301e+00, dis_loss=3.208e-01\n",
+      "INFO: Iter 700: gen_loss=4.265e+00, dis_loss=3.301e-01\n",
+      "INFO: Iter 750: gen_loss=4.739e+00, dis_loss=1.881e-01\n",
+      "INFO: Iter 800: gen_loss=4.413e+00, dis_loss=2.648e-01\n",
+      "INFO: Iter 850: gen_loss=5.190e+00, dis_loss=1.595e-01\n",
+      "INFO: Iter 900: gen_loss=5.641e+00, dis_loss=1.241e-01\n",
+      "INFO: Iter 950: gen_loss=4.454e+00, dis_loss=4.375e-01\n",
+      "INFO: Iter 1000: gen_loss=4.169e+00, dis_loss=2.384e-01\n",
+      "INFO: Iter 1050: gen_loss=4.946e+00, dis_loss=1.716e-01\n",
+      "INFO: Iter 1100: gen_loss=4.234e+00, dis_loss=2.619e-01\n",
+      "INFO: Iter 1150: gen_loss=4.644e+00, dis_loss=1.127e-01\n",
+      "INFO: Iter 1200: gen_loss=4.943e+00, dis_loss=2.967e-01\n"
      ]
     }
    ],

+ 2 - 2
tensorboard.ipynb

@@ -9,12 +9,12 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "TensorBoard 1.12.2 at http://e0614dc78f89:6006 (Press CTRL+C to quit)\n"
+      "TensorBoard 1.12.2 at http://f735e09f9d2f:6006 (Press CTRL+C to quit)\n"
      ]
     }
    ],
    "source": [
-    "!tensorboard --logdir runs/Jan22_17-09-18_e0614dc78f89"
+    "!tensorboard --logdir runs/Jan23_10-21-50_f735e09f9d2f"
    ]
   },
   {