kmeans.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # import
  2. # ------------------------------------------------------------------------------------------
  3. import os
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import copy
  7. # multiprocessing and functools
  8. # import multiprocessing as mp
  9. from pathos.multiprocessing import ProcessingPool as Pool
  10. from functools import partial
  11. # miam import
  12. import miam.utils
  13. import miam.image.Image as MIMG
  14. import miam.processing.ColorSpaceTransform as MCST
  15. import miam.histogram.Histogram as MHIST
  16. import miam.math as MMATH
  17. # ------------------------------------------------------------------------------------------
  18. # MIAM project 2020
  19. # ------------------------------------------------------------------------------------------
  20. # author: remi.cozot@univ-littoral.fr
  21. # ------------------------------------------------------------------------------------------
  22. class kmeans(object):
  23. # set a random seed for reproductability
  24. #np.random.seed(1968)
  25. np.random.seed(1968)
  26. def __init__(self, distance, normalize):
  27. # distance use for computation between samples
  28. self.distance = distance
  29. self.normalize = normalize
  30. def MPassignSamplesToCentroids(self,centroids, samples, previousAssigmentIdx):
  31. # parallel function
  32. def assignSampleToCentroid(previousAssigmentIdx, centroids, distance,isamp):
  33. """ return (jdist, samp, i, dist, change, remain) """
  34. # recovering data from parameters
  35. i, samp = isamp
  36. # init data in centroids loop
  37. dist = 0.0 # distance to centroid
  38. jdist = 0 # minimal distance
  39. for j,cent in enumerate(centroids):
  40. #compute distance samps[i] et cents[j]
  41. if j==0:
  42. # first iteration
  43. dist = distance.eval(samp,cent)
  44. jdist =0
  45. else:
  46. # other iteration
  47. d= distance.eval(samp,cent)
  48. # compare dist to current minimal dist
  49. if d<dist:
  50. dist =d
  51. jdist=j
  52. # end for cents
  53. if not i in previousAssigmentIdx[jdist]:
  54. change, remain = 1, 0
  55. else:
  56. change, remain = 0, 1
  57. #return data
  58. return (jdist, samp, i, dist, change, remain)
  59. # create partial to avoid multiple input parameters
  60. pAss = partial(assignSampleToCentroid, previousAssigmentIdx, centroids, self.distance)
  61. # prepare input
  62. isamps = list(enumerate(samples))
  63. # parallel computation
  64. _pool = Pool()
  65. rawResults = _pool.map(pAss, isamps) # launching on all samples
  66. results =list(rawResults)
  67. # formatting results
  68. numberSamples = len(samples)
  69. numberOfChange, numberOfRemain = 0,0 # number of samples that change of/remain in centroid
  70. sumDistance = 0 # sum of distance to centroids
  71. assigments, assigmentsIdx= [[]], [[]] # return list
  72. for i in range(len(centroids)-1):
  73. assigments.append([])
  74. assigmentsIdx.append([])
  75. for res in results:
  76. jdist, samp, i, dist, change, remain = res
  77. assigments[jdist].append(samp)
  78. assigmentsIdx[jdist].append(i)
  79. sumDistance += dist
  80. numberOfChange += change
  81. numberOfRemain += remain
  82. # add data to follow convergence
  83. conv = (numberOfChange,numberOfRemain, sumDistance/numberSamples)
  84. return (assigments, assigmentsIdx, conv)
  85. def assignSamplesToCentroids(self,centroids, samples, previousAssigmentIdx):
  86. # assChanged set to True if at least one assignment changes of centroid
  87. assChanged = False
  88. # number of samples that change of/remain in centroid
  89. numberOfChange, numberOfRemain = 0,0
  90. # sum of distance to centroids
  91. sumDistance = 0
  92. # return list
  93. assigments, assigmentsIdx= [[]], [[]]
  94. for i in range(len(centroids)-1):
  95. assigments.append([])
  96. assigmentsIdx.append([])
  97. numberSamples = len(samples)
  98. # DEBUG
  99. print("")
  100. # END DEBUG
  101. for i,samp in enumerate(samples):
  102. # DEBUG
  103. print("\r kmeans.assignSamplesToCentroids: ",i,"/",numberSamples,"[remains:",numberOfRemain,"][changes:",numberOfChange,"][mean distance:",sumDistance/(i+1),"] ", end = '\r')
  104. # END DEBUG
  105. dist = 0.0
  106. jdist = 0
  107. for j,cent in enumerate(centroids):
  108. #compute distance samps[i] et cents[j]
  109. if j==0:
  110. # first iteration
  111. dist = self.distance.eval(samp,cent)
  112. jdist =0
  113. else:
  114. # other iteration
  115. d= self.distance.eval(samp,cent)
  116. # compare dist to current minimal dist
  117. if d<dist:
  118. dist =d
  119. jdist=j
  120. # end if
  121. #end if
  122. # end for cents
  123. assigments[jdist].append(samp)
  124. assigmentsIdx[jdist].append(i)
  125. sumDistance += dist
  126. if not i in previousAssigmentIdx[jdist]:
  127. assChanged = True
  128. numberOfChange += 1
  129. else:
  130. numberOfRemain +=1
  131. # end for samp
  132. # add dta to follow convergence
  133. conv = (numberOfChange,numberOfRemain, sumDistance/numberSamples)
  134. return (assigments, assigmentsIdx, conv)
  135. def averageAssigments(self, assignements):
  136. # debug
  137. # print("kmeans.averageAssigment>> start")
  138. # end debug
  139. # return list
  140. assigmentAverage = [[]]
  141. for i in range(len(assignements)-1): assigmentAverage.append([])
  142. for i,assigment_i in enumerate(assignements):
  143. # debug
  144. # print("kmeans.averageAssigment::sassigment_i.size>>",np.asarray(assigment_i).size)
  145. # end debug
  146. if np.asarray(assigment_i).size >0 :
  147. assavi=np.mean(np.asarray(assigment_i),axis=0)
  148. assavi = self.normalize.eval(assavi)
  149. assigmentAverage[i]= assavi
  150. # debug
  151. # print("kmeans.averageAssigment>> end")
  152. # end debug
  153. return assigmentAverage
  154. def kmeans(self,samples, nbClusters, nbIter, display = None, initRange=None, multiPro=False):
  155. """
  156. method keams:
  157. attribute(s):
  158. self: insytance method
  159. samples: samples to cluster (np.ndarray)
  160. samples.shape[0] : number of samples
  161. samples.shape[1:]: dimension of sample, if samples.shape[1:] is int then sample are vectors and amples.shape[1:] is vector size
  162. if samples.shape[1:] is tuple are matrices or tensors and amples.shape[1:] is matrix/tensor dimension
  163. exemple samples.shape[1:] =(5,3) sample is matrix 5x3
  164. nbClusters: number of cluster to compute (int)
  165. nbIter: number of iteration of k-means (int)
  166. display: class with init and plot method, plot is called at each iteration !!!!!!!!!!!!!!!!!!!! require refactoring !!!!!!!!!!!!!!!!!!!!
  167. initRange: None or list of tuple
  168. random centroids are used to init the k-means
  169. centroids are stored in a np.ndarray which shape is (number of clusters, *samples.shape[1:])
  170. samples.shape[-1] is size of "atomic" vector for example is centroids is 5x3 it means 5 vector of size 3 (the case for color palettes)
  171. init should be [{minRange0,maxRange0}, {minRange1,maxRange1}, {minRange2,maxRange2}]*5 with initRange = [(minRange0,maxRange0), (minRange1,maxRange1), (minRange2,maxRange2)]
  172. initRange=None range in 0..1
  173. """
  174. # dimension of centroids
  175. dimSamp = samples.shape
  176. dimCentroid = dimSamp[1:] # palette
  177. # dimCentroid = dimSamp[1] # histo
  178. # init centroids
  179. if isinstance(dimSamp,tuple):
  180. u = np.random.rand(nbClusters,*dimCentroid)
  181. if initRange:
  182. minRange, maxRange = [],[]
  183. for _range in initRange:
  184. minr, maxr = _range
  185. minRange.append(minr)
  186. maxRange.append(maxr)
  187. v = (1-u)*np.asarray(minRange) + u*np.asarray(maxRange)
  188. else:
  189. v =u
  190. centroids = self.normalize.evals(v) # palette
  191. else: #integer
  192. u = np.random.rand(nbClusters,dimCentroid)
  193. if initRange:
  194. minRange, maxRange = [],[]
  195. for _range in initRange:
  196. minr, maxr = _range
  197. minRange.append(minr)
  198. maxRange.append(maxr)
  199. v = (1-u)*np.asarray(minRange) + u*np.asarray(maxRange)
  200. else:
  201. v =u
  202. centroids = self.normalize.eval(v) # histo
  203. # return assigments and assigments index
  204. previousAssigmentsIdx = [[]]
  205. assigments,assigmentsIdx = [[]], [[]]
  206. for i in range(nbClusters-1):
  207. assigments.append([])
  208. assigmentsIdx.append([])
  209. previousAssigmentsIdx.append([])
  210. # convergence
  211. changes = []
  212. remains=[]
  213. meanDistances= []
  214. # MAIN LOOP
  215. # -----------------------------------------------------------------------------------------
  216. # for iter in range(nbIter)
  217. for iter in range(nbIter):
  218. print("\r","kmeans(iteration): ",iter,"/",nbIter,":",iter*100//nbIter," % ",end = '\r')
  219. # assign sample to centoids
  220. if multiPro:
  221. (assigments,assigmentsIdx,conv) = self.MPassignSamplesToCentroids(centroids,samples, previousAssigmentsIdx)
  222. else:
  223. (assigments,assigmentsIdx,conv) = self.assignSamplesToCentroids(centroids,samples, previousAssigmentsIdx)
  224. # recover data from results
  225. change, remain, meanDist = conv
  226. changes.append(change)
  227. remains.append(remain)
  228. meanDistances.append(meanDist)
  229. # compute mean of (assigment) cluster
  230. assigmentsAverage = self.averageAssigments(assigments)
  231. # update centroids and stopping criteria
  232. canBreak= True
  233. for i,ass_av in enumerate(assigmentsAverage):
  234. emptyAss = True
  235. if isinstance(ass_av,np.ndarray):
  236. if (ass_av.size!=0):
  237. emptyAss = False
  238. centroids[i] = ass_av
  239. if emptyAss:
  240. canBreak= False
  241. if isinstance(dimSamp,tuple):
  242. u = np.random.rand(1,*dimCentroid)
  243. if initRange:
  244. minRange, maxRange = [],[]
  245. for _range in initRange:
  246. minr, maxr = _range
  247. minRange.append(minr)
  248. maxRange.append(maxr)
  249. v = (1-u)*np.asarray(minRange) + u*np.asarray(maxRange)
  250. else: v =u
  251. newcentroid = self.normalize.evals(v) # palette
  252. else: # histogram
  253. u = np.random.rand(nbClusters,dimCentroid)
  254. if initRange:
  255. minRange, maxRange = [],[]
  256. for _range in initRange:
  257. minr, maxr = _range
  258. minRange.append(minr)
  259. maxRange.append(maxr)
  260. v = (1-u)*np.asarray(minRange) + u*np.asarray(maxRange)
  261. else:
  262. v =u
  263. newcentroid = self.normalize.eval(v) # histo
  264. centroids[i] = newcentroid
  265. print("")
  266. print("WARNING[miam.classification.kmeans(): (iteration:",iter,"/centroid:",i,"): no assigment! >> compute new centroid]")
  267. # display
  268. if display: display.plot(centroids, assigmentsIdx, iter,(changes,remains,meanDistances), len(samples))
  269. # memory
  270. previousAssigmentsIdx = copy.deepcopy(assigmentsIdx)
  271. # break iteration if change=0
  272. if (change==0) and(canBreak): break
  273. # -----------------------------------------------------------------------------------------
  274. # return centroids
  275. print(" ")
  276. return (centroids,assigments,assigmentsIdx)