ganSynthesisImage.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #!/usr/bin/env python
  2. import random
  3. import argparse
  4. import cv2
  5. import os
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from tensorboardX import SummaryWriter
  10. from PIL import Image
  11. import torchvision.utils as vutils
  12. import gym
  13. import gym.spaces
  14. import numpy as np
  15. log = gym.logger
  16. log.set_level(gym.logger.INFO)
  17. LATENT_VECTOR_SIZE = 100
  18. DISCR_FILTERS = 64
  19. GENER_FILTERS = 64
  20. BATCH_SIZE = 16
  21. # dimension input image will be rescaled
  22. IMAGE_SIZE = 64
  23. input_shape = (3, IMAGE_SIZE, IMAGE_SIZE)
  24. LEARNING_RATE = 0.0001
  25. REPORT_EVERY_ITER = 50
  26. SAVE_IMAGE_EVERY_ITER = 200
  27. MAX_ITERATION = 100000
  28. data_folder = 'synthesis_images/generated_blocks'
  29. models_folder = 'saved_models'
  30. class Discriminator(nn.Module):
  31. def __init__(self, input_shape):
  32. super(Discriminator, self).__init__()
  33. # this pipe converges image into the single number
  34. self.conv_pipe = nn.Sequential(
  35. nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
  36. kernel_size=4, stride=2, padding=1),
  37. nn.ReLU(),
  38. nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
  39. kernel_size=4, stride=2, padding=1),
  40. nn.BatchNorm2d(DISCR_FILTERS*2),
  41. nn.ReLU(),
  42. nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
  43. kernel_size=4, stride=2, padding=1),
  44. nn.BatchNorm2d(DISCR_FILTERS * 4),
  45. nn.ReLU(),
  46. nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
  47. kernel_size=4, stride=2, padding=1),
  48. nn.BatchNorm2d(DISCR_FILTERS * 8),
  49. nn.ReLU(),
  50. nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
  51. kernel_size=4, stride=1, padding=0),
  52. nn.Sigmoid()
  53. )
  54. def forward(self, x):
  55. conv_out = self.conv_pipe(x)
  56. return conv_out.view(-1, 1).squeeze(dim=1)
  57. class Generator(nn.Module):
  58. def __init__(self, output_shape):
  59. super(Generator, self).__init__()
  60. # pipe deconvolves input vector into (3, 64, 64) image
  61. self.pipe = nn.Sequential(
  62. nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
  63. kernel_size=4, stride=1, padding=0),
  64. nn.BatchNorm2d(GENER_FILTERS * 8),
  65. nn.ReLU(),
  66. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
  67. kernel_size=4, stride=2, padding=1),
  68. nn.BatchNorm2d(GENER_FILTERS * 4),
  69. nn.ReLU(),
  70. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
  71. kernel_size=4, stride=2, padding=1),
  72. nn.BatchNorm2d(GENER_FILTERS * 2),
  73. nn.ReLU(),
  74. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
  75. kernel_size=4, stride=2, padding=1),
  76. nn.BatchNorm2d(GENER_FILTERS),
  77. nn.ReLU(),
  78. nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
  79. kernel_size=4, stride=2, padding=1),
  80. nn.Tanh()
  81. )
  82. def forward(self, x):
  83. return self.pipe(x)
  84. # here we have to generate our batches from final or noisy synthesis images
  85. def iterate_batches(batch_size=BATCH_SIZE):
  86. batch = []
  87. images = os.listdir(data_folder)
  88. nb_images = len(images)
  89. while True:
  90. i = random.randint(0, nb_images - 1)
  91. img = Image.open(os.path.join(data_folder, images[i]))
  92. img_arr = np.asarray(img)
  93. new_obs = cv2.resize(img_arr, (IMAGE_SIZE, IMAGE_SIZE))
  94. # transform (210, 160, 3) -> (3, 210, 160)
  95. new_obs = np.moveaxis(new_obs, 2, 0)
  96. batch.append(new_obs)
  97. if len(batch) == batch_size:
  98. # Normalising input between -1 to 1
  99. batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
  100. yield torch.tensor(batch_np)
  101. batch.clear()
  102. if __name__ == "__main__":
  103. parser = argparse.ArgumentParser()
  104. parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
  105. args = parser.parse_args()
  106. device = torch.device("cuda" if args.cuda else "cpu")
  107. #envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
  108. print(input_shape)
  109. net_discr = Discriminator(input_shape=input_shape).to(device)
  110. net_gener = Generator(output_shape=input_shape).to(device)
  111. print(net_gener)
  112. objective = nn.BCELoss()
  113. gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  114. dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  115. writer = SummaryWriter()
  116. gen_losses = []
  117. dis_losses = []
  118. iter_no = 0
  119. true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
  120. fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
  121. for batch_v in iterate_batches():
  122. # generate extra fake samples, input is 4D: batch, filters, x, y
  123. gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
  124. # There we get data
  125. batch_v = batch_v.to(device)
  126. gen_output_v = net_gener(gen_input_v)
  127. # train discriminator
  128. dis_optimizer.zero_grad()
  129. dis_output_true_v = net_discr(batch_v)
  130. dis_output_fake_v = net_discr(gen_output_v.detach())
  131. dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
  132. dis_loss.backward()
  133. dis_optimizer.step()
  134. dis_losses.append(dis_loss.item())
  135. # train generator
  136. gen_optimizer.zero_grad()
  137. dis_output_v = net_discr(gen_output_v)
  138. gen_loss_v = objective(dis_output_v, true_labels_v)
  139. gen_loss_v.backward()
  140. gen_optimizer.step()
  141. gen_losses.append(gen_loss_v.item())
  142. iter_no += 1
  143. if iter_no % REPORT_EVERY_ITER == 0:
  144. log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
  145. writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
  146. writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
  147. gen_losses = []
  148. dis_losses = []
  149. if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
  150. writer.add_image("fake", vutils.make_grid(gen_output_v.data[:IMAGE_SIZE], normalize=True), iter_no)
  151. writer.add_image("real", vutils.make_grid(batch_v.data[:IMAGE_SIZE], normalize=True), iter_no)
  152. if iter_no >= MAX_ITERATION:
  153. # end of train
  154. break
  155. # now save these two models
  156. torch.save(net_discr.state_dict(), os.path.join(models_folder, 'net_discr_model.pt'))
  157. torch.save(net_gener.state_dict(), os.path.join(models_folder, 'net_gener_model.pt'))