12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import matplotlib.pyplot as plt
- import numpy as np
- from . import KmeanDisplay
- import miam.aesthetics.Palette as MPAL
- import miam.image.ColorSpace as MICS
- import miam.pointcloud.PointCloud2D as MPC2D
- class colorPaletteKmeanDisplay(KmeanDisplay.KmeanDisplay):
- """ xxx """
- def __init__(self, pcolor= None):
- # colors
- self.colors = ['b', 'g','r','c','m','y','k'] if pcolor == None else pcolor
- # figure and axes
- fig, ax = plt.subplots(2,2) # 1 row x 2 columns
- fig.suptitle('Palette Classification by k-means[color palette distance]')
- wm = plt.get_current_fig_manager()
- wm.window.state('zoomed')
- self.fig = fig
- self.axes = ax
- def plot(self, centroids, assigmentsIdx, iter, convergence,nbSamples):
- # display centroids shape
- self.axes[0,0].cla()
- _range = max([np.amax(np.abs(centroids[:,:,1])),np.amax(np.abs(centroids[:,:,2]))])
- for i,c in enumerate(centroids):
- # plot palette
- x = c[:,1] # a -> x
- y = c[:,2] # b -> y
- pc = MPC2D.PointCloud2D(x,y)
- xx,yy = MPC2D.PointCloud2D.toXYarray(pc.convexHull())
- self.axes[0,0].plot(xx,yy,self.colors[i%len(self.colors)]+'o--', linewidth=1, markersize=3)
- self.axes[0,0].plot([- _range, _range],[0,0],'k', linewidth=0.1)
- self.axes[0,0].plot([0,0],[-_range,_range],'k', linewidth=0.1)
- self.axes[0,0].set_title("centroids (iter:"+str(iter)+")"+"['b', 'g','r','c','m','y','k']")
- # display image of palette
- self.axes[0,1].cla()
- self.axes[0,1].set_title("palettes("+str(centroids.shape[0])+")")
- palettes = []
- for c in centroids:
- palettes.append(MPAL.Palette("",c, MICS.ColorSpace.buildLab()))
- img = MPAL.Palette.createImageOfPalettes(palettes)
- img.plot(self.axes[0,1],title=False)
- # distance
- numberOfChange, numberOfRemain, meanDistance = convergence
- self.axes[1,0].cla()
- self.axes[1,0].set_title("mean distance:"+str(meanDistance[-1]/centroids.shape[0]))
- self.axes[1,0].plot(np.asarray(meanDistance)/centroids.shape[0],'r')
- # change
- self.axes[1,1].cla()
- self.axes[1,1].set_title("convergence index (% change)/ absolute:"+str(numberOfChange[-1]))
- self.axes[1,1].plot(100*np.asarray(numberOfChange)/nbSamples,'g')
- # pause
- plt.pause(0.05)
|