colorPaletteKmeanDisplay.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from . import KmeanDisplay
  4. import miam.aesthetics.Palette as MPAL
  5. import miam.image.ColorSpace as MICS
  6. import miam.pointcloud.PointCloud2D as MPC2D
  7. class colorPaletteKmeanDisplay(KmeanDisplay.KmeanDisplay):
  8. """ xxx """
  9. def __init__(self, pcolor= None):
  10. # colors
  11. self.colors = ['b', 'g','r','c','m','y','k'] if pcolor == None else pcolor
  12. # figure and axes
  13. fig, ax = plt.subplots(2,2) # 1 row x 2 columns
  14. fig.suptitle('Palette Classification by k-means[color palette distance]')
  15. wm = plt.get_current_fig_manager()
  16. wm.window.state('zoomed')
  17. self.fig = fig
  18. self.axes = ax
  19. def plot(self, centroids, assigmentsIdx, iter, convergence,nbSamples):
  20. # display centroids shape
  21. self.axes[0,0].cla()
  22. _range = max([np.amax(np.abs(centroids[:,:,1])),np.amax(np.abs(centroids[:,:,2]))])
  23. for i,c in enumerate(centroids):
  24. # plot palette
  25. x = c[:,1] # a -> x
  26. y = c[:,2] # b -> y
  27. pc = MPC2D.PointCloud2D(x,y)
  28. xx,yy = MPC2D.PointCloud2D.toXYarray(pc.convexHull())
  29. self.axes[0,0].plot(xx,yy,self.colors[i%len(self.colors)]+'o--', linewidth=1, markersize=3)
  30. self.axes[0,0].plot([- _range, _range],[0,0],'k', linewidth=0.1)
  31. self.axes[0,0].plot([0,0],[-_range,_range],'k', linewidth=0.1)
  32. self.axes[0,0].set_title("centroids (iter:"+str(iter)+")"+"['b', 'g','r','c','m','y','k']")
  33. # display image of palette
  34. self.axes[0,1].cla()
  35. self.axes[0,1].set_title("palettes("+str(centroids.shape[0])+")")
  36. palettes = []
  37. for c in centroids:
  38. palettes.append(MPAL.Palette("",c, MICS.ColorSpace.buildLab()))
  39. img = MPAL.Palette.createImageOfPalettes(palettes)
  40. img.plot(self.axes[0,1],title=False)
  41. # distance
  42. numberOfChange, numberOfRemain, meanDistance = convergence
  43. self.axes[1,0].cla()
  44. self.axes[1,0].set_title("mean distance:"+str(meanDistance[-1]/centroids.shape[0]))
  45. self.axes[1,0].plot(np.asarray(meanDistance)/centroids.shape[0],'r')
  46. # change
  47. self.axes[1,1].cla()
  48. self.axes[1,1].set_title("convergence index (% change)/ absolute:"+str(numberOfChange[-1]))
  49. self.axes[1,1].plot(100*np.asarray(numberOfChange)/nbSamples,'g')
  50. # pause
  51. plt.pause(0.05)