1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- # module filewhich contains helpful display function
- import matplotlib.pyplot as plt
- '''
- Function which saves data from neural network model
- '''
- def save(history, filename):
- # summarize history for accuracy
- plt.plot(history.history['acc'])
- plt.plot(history.history['val_acc'])
- plt.title('model accuracy')
- plt.ylabel('accuracy')
- plt.xlabel('epoch')
- plt.legend(['train', 'test'], loc='upper left')
- plt.savefig(str('%s_accuracy.png' % filename))
- # clear plt history
- plt.gcf().clear()
- # summarize history for loss
- plt.plot(history.history['loss'])
- plt.plot(history.history['val_loss'])
- plt.title('model loss')
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.legend(['train', 'test'], loc='upper left')
- plt.savefig(str('%s_loss.png' % filename))
- def show(history, filename):
- # summarize history for accuracy
- plt.plot(history.history['acc'])
- plt.plot(history.history['val_acc'])
- plt.title('model accuracy')
- plt.ylabel('accuracy')
- plt.xlabel('epoch')
- plt.legend(['train', 'test'], loc='upper left')
- plt.show()
- # summarize history for loss
- plt.plot(history.history['loss'])
- plt.plot(history.history['val_loss'])
- plt.title('model loss')
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.legend(['train', 'test'], loc='upper left')
- plt.show()
|