LigneForce.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import numpy as np
  4. import os
  5. import sys
  6. import copy
  7. import time
  8. from skimage.io import imread, imsave
  9. from skimage.color import rgb2gray
  10. from skimage import exposure
  11. from skimage.transform import resize
  12. from PIL import Image, ImageDraw, ImageColor
  13. def gradient4D(img):
  14. (row, col) = img.shape
  15. g4d = np.zeros((row, col))
  16. for i in range(row-1):
  17. for j in range(col-1):
  18. g4d[i, j] = abs(img[i+1, j] - img[i, j] ) + abs(img[i, j+1] - img[i, j]) + abs(img[i+1, j-1] - img[i, j]) + abs(img[i+1, j+1] - img[i, j])
  19. return npNormalise(g4d)
  20. def npNormalise(xArray):
  21. XNorm = (xArray - xArray.min()) / (xArray.max() - xArray.min())
  22. return XNorm
  23. def getBaseLines(ws):
  24. baseLineList = []
  25. baseLineListIdex = []
  26. for i in range(0,ws-5):
  27. #if i == 1: break
  28. for j in range(5,ws): # cut bord
  29. #if j == 6: break
  30. # 1
  31. # -------------
  32. # | |
  33. # 4 | |2
  34. # | |
  35. # | |
  36. # --------------
  37. # 3
  38. # adjacent edge
  39. img12 = Image.new('F', (ws,ws),0)
  40. draw12 = ImageDraw.Draw(img12)
  41. draw12.line(xy=(i, 0, ws-1, j),
  42. fill=(1), width = 1)
  43. baseLineList.append(np.asarray(img12))
  44. baseLineListIdex.append(np.asarray([[i, 0],[ws-1, j]]))
  45. baseLineList.append(np.rot90(np.asarray(img12), 1, axes=(0, 1)))
  46. baseLineListIdex.append(np.asarray([[j, 0],[0, ws-1-i]]))
  47. baseLineList.append(np.rot90(np.asarray(img12), 2, axes=(0, 1)))
  48. baseLineListIdex.append(np.asarray([[0, ws-1-j],[ws-1-i, ws-1]]))
  49. baseLineList.append(np.rot90(np.asarray(img12), 3, axes=(0, 1)))
  50. baseLineListIdex.append(np.asarray([[ws-1, i],[ws-1-j, ws-1]]))
  51. # opposite side
  52. img13 = Image.new('F', (ws,ws),0)
  53. draw13 = ImageDraw.Draw(img13)
  54. draw13.line(xy=(i, 0, j, ws-1),
  55. fill=(1), width = 1)
  56. baseLineList.append(np.asarray(img13))
  57. baseLineListIdex.append(np.asarray([[i,0],[j, ws-1]]))
  58. baseLineList.append(np.asarray(img13).T)
  59. baseLineListIdex.append(np.asarray([[0,i],[ws-1, j]]))
  60. print('base line number :', len(baseLineList))
  61. return np.asarray(baseLineList), np.asarray(baseLineListIdex)
  62. def calculSlope(v1,v2):
  63. difX = v2[0] - v1[0]
  64. difY = v2[1] - v1[1]
  65. if difX == 0 :
  66. lk = 5*difY
  67. else:
  68. lk = difY / difX
  69. return lk
  70. def clusterRegion(centerLine, scale = 4, windowSize=64):
  71. H = windowSize
  72. W = windowSize
  73. sMask = np.zeros([H,W])
  74. ix = int(centerLine[0][0])
  75. iy = int(centerLine[0][1])
  76. pixelRange = int(min(H,W) / scale) # scale = 10
  77. k = calculSlope(centerLine[0],centerLine[1])
  78. if abs(k) > 1:
  79. while ix > 0:
  80. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  81. frontY = max(0, iy-pixelRange)
  82. backY = min(W,iy+pixelRange+1)
  83. sMask[ix, frontY:backY] = 1
  84. ix = ix - 1
  85. ix = int(centerLine[0][0])
  86. while ix < H:
  87. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  88. frontY = max(0, iy-pixelRange)
  89. backY = min(W,iy+pixelRange+1)
  90. sMask[ix, frontY:backY] = 1
  91. ix = ix + 1
  92. else:
  93. while iy > 0:
  94. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  95. frontX = max(0, ix-pixelRange)
  96. backX = min(H,ix+pixelRange+1)
  97. sMask[frontX:backX, iy] = 1
  98. iy = iy - 1
  99. iy = int(centerLine[0][1])
  100. while iy < W:
  101. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  102. frontX = max(0, ix-pixelRange)
  103. backX = min(H,ix+pixelRange+1)
  104. sMask[frontX:backX, iy] = 1
  105. iy = iy + 1
  106. return sMask
  107. def drawGroupLine(file, lineList, flineListCluster, scale, functionName, colorSTR, outputPath):
  108. c = ImageColor.colormap
  109. cList = list(c.items())
  110. (inputPath,inputFile) = os.path.split(file)
  111. print(inputPath)
  112. print(inputFile)
  113. with Image.open(file) as img4draw:
  114. w, h = img4draw.size
  115. scale = 1/64
  116. wScale = np.ceil(w*scale)
  117. hScale = np.ceil(h*scale)
  118. img1 = ImageDraw.Draw(img4draw)
  119. # for n,lineSet in enumerate(flineListCluster):
  120. # for [v1,v2],w,_ in lineSet[1:]:
  121. # img1.line([(v1[0]*wScale,v1[1]*hScale), (v2[0]*wScale,v2[1]*hScale)], fill = cList[int(n*2)+2][1], width = 4)
  122. for [v1,v2] in lineList:
  123. img1.line([(v1[0],v1[1]), (v2[0],v2[1])], fill = colorSTR, width = 8)
  124. img4draw.save(os.path.join(outputPath, inputFile[:-4] + '_' + str(functionName) + inputFile[-4:] ))
  125. def sortSlope(lineListArray):
  126. slopeList = []
  127. groupWeight = 0
  128. for l in lineListArray:
  129. if (l[0][1][0] - l[0][0][0] ) == 0:
  130. k = 1000
  131. else:
  132. k = (l[0][1][1] - l[0][0][1]) / (l[0][1][0] - l[0][0][0])
  133. slopeList.append(k)
  134. groupWeight = groupWeight + l[1]
  135. # print('weight = ', l[1])
  136. index = np.argsort(np.array(slopeList))
  137. #groupWeight = np.mean(groupWeight)
  138. return [lineListArray[int(np.median(index))][0], lineListArray[0][1], lineListArray[int(np.median(index))][2]]
  139. # index[len(index)//2]
  140. def forceLinesClusterIntegration(cluster):
  141. forceL = []
  142. for i,lineSet in enumerate(cluster):
  143. forceL.append(sortSlope(lineSet[1:]))
  144. return forceL
  145. def refine(lineList, fg, wg, iP, ws):
  146. wlist = []
  147. forceList = []
  148. for l in lineList:
  149. wlist.append(l[1])
  150. npwList = np.array(wlist)
  151. sortWeight = npwList.argsort()[::-1]
  152. for n,wId in enumerate(sortWeight):
  153. if n == 0:
  154. gMask = clusterRegion(lineList[wId][0], fg, ws)
  155. forceList.append([gMask, lineList[wId]])
  156. else:
  157. judge, forceList = judgeVertexAdvanced(lineList[wId][2], lineList[wId][0], npwList[wId], forceList, wg, iP )
  158. if judge == False:
  159. gMask = clusterRegion(lineList[wId][0], fg, ws)
  160. forceList.append([gMask, lineList[wId]])
  161. flList = forceLinesClusterIntegration(forceList)
  162. return flList
  163. def findSaliantLineCluster(gradient4d,allLines,allLinesIndex,ws, orgW, orgH ):
  164. weightList = []
  165. fineGrained0 = 8
  166. intePrec0 = 0.8
  167. forceLinesCluster = []
  168. for l in allLines:
  169. w = np.sum(gradient4d*l)
  170. weightList.append(w)
  171. npWeightList = np.array(weightList)
  172. sortWeightList = npWeightList.argsort()[::-1] # [::-1] inverse a list
  173. for n,wId in enumerate(sortWeightList[:300]):
  174. if n == 0:
  175. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  176. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  177. #
  178. else:
  179. # print(npWeightList[sortWeightList[n-1]])
  180. # print(npWeightList[wId])
  181. if (npWeightList[sortWeightList[n-1]] - npWeightList[wId]) > 10 :
  182. print('weight break------in line ', str(n))
  183. break
  184. judge, forceLinesCluster = judgeVertexAdvanced(allLines[wId], allLinesIndex[wId], npWeightList[wId], forceLinesCluster , 2, intePrec0)
  185. if judge == False:
  186. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  187. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  188. forceLinesRough = forceLinesClusterIntegration(forceLinesCluster)
  189. forceLinesRoughNew = forceLinesRough
  190. forceLinesRoughOrg = []
  191. fineGrained = 7
  192. wGrained = 3
  193. intePrec = 0.7
  194. for i in range(10000):
  195. if len(forceLinesRoughNew) == len(forceLinesRoughOrg):
  196. if (fineGrained <= 4 )and (wGrained >= 10) :
  197. print('break in loop ', str(i))
  198. break
  199. forceLinesRoughOrg = forceLinesRoughNew
  200. forceLinesRoughNew = refine(forceLinesRoughNew, fineGrained, wGrained, intePrec, ws)
  201. if fineGrained > 4:
  202. fineGrained = fineGrained-1
  203. if intePrec > 0.6:
  204. intePrec = intePrec - 0.05
  205. if wGrained < 10:
  206. wGrained = wGrained + 1
  207. forceLines = []
  208. for l in forceLinesRoughNew:
  209. forceLines.append(l[0])
  210. forceLines = np.array(forceLines)
  211. scale = 1/ws
  212. HWscale = np.array([[np.ceil(orgW*scale),np.ceil(orgH*scale)],
  213. [np.ceil(orgW*scale),np.ceil(orgH*scale)]])
  214. HWS = np.expand_dims(HWscale,0).repeat(forceLines.shape[0],axis=0)
  215. forceLines = forceLines*HWS
  216. return forceLines, forceLinesCluster,HWS
  217. def judgeVertexAdvanced(line1,v1, v1w, forceL, wSeuil = 4, intersectPrecent = 0.7):
  218. v1 = np.array(v1)
  219. newGroup = False
  220. for cl in forceL:
  221. vPossible = cl[0]*line1
  222. if np.sum(vPossible) > (np.sum(line1)*intersectPrecent):
  223. if abs(cl[1][1] - v1w) < wSeuil:
  224. cl.append([v1,v1w,line1])
  225. return True,forceL
  226. else:
  227. return True,forceL
  228. return False, forceL
  229. def getLeadingLine(imgpath, outPath):
  230. windowSize = 64
  231. allLines, allLinesIndex = getBaseLines(windowSize)
  232. img = imread(imgpath)
  233. print(img.shape)
  234. if (len(img.shape) != 3) or (img.shape[2] != 3):
  235. print('NOT a 3 channel image')
  236. else:
  237. orgH, orgW, _ = img.shape
  238. resizeImg = resize(img,(windowSize,windowSize))
  239. # Add contrast
  240. logImg = exposure.adjust_log(resizeImg, 1)
  241. grayImg = rgb2gray(logImg)
  242. gradient4d= gradient4D(grayImg)
  243. forceLines, forceLinesCluster, scale = findSaliantLineCluster(gradient4d,allLines,allLinesIndex,windowSize, orgW, orgH )
  244. drawGroupLine(imgpath, forceLines, forceLinesCluster, scale, 'forceLines', 'red', outPath)
  245. if __name__ == '__main__':
  246. print(sys.argv[1])
  247. print(sys.argv[2])
  248. INPUT_DIRECTORY = sys.argv[1]
  249. OUTPUT_DIRECTORY = sys.argv[2]
  250. if not (os.path.exists(OUTPUT_DIRECTORY)):
  251. print('Create output path:' , OUTPUT_DIRECTORY)
  252. os.makedirs(OUTPUT_DIRECTORY)
  253. start = time.time()
  254. if os.path.isfile( INPUT_DIRECTORY ):
  255. if INPUT_DIRECTORY.lower().endswith(('.jpg', '.png')) and not INPUT_DIRECTORY.lower().startswith('.'):
  256. getLeadingLine(INPUT_DIRECTORY,OUTPUT_DIRECTORY)
  257. elif os.path.isdir( INPUT_DIRECTORY ):
  258. files= os.listdir(INPUT_DIRECTORY)
  259. for i, file in enumerate(files):
  260. if file.lower().endswith(('.jpg', '.png')) and not file.lower().startswith('.'):
  261. fullpath = os.path.join(INPUT_DIRECTORY, file)
  262. getLeadingLine(fullpath,OUTPUT_DIRECTORY)
  263. end = time.time()
  264. print(' use time = ', str((end - start)/60.0), 'm')