plot_info.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # module filewhich contains helpful display function
  2. # avoid tk issue
  3. import matplotlib
  4. matplotlib.use('agg')
  5. import matplotlib.pyplot as plt
  6. '''
  7. Function which saves data from neural network model
  8. '''
  9. def save(history, filename):
  10. # summarize history for accuracy
  11. plt.plot(history.history['acc'])
  12. plt.plot(history.history['val_acc'])
  13. plt.title('model accuracy')
  14. plt.ylabel('accuracy')
  15. plt.xlabel('epoch')
  16. plt.legend(['train', 'test'], loc='upper left')
  17. plt.savefig(str('%s_accuracy.png' % filename))
  18. # clear plt history
  19. plt.gcf().clear()
  20. # summarize history for loss
  21. plt.plot(history.history['loss'])
  22. plt.plot(history.history['val_loss'])
  23. plt.title('model loss')
  24. plt.ylabel('loss')
  25. plt.xlabel('epoch')
  26. plt.legend(['train', 'test'], loc='upper left')
  27. plt.savefig(str('%s_loss.png' % filename))
  28. def show(history, filename):
  29. # summarize history for accuracy
  30. plt.plot(history.history['acc'])
  31. plt.plot(history.history['val_acc'])
  32. plt.title('model accuracy')
  33. plt.ylabel('accuracy')
  34. plt.xlabel('epoch')
  35. plt.legend(['train', 'test'], loc='upper left')
  36. plt.show()
  37. # summarize history for loss
  38. plt.plot(history.history['loss'])
  39. plt.plot(history.history['val_loss'])
  40. plt.title('model loss')
  41. plt.ylabel('loss')
  42. plt.xlabel('epoch')
  43. plt.legend(['train', 'test'], loc='upper left')
  44. plt.show()