plot_info.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # module filewhich contains helpful display function
  2. import matplotlib.pyplot as plt
  3. '''
  4. Function which saves data from neural network model
  5. '''
  6. def save(history, filename):
  7. # summarize history for accuracy
  8. plt.plot(history.history['acc'])
  9. plt.plot(history.history['val_acc'])
  10. plt.title('model accuracy')
  11. plt.ylabel('accuracy')
  12. plt.xlabel('epoch')
  13. plt.legend(['train', 'test'], loc='upper left')
  14. plt.savefig(str('%s_accuracy.png' % filename))
  15. # clear plt history
  16. plt.gcf().clear()
  17. # summarize history for loss
  18. plt.plot(history.history['loss'])
  19. plt.plot(history.history['val_loss'])
  20. plt.title('model loss')
  21. plt.ylabel('loss')
  22. plt.xlabel('epoch')
  23. plt.legend(['train', 'test'], loc='upper left')
  24. plt.savefig(str('%s_loss.png' % filename))
  25. def show(history, filename):
  26. # summarize history for accuracy
  27. plt.plot(history.history['acc'])
  28. plt.plot(history.history['val_acc'])
  29. plt.title('model accuracy')
  30. plt.ylabel('accuracy')
  31. plt.xlabel('epoch')
  32. plt.legend(['train', 'test'], loc='upper left')
  33. plt.show()
  34. # summarize history for loss
  35. plt.plot(history.history['loss'])
  36. plt.plot(history.history['val_loss'])
  37. plt.title('model loss')
  38. plt.ylabel('loss')
  39. plt.xlabel('epoch')
  40. plt.legend(['train', 'test'], loc='upper left')
  41. plt.show()