Histogram.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # import
  2. # ------------------------------------------------------------------------------------------
  3. from .. import image
  4. import copy, functools
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. # miam import
  8. import miam.math.Distance
  9. # ------------------------------------------------------------------------------------------
  10. # MIAM project 2020
  11. # ------------------------------------------------------------------------------------------
  12. # author: remi.cozot@univ-littoral.fr
  13. # ------------------------------------------------------------------------------------------
  14. class Histogram(object):
  15. """description of class"""
  16. def __init__(self, histValue, edgeValue, name, channel, logSpace = False):
  17. """ constructor """
  18. self.name = name
  19. self.channel = channel
  20. self.histValue = histValue
  21. self.edgeValue = edgeValue
  22. self.logSpace = logSpace
  23. def __repr__(self):
  24. res = " Histogram{ name:" + self.name + "\n" + \
  25. " nb bins: " + str(len(self.histValue)) + "\n" + \
  26. " channel: " + str(self.channel.name)+"("+self.channel.colorSpace()+")" + "\n" + \
  27. " logSpace: " + str(self.logSpace) + "\n }"
  28. return res
  29. def __str__(self): return self.__repr__()
  30. def normalise(self, norm=None):
  31. """ normalise histogram according to norm='probability' | 'dot' """
  32. res = copy.deepcopy(self)
  33. if not norm: norm = 'probability'
  34. if norm == 'probability':
  35. sum = np.sum(res.histValue)
  36. res.histValue = res.histValue/sum
  37. elif norm == 'dot':
  38. dot2 = np.dot(res.histValue,res.histValue)
  39. res.histValue = res.histValue/np.sqrt(dot2)
  40. else:
  41. print("WARNING[miam.hisrogram.Histogram.normalise(",self.name,"): unknown norm:", norm,"!]")
  42. return res
  43. def build(img, channel, nbBins=100, range= None, logSpace = None):
  44. """
  45. build an Histogram object from image
  46. @params:
  47. img - Required : input image from witch hsitogram will be build (miam.image.Image.Image)
  48. channel - Required : image channel used to build histogram (miam.image.channel.channel)
  49. nbBins - Optional : histogram number of bins (Int)
  50. range - Optional : range of histogram, if None min max of channel (Float,Float)
  51. logSpace - Optional : compute in log space if True, if None guess from image (Boolean)
  52. """
  53. # logSpace
  54. if not logSpace: logSpace = 'auto'
  55. if isinstance(logSpace,str):
  56. if logSpace=='auto':
  57. if img.type == image.imageType.imageType.SDR : logSpace = False
  58. if img.type == image.imageType.imageType.HDR : logSpace = True
  59. elif not isinstance(logSpace,bool):
  60. logSpace = False
  61. channelVector = img.getChannelVector(channel)
  62. # range
  63. if not range:
  64. if channel.colorSpace() == 'Lab':
  65. range= (0.0,100.0)
  66. elif channel.colorSpace() == 'sRGB'or channel.colorSpace() == 'XYZ':
  67. range= (0.0,1.0)
  68. else:
  69. range= (0.0,1.0)
  70. print("WARNING[miam.hisrogram.Histogram.build(",img.name,"):",
  71. "colour space:",channel.colorSpace(), "not yet implemented > range(0.0,1.0)!]")
  72. # compute bins
  73. if logSpace:
  74. ((minR,maxR),(minG,maxG),(minB,maxB)) = img.getMinMaxPerChannel()
  75. minRGB = min(minR, minG, minB)
  76. maxRGB = max(maxR, maxG, maxB)
  77. #bins
  78. bins = 10 ** np.linspace(np.log10(minRGB), np.log10(maxRGB), nbBins+1)
  79. else:
  80. bins = np.linspace(range[0],range[1],nbBins+1)
  81. nphist, npedges = np.histogram(channelVector, bins)
  82. nphist = nphist/channelVector.shape
  83. return Histogram(nphist,
  84. npedges,
  85. img.name,
  86. #copy.deepcopy(img.colorSpace),
  87. channel,
  88. logSpace = logSpace
  89. )
  90. def plot(self, ax,color='r', shortName =True,title=True):
  91. if not color : color = 'r'
  92. ax.plot(self.edgeValue[1:],self.histValue,color)
  93. if self.logSpace: ax.set_xscale("log")
  94. name = self.name.split("/")[-1]+"(H("+self.channel.name+"))"if shortName else self.name+"(Histogram:"+self.channel.name+")"
  95. if title: ax.set_title(name)
  96. def scale(alpha,h):
  97. res = copy.deepcopy(h)
  98. res.histValue = res.histValue * alpha
  99. return res
  100. def add(hu,hv):
  101. res = copy.deepcopy(hu)
  102. # check edges
  103. if (hu.edgeValue==hv.edgeValue).all():
  104. # porceed to summation
  105. res.histValue = hu.histValue + hv.histValue
  106. else:
  107. # remap
  108. pass
  109. return res
  110. def computeDistance(hu,hv,distance=None):
  111. """ compute distance between histograms 'hu' and 'hv' according to distance 'distance' """
  112. if not distance:
  113. # default distance is cosine
  114. distance = miam.math.Distance.Distance(miam.math.Distance.cosineDistance)
  115. res = 1
  116. # some checking
  117. # histogram must have the same color space
  118. if hu.channel.name == hv.channel.name :
  119. # histogram must have the same number of bin
  120. if len(hu.histValue) == len(hv.histValue):
  121. res = distance.eval( hu.histValue, hv.histValue)
  122. else:
  123. print("WARNING[miam.hisrogram.Histogram.computeDistance(",str(hu),",",str(hv),"):", "have different length: return distance=1 !]")
  124. else :
  125. print("WARNING[miam.hisrogram.Histogram.computeDistance(",str(hu),",",str(hv),"):", "have different channels: return distance=1!]")
  126. return res
  127. def segmentPics(self, nbSegs=3):
  128. """
  129. segment histogram by pics
  130. """
  131. # local functions
  132. def isMaxWindow(a3): return ((a3[0] <= a3[1]) and (a3[1] >= a3[2]))
  133. def isMinWindow(a3): return ((a3[0] >= a3[1]) and (a3[1] <= a3[2]))
  134. def filterWindow(a3,weights): return (a3[0]*weights[0]+a3[1]*weights[1]+a3[2]*weights[2])/(weights[0]+weights[1]+weights[2])
  135. def filter(v, weights):
  136. res = copy.deepcopy(v)
  137. res[0] = filterWindow([v[0],v[0],v[1]],weights)
  138. res[-1] = filterWindow([v[-2],v[-1],v[-1]],weights)
  139. for i in range(1, len(v)-1): res[i] = filterWindow(v[(i-1):(i+2)],weights)
  140. return res
  141. def getSegmentBoundaries(v):
  142. seg = []
  143. seg.append(0) # add fist
  144. for i in range(1,len(v)-2):
  145. isMin = isMinWindow(v[i-1:i+2])
  146. isMax = isMaxWindow(v[i-1:i+2])
  147. if isMin and not isMax:
  148. seg.append(i)
  149. seg.append(len(v)-1) # add last
  150. return seg
  151. value = copy.deepcopy(self.histValue)
  152. while (len(getSegmentBoundaries(value))-1)> nbSegs:
  153. weights = [1,1,1]
  154. newValue = filter(value, weights)
  155. newnbs = len(getSegmentBoundaries(newValue))-1
  156. # next iter
  157. value = newValue
  158. # plot for debug
  159. #segmentBoundaries = getSegmentBoundaries(value)
  160. #plt.figure("segments")
  161. #plt.plot(self.histValue,'k--')
  162. #for i in range(len(segmentBoundaries)):
  163. # plt.plot(segmentBoundaries[i],self.histValue[segmentBoundaries[i]],'ro')
  164. #plt.show(block=True)
  165. #print(segmentBoundaries)
  166. # index of boundaries
  167. segmentBoundaries = getSegmentBoundaries(value)
  168. # values of boundaries
  169. segmentBoundariesValues = list(map(lambda i: self.edgeValue[i] if i< len(self.edgeValue)/2 else self.edgeValue[i+1] ,segmentBoundaries))
  170. return segmentBoundariesValues