ganAtariImage.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #!/usr/bin/env python
  2. import random
  3. import argparse
  4. import cv2
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from tensorboardX import SummaryWriter
  9. import torchvision.utils as vutils
  10. import gym
  11. import gym.spaces
  12. import numpy as np
  13. log = gym.logger
  14. log.set_level(gym.logger.INFO)
  15. LATENT_VECTOR_SIZE = 100
  16. DISCR_FILTERS = 64
  17. GENER_FILTERS = 64
  18. BATCH_SIZE = 16
  19. # dimension input image will be rescaled
  20. IMAGE_SIZE = 64
  21. LEARNING_RATE = 0.0001
  22. REPORT_EVERY_ITER = 100
  23. SAVE_IMAGE_EVERY_ITER = 200
  24. class InputWrapper(gym.ObservationWrapper):
  25. """
  26. Preprocessing of input numpy array:
  27. 1. resize image into predefined size
  28. 2. move color channel axis to a first place
  29. """
  30. def __init__(self, *args):
  31. super(InputWrapper, self).__init__(*args)
  32. assert isinstance(self.observation_space, gym.spaces.Box)
  33. old_space = self.observation_space
  34. self.observation_space = gym.spaces.Box(self.observation(old_space.low), self.observation(old_space.high),
  35. dtype=np.float32)
  36. def observation(self, observation):
  37. # resize image
  38. new_obs = cv2.resize(observation, (IMAGE_SIZE, IMAGE_SIZE))
  39. # transform (210, 160, 3) -> (3, 210, 160)
  40. new_obs = np.moveaxis(new_obs, 2, 0)
  41. return new_obs.astype(np.float32)
  42. class Discriminator(nn.Module):
  43. def __init__(self, input_shape):
  44. super(Discriminator, self).__init__()
  45. # this pipe converges image into the single number
  46. self.conv_pipe = nn.Sequential(
  47. nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
  48. kernel_size=4, stride=2, padding=1),
  49. nn.ReLU(),
  50. nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
  51. kernel_size=4, stride=2, padding=1),
  52. nn.BatchNorm2d(DISCR_FILTERS*2),
  53. nn.ReLU(),
  54. nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
  55. kernel_size=4, stride=2, padding=1),
  56. nn.BatchNorm2d(DISCR_FILTERS * 4),
  57. nn.ReLU(),
  58. nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
  59. kernel_size=4, stride=2, padding=1),
  60. nn.BatchNorm2d(DISCR_FILTERS * 8),
  61. nn.ReLU(),
  62. nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
  63. kernel_size=4, stride=1, padding=0),
  64. nn.Sigmoid()
  65. )
  66. def forward(self, x):
  67. conv_out = self.conv_pipe(x)
  68. return conv_out.view(-1, 1).squeeze(dim=1)
  69. class Generator(nn.Module):
  70. def __init__(self, output_shape):
  71. super(Generator, self).__init__()
  72. # pipe deconvolves input vector into (3, 64, 64) image
  73. self.pipe = nn.Sequential(
  74. nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
  75. kernel_size=4, stride=1, padding=0),
  76. nn.BatchNorm2d(GENER_FILTERS * 8),
  77. nn.ReLU(),
  78. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
  79. kernel_size=4, stride=2, padding=1),
  80. nn.BatchNorm2d(GENER_FILTERS * 4),
  81. nn.ReLU(),
  82. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
  83. kernel_size=4, stride=2, padding=1),
  84. nn.BatchNorm2d(GENER_FILTERS * 2),
  85. nn.ReLU(),
  86. nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
  87. kernel_size=4, stride=2, padding=1),
  88. nn.BatchNorm2d(GENER_FILTERS),
  89. nn.ReLU(),
  90. nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
  91. kernel_size=4, stride=2, padding=1),
  92. nn.Tanh()
  93. )
  94. def forward(self, x):
  95. return self.pipe(x)
  96. # here we have to generate our batches from final or noisy synthesis images
  97. def iterate_batches(envs, batch_size=BATCH_SIZE):
  98. batch = [e.reset() for e in envs]
  99. env_gen = iter(lambda: random.choice(envs), None)
  100. while True:
  101. e = next(env_gen)
  102. obs, reward, is_done, _ = e.step(e.action_space.sample())
  103. if np.mean(obs) > 0.01:
  104. batch.append(obs)
  105. if len(batch) == batch_size:
  106. # Normalising input between -1 to 1
  107. batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
  108. yield torch.tensor(batch_np)
  109. batch.clear()
  110. if is_done:
  111. e.reset()
  112. if __name__ == "__main__":
  113. parser = argparse.ArgumentParser()
  114. parser.add_argument("--cuda", default=False, action='store_true', help="Enable cuda computation")
  115. args = parser.parse_args()
  116. device = torch.device("cuda" if args.cuda else "cpu")
  117. envs = [InputWrapper(gym.make(name)) for name in ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')]
  118. input_shape = envs[0].observation_space.shape
  119. print(input_shape)
  120. net_discr = Discriminator(input_shape=input_shape).to(device)
  121. net_gener = Generator(output_shape=input_shape).to(device)
  122. print(net_gener)
  123. objective = nn.BCELoss()
  124. gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  125. dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  126. writer = SummaryWriter()
  127. gen_losses = []
  128. dis_losses = []
  129. iter_no = 0
  130. true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
  131. fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
  132. for batch_v in iterate_batches(envs):
  133. # generate extra fake samples, input is 4D: batch, filters, x, y
  134. gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
  135. batch_v = batch_v.to(device)
  136. gen_output_v = net_gener(gen_input_v)
  137. # train discriminator
  138. dis_optimizer.zero_grad()
  139. dis_output_true_v = net_discr(batch_v)
  140. dis_output_fake_v = net_discr(gen_output_v.detach())
  141. dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
  142. dis_loss.backward()
  143. dis_optimizer.step()
  144. dis_losses.append(dis_loss.item())
  145. # train generator
  146. gen_optimizer.zero_grad()
  147. dis_output_v = net_discr(gen_output_v)
  148. gen_loss_v = objective(dis_output_v, true_labels_v)
  149. gen_loss_v.backward()
  150. gen_optimizer.step()
  151. gen_losses.append(gen_loss_v.item())
  152. iter_no += 1
  153. if iter_no % REPORT_EVERY_ITER == 0:
  154. log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
  155. writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
  156. writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
  157. gen_losses = []
  158. dis_losses = []
  159. if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
  160. writer.add_image("fake", vutils.make_grid(gen_output_v.data[:64], normalize=True), iter_no)
  161. writer.add_image("real", vutils.make_grid(batch_v.data[:64], normalize=True), iter_no)