Palette.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # import
  2. # ------------------------------------------------------------------------------------------
  3. import os, colour, sklearn.cluster, skimage.color, copy
  4. import numpy as np
  5. # local
  6. import miam.image.Image as MIMG
  7. import miam.processing.ColorSpaceTransform as MCST
  8. import miam.image.ColorSpace as MICS
  9. import miam.math.Distance as MDST
  10. # ------------------------------------------------------------------------------------------
  11. # MIAM project 2020
  12. # ------------------------------------------------------------------------------------------
  13. # author: remi.cozot@univ-littoral.fr
  14. # ------------------------------------------------------------------------------------------
  15. class Palette(object):
  16. """
  17. class Palette:
  18. attribute(s):
  19. name: object name
  20. colorSpace: colorspace (colour.models.RGB_COLOURSPACES, Lab, etc.)
  21. nbColors: number of colors in the Palette
  22. colorData: array of pixels (np.ndarray)
  23. colors: np array colors[0:nbColors,0:2]
  24. sorted according to distance to black (in the palette colorSpace)
  25. """
  26. # constructor
  27. def __init__(self, name, colors, colorSpace):
  28. self.name = name
  29. self.colorSpace = colorSpace
  30. self.nbColors = colors.shape[0]
  31. self.colors = np.asarray(sorted(colors.tolist(), key = lambda u : np.sqrt(np.dot(u,u))))
  32. # DEBUG
  33. # print("Palette.__init__():", self.name)
  34. # methods
  35. def build(image, nbColors, fast=True, method='kmean-Lab', **kwargs):
  36. # according to method
  37. if method == 'kmean-Lab':
  38. # with 'remove black or not'
  39. if 'removeBlack' in kwargs: removeBlack = kwargs['removeBlack']
  40. else:
  41. print('removeBlack set to True')
  42. removeBlack = True
  43. # to Lab then to Vector
  44. if fast: image = image.smartResize()
  45. imageLab = MCST.ColorSpaceTransform().compute(image,dest='Lab')
  46. imgLabDataVector = MIMG.Image.array2vector(imageLab.colorData)
  47. if removeBlack:
  48. # k-means: nb cluster = nbColors + 1
  49. kmeans_cluster_Lab = sklearn.cluster.KMeans(n_clusters=nbColors+1)
  50. kmeans_cluster_Lab.fit(imgLabDataVector)
  51. cluster_centers_Lab = kmeans_cluster_Lab.cluster_centers_
  52. # remove darkness one
  53. idxLmin = np.argmin(cluster_centers_Lab[:,0]) # idx of darkness
  54. cluster_centers_Lab = np.delete(cluster_centers_Lab, idxLmin, axis=0) # remove min from cluster_centers_Lab
  55. else:
  56. # k-means: nb cluster = nbColors
  57. kmeans_cluster_Lab = sklearn.cluster.KMeans(n_clusters=nbColors)
  58. kmeans_cluster_Lab.fit(imgLabDataVector)
  59. cluster_centers_Lab = kmeans_cluster_Lab.cluster_centers_
  60. colors = cluster_centers_Lab
  61. else:
  62. print('unknow palette method')
  63. colors = None
  64. return Palette('Palette_'+image.name,colors, MICS.ColorSpace.buildLab())
  65. def createImageOfPalette(self, colorWidth=100):
  66. if self.colorSpace.name =='Lab':
  67. cRGB = MCST.Lab_to_sRGB(self.colors, apply_cctf_encoding=True)
  68. elif self.colorSpace.name=='sRGB':
  69. cRGB = self.colors
  70. width = colorWidth*cRGB.shape[0]
  71. height=colorWidth
  72. # return image
  73. img = np.ones((height,width,3))
  74. for i in range(cRGB.shape[0]):
  75. xMin= i*colorWidth
  76. xMax= xMin+colorWidth
  77. yMin=0
  78. yMax= colorWidth
  79. img[yMin:yMax, xMin:xMax,0]=cRGB[i,0]
  80. img[yMin:yMax, xMin:xMax,1]=cRGB[i,1]
  81. img[yMin:yMax, xMin:xMax,2]=cRGB[i,2]
  82. # colorData, name, type, linear, colorspace, scalingFactor
  83. # DEBUG
  84. # print("createImageOfPalette:", self.name)
  85. return MIMG.Image(img, self.name, MIMG.imageType.imageType.SDR, False, MICS.ColorSpace.buildsRGB(),1.0)
  86. # create image of multiple palettes
  87. def createImageOfPalettes(palettes, colorWidth=100):
  88. # return image
  89. width = colorWidth*palettes[0].colors.shape[0]
  90. height=colorWidth*len(palettes)
  91. img = np.ones((height,width,3))
  92. for j,palette in enumerate(palettes):
  93. if palette.colorSpace.name =='Lab':
  94. cRGB = MCST.Lab_to_sRGB(palette.colors, apply_cctf_encoding=True)
  95. elif palette.colorSpace.name=='sRGB':
  96. cRGB = palette.colors
  97. for i in range(cRGB.shape[0]):
  98. xMin= i*colorWidth
  99. xMax= xMin+colorWidth
  100. yMin= j*colorWidth
  101. yMax= yMin+colorWidth
  102. img[yMin:yMax, xMin:xMax,0]=cRGB[i,0]
  103. img[yMin:yMax, xMin:xMax,1]=cRGB[i,1]
  104. img[yMin:yMax, xMin:xMax,2]=cRGB[i,2]
  105. return MIMG.Image(img, "palettes", MIMG.imageType.imageType.SDR, False, MICS.ColorSpace.buildsRGB(),1.0)
  106. # magic operators
  107. def __add__(self, other):
  108. # print("DEBUG[Palette.__add__(",self.name,",",other.name,") ]")
  109. # create a copy
  110. res = copy.deepcopy(self)
  111. if isinstance(other,type(self)): # both are Palette
  112. # check if we can add
  113. # 1 - color space should be the same
  114. if self.colorSpace.name == other.colorSpace.name: # in the sma ecolor space
  115. # 2 - number of colors
  116. if self.nbColors == other.nbColors:
  117. # add
  118. res.colors= self.colors + other.colors
  119. res.colors = np.asarray(sorted(res.colors.tolist(), key = lambda u : np.sqrt(np.dot(u,u))))
  120. else:
  121. # error message
  122. print("WARNING[Palette.__add__(self,other): both image palette must have the same number of colors ! return a copy of ",self,"]")
  123. else:
  124. # error message
  125. print("WARNING[Palette.__add__(self,other): both image palette must have the same color space ! return a copy of ",self,"]")
  126. # return
  127. return res
  128. def __radd__(self, other):
  129. # print("DEBUG[Palette.__radd__(",self.name,",",other.name,") ]")
  130. return self.__add__(other)
  131. def __mul__ (self, other):
  132. # print("DEBUG[Palette.__mul__(",self.name,",",other,") ]")
  133. # create a copy
  134. res = copy.deepcopy(self)
  135. if isinstance(other, (int, float)):
  136. res.colors = self.colors * other
  137. else:
  138. # error message
  139. print("WARNING[Palette.__mul__(self,other): other must be int or float ! return a copy of ",self,"]")
  140. return res
  141. def __rmul__ (self, other):
  142. # print("DEBUG[Palette.__rmul__(",self.name,",",other,") ]")
  143. return self.__mul__(other)