ts_model_helper.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # module filewhich contains helpful display function
  2. # avoid tk issue
  3. import matplotlib
  4. #matplotlib.use('agg')
  5. import matplotlib.pyplot as plt
  6. def save(history, filename):
  7. '''
  8. @brief Function which saves data from neural network model
  9. @param history : tensorflow model history
  10. @param filename : information about model filename
  11. @return nothing
  12. '''
  13. # summarize history for accuracy
  14. plt.plot(history.history['acc'])
  15. plt.plot(history.history['val_acc'])
  16. plt.title('model accuracy')
  17. plt.ylabel('accuracy')
  18. plt.xlabel('epoch')
  19. plt.legend(['train', 'test'], loc='upper left')
  20. plt.savefig(str('%s_accuracy.png' % filename))
  21. # clear plt history
  22. plt.gcf().clear()
  23. # summarize history for loss
  24. plt.plot(history.history['loss'])
  25. plt.plot(history.history['val_loss'])
  26. plt.title('model loss')
  27. plt.ylabel('loss')
  28. plt.xlabel('epoch')
  29. plt.legend(['train', 'test'], loc='upper left')
  30. plt.savefig(str('%s_loss.png' % filename))
  31. def show(history, filename):
  32. '''
  33. @brief Function which shows data from neural network model
  34. @param history : tensorflow model history
  35. @param filename : information about model filename
  36. @return nothing
  37. '''
  38. # summarize history for accuracy
  39. plt.plot(history.history['acc'])
  40. plt.plot(history.history['val_acc'])
  41. plt.title('model accuracy')
  42. plt.ylabel('accuracy')
  43. plt.xlabel('epoch')
  44. plt.legend(['train', 'test'], loc='upper left')
  45. plt.show()
  46. # summarize history for loss
  47. plt.plot(history.history['loss'])
  48. plt.plot(history.history['val_loss'])
  49. plt.title('model loss')
  50. plt.ylabel('loss')
  51. plt.xlabel('epoch')
  52. plt.legend(['train', 'test'], loc='upper left')
  53. plt.show()