ganSynthesisImage_200.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. #!/usr/bin/env python
  2. import random
  3. import argparse
  4. import cv2
  5. import os, sys, getopt
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from tensorboardX import SummaryWriter
  10. from PIL import Image
  11. import torchvision.utils as vutils
  12. import gym
  13. import gym.spaces
  14. import numpy as np
  15. log = gym.logger
  16. log.set_level(gym.logger.INFO)
  17. LATENT_VECTOR_SIZE = 400
  18. DISCR_FILTERS = 200
  19. GENER_FILTERS = 200
  20. BATCH_SIZE = 16
  21. # dimension input image will be rescaled
  22. IMAGE_SIZE = 200
  23. input_shape = (3, IMAGE_SIZE, IMAGE_SIZE)
  24. BACKUP_MODEL_NAME = "synthesis_{}_model.pt"
  25. BACKUP_FOLDER = "saved_models"
  26. BACKUP_EVERY_ITER = 1
  27. LEARNING_RATE = 0.0001
  28. REPORT_EVERY_ITER = 10
  29. SAVE_IMAGE_EVERY_ITER = 20
  30. MAX_ITERATION = 100000
  31. data_folder = 'synthesis_images/generated_blocks'
  32. class Discriminator(nn.Module):
  33. def __init__(self, input_shape):
  34. super(Discriminator, self).__init__()
  35. # this pipe converges image into the single number
  36. self.conv_pipe = nn.Sequential(
  37. nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
  38. kernel_size=4, stride=2, padding=1),
  39. nn.ReLU(),
  40. nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
  41. kernel_size=4, stride=2, padding=1),
  42. nn.BatchNorm2d(DISCR_FILTERS*2),
  43. nn.ReLU(),
  44. nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
  45. kernel_size=4, stride=2, padding=1),
  46. nn.BatchNorm2d(DISCR_FILTERS * 4),
  47. nn.ReLU(),
  48. nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
  49. kernel_size=4, stride=2, padding=1),
  50. nn.BatchNorm2d(DISCR_FILTERS * 8),
  51. nn.ReLU(),
  52. nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=DISCR_FILTERS * 16,
  53. kernel_size=8, stride=2, padding=1),
  54. nn.BatchNorm2d(DISCR_FILTERS * 16),
  55. nn.ReLU(),
  56. nn.Conv2d(in_channels=DISCR_FILTERS * 16, out_channels=1,
  57. kernel_size=4, stride=1, padding=0),
  58. nn.Sigmoid()
  59. )
  60. def forward(self, x):
  61. conv_out = self.conv_pipe(x)
  62. return conv_out.view(-1, 1).squeeze(dim=1)
  63. class Generator(nn.Module):
  64. def __init__(self, output_shape):
  65. super(Generator, self).__init__()
  66. # pipe deconvolves input vector into (3, 64, 64) image
  67. self.pipe = nn.Sequential(
  68. nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 5,
  69. kernel_size=6, stride=1, padding=0),
  70. nn.BatchNorm2d(GENER_FILTERS * 5),
  71. nn.ReLU(),
  72. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 5, out_channels=GENER_FILTERS * 4,
  73. kernel_size=4, stride=2, padding=1),
  74. nn.BatchNorm2d(GENER_FILTERS * 4),
  75. nn.ReLU(),
  76. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 3,
  77. kernel_size=4, stride=2, padding=1),
  78. nn.BatchNorm2d(GENER_FILTERS * 3),
  79. nn.ReLU(),
  80. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 3, out_channels=GENER_FILTERS * 2,
  81. kernel_size=6, stride=2, padding=1),
  82. nn.BatchNorm2d(GENER_FILTERS * 2),
  83. nn.ReLU(),
  84. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
  85. kernel_size=4, stride=2, padding=1),
  86. nn.BatchNorm2d(GENER_FILTERS),
  87. nn.ReLU(),
  88. nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
  89. kernel_size=4, stride=2, padding=1),
  90. nn.Tanh()
  91. )
  92. def forward(self, x):
  93. return self.pipe(x)
  94. # here we have to generate our batches from final or noisy synthesis images
  95. def iterate_batches(batch_size=BATCH_SIZE):
  96. batch = []
  97. images = os.listdir(data_folder)
  98. nb_images = len(images)
  99. while True:
  100. i = random.randint(0, nb_images - 1)
  101. img = Image.open(os.path.join(data_folder, images[i]))
  102. img_arr = np.asarray(img)
  103. new_obs = cv2.resize(img_arr, (IMAGE_SIZE, IMAGE_SIZE))
  104. # transform (210, 160, 3) -> (3, 210, 160)
  105. new_obs = np.moveaxis(new_obs, 2, 0)
  106. batch.append(new_obs)
  107. if len(batch) == batch_size:
  108. # Normalising input between -1 to 1
  109. batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
  110. yield torch.tensor(batch_np)
  111. batch.clear()
  112. if __name__ == "__main__":
  113. save_model = False
  114. load_model = False
  115. p_cuda = False
  116. #parser = argparse.ArgumentParser()
  117. #parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
  118. #args = parser.parse_args()
  119. try:
  120. opts, args = getopt.getopt(sys.argv[1:], "hflc", ["help=", "folder=", "load=", "cuda="])
  121. except getopt.GetoptError:
  122. # print help information and exit:
  123. print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
  124. print('python ganSynthesisImage_200.py --load model_name_to_load ')
  125. sys.exit(2)
  126. for o, a in opts:
  127. if o in ("-h", "--help"):
  128. print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
  129. print('python ganSynthesisImage_200.py --load folder_name_to_load ')
  130. sys.exit()
  131. elif o in ("-f", "--folder"):
  132. p_model_folder = a
  133. save_model = True
  134. elif o in ("-l", "--load"):
  135. p_load = a
  136. load_model = True
  137. elif o in ("-c", "--cuda"):
  138. p_cuda = int(a)
  139. else:
  140. assert False, "unhandled option"
  141. if save_model and load_model:
  142. raise Exception("Cannot save and load model. One argurment in only required.")
  143. if not save_model and not load_model:
  144. print('python ganSynthesisImage_200.py --folder folder_name_to_save --cuda 1')
  145. print('python ganSynthesisImage_200.py --load folder_name_to_load ')
  146. print("Need at least one argurment.")
  147. sys.exit(2)
  148. device = torch.device("cuda" if p_cuda else "cpu")
  149. #envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
  150. # prepare folder names to save models
  151. if save_model:
  152. models_folder_path = os.path.join(BACKUP_FOLDER, p_model_folder)
  153. dis_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('disc'))
  154. gen_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('gen'))
  155. if load_model:
  156. models_folder_path = os.path.join(BACKUP_FOLDER, p_load)
  157. dis_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('disc'))
  158. gen_model_path = os.path.join(models_folder_path, BACKUP_MODEL_NAME.format('gen'))
  159. # Construct model
  160. net_discr = Discriminator(input_shape=input_shape).to(device)
  161. net_gener = Generator(output_shape=input_shape).to(device)
  162. print(net_discr)
  163. print(net_gener)
  164. objective = nn.BCELoss()
  165. gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  166. dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  167. writer = SummaryWriter()
  168. gen_losses = []
  169. dis_losses = []
  170. iter_no = 0
  171. true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
  172. fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
  173. # load models checkpoint if exists
  174. if load_model:
  175. gen_checkpoint = torch.load(gen_model_path)
  176. net_gener.load_state_dict(gen_checkpoint['model_state_dict'])
  177. gen_optimizer.load_state_dict(gen_checkpoint['optimizer_state_dict'])
  178. gen_losses = gen_checkpoint['gen_losses']
  179. iteration = gen_checkpoint['iteration'] # retrieve only from the gen net the iteration number
  180. dis_checkpoint = torch.load(dis_model_path)
  181. net_discr.load_state_dict(dis_checkpoint['model_state_dict'])
  182. dis_optimizer.load_state_dict(dis_checkpoint['optimizer_state_dict'])
  183. dis_losses = dis_checkpoint['dis_losses']
  184. iter_no = iteration
  185. for batch_v in iterate_batches():
  186. # generate extra fake samples, input is 4D: batch, filters, x, y
  187. gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
  188. # There we get data
  189. batch_v = batch_v.to(device)
  190. gen_output_v = net_gener(gen_input_v)
  191. # train discriminator
  192. dis_optimizer.zero_grad()
  193. dis_output_true_v = net_discr(batch_v)
  194. dis_output_fake_v = net_discr(gen_output_v.detach())
  195. dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
  196. dis_loss.backward()
  197. dis_optimizer.step()
  198. dis_losses.append(dis_loss.item())
  199. # train generator
  200. gen_optimizer.zero_grad()
  201. dis_output_v = net_discr(gen_output_v)
  202. gen_loss_v = objective(dis_output_v, true_labels_v)
  203. gen_loss_v.backward()
  204. gen_optimizer.step()
  205. gen_losses.append(gen_loss_v.item())
  206. iter_no += 1
  207. print("Iteration : ", iter_no)
  208. if iter_no % REPORT_EVERY_ITER == 0:
  209. log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
  210. writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
  211. writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
  212. gen_losses = []
  213. dis_losses = []
  214. if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
  215. writer.add_image("fake", vutils.make_grid(gen_output_v.data[:IMAGE_SIZE], normalize=True), iter_no)
  216. writer.add_image("real", vutils.make_grid(batch_v.data[:IMAGE_SIZE], normalize=True), iter_no)
  217. if iter_no % BACKUP_EVERY_ITER == 0:
  218. if not os.path.exists(models_folder_path):
  219. os.makedirs(models_folder_path)
  220. torch.save({
  221. 'iteration': iter_no,
  222. 'model_state_dict': net_gener.state_dict(),
  223. 'optimizer_state_dict': gen_optimizer.state_dict(),
  224. 'gen_losses': gen_losses
  225. }, gen_model_path)
  226. torch.save({
  227. 'iteration': iter_no,
  228. 'model_state_dict': net_discr.state_dict(),
  229. 'optimizer_state_dict': dis_optimizer.state_dict(),
  230. 'dis_losses': dis_losses
  231. }, dis_model_path)