LigneForce.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import numpy as np
  4. import os
  5. import sys
  6. import copy
  7. import time
  8. import argparse
  9. from skimage.io import imread, imsave
  10. from skimage.color import rgb2gray
  11. from skimage import exposure
  12. from skimage.transform import resize
  13. from PIL import Image, ImageDraw, ImageColor
  14. def gradient4D(img):
  15. (row, col) = img.shape
  16. g4d = np.zeros((row, col))
  17. for i in range(row-1):
  18. for j in range(col-1):
  19. g4d[i, j] = abs(img[i+1, j] - img[i, j] ) + abs(img[i, j+1] - img[i, j]) + abs(img[i+1, j-1] - img[i, j]) + abs(img[i+1, j+1] - img[i, j])
  20. return npNormalise(g4d)
  21. def npNormalise(xArray):
  22. XNorm = (xArray - xArray.min()) / (xArray.max() - xArray.min())
  23. return XNorm
  24. def getBaseLines(ws):
  25. baseLineList = []
  26. baseLineListIdex = []
  27. for i in range(0,ws-5):
  28. #if i == 1: break
  29. for j in range(5,ws): # cut bord
  30. #if j == 6: break
  31. # 1
  32. # -------------
  33. # | |
  34. # 4 | |2
  35. # | |
  36. # | |
  37. # --------------
  38. # 3
  39. # adjacent edge
  40. img12 = Image.new('F', (ws,ws),0)
  41. draw12 = ImageDraw.Draw(img12)
  42. draw12.line(xy=(i, 0, ws-1, j),
  43. fill=(1), width = 1)
  44. baseLineList.append(np.asarray(img12))
  45. baseLineListIdex.append(np.asarray([[i, 0],[ws-1, j]]))
  46. baseLineList.append(np.rot90(np.asarray(img12), 1, axes=(0, 1)))
  47. baseLineListIdex.append(np.asarray([[j, 0],[0, ws-1-i]]))
  48. baseLineList.append(np.rot90(np.asarray(img12), 2, axes=(0, 1)))
  49. baseLineListIdex.append(np.asarray([[0, ws-1-j],[ws-1-i, ws-1]]))
  50. baseLineList.append(np.rot90(np.asarray(img12), 3, axes=(0, 1)))
  51. baseLineListIdex.append(np.asarray([[ws-1, i],[ws-1-j, ws-1]]))
  52. # opposite side
  53. img13 = Image.new('F', (ws,ws),0)
  54. draw13 = ImageDraw.Draw(img13)
  55. draw13.line(xy=(i, 0, j, ws-1),
  56. fill=(1), width = 1)
  57. baseLineList.append(np.asarray(img13))
  58. baseLineListIdex.append(np.asarray([[i,0],[j, ws-1]]))
  59. baseLineList.append(np.asarray(img13).T)
  60. baseLineListIdex.append(np.asarray([[0,i],[ws-1, j]]))
  61. print('base line number :', len(baseLineList))
  62. return np.asarray(baseLineList), np.asarray(baseLineListIdex)
  63. def calculSlope(v1,v2):
  64. difX = v2[0] - v1[0]
  65. difY = v2[1] - v1[1]
  66. if difX == 0 :
  67. lk = 5*difY
  68. else:
  69. lk = difY / difX
  70. return lk
  71. def clusterRegion(centerLine, scale = 4, windowSize=64):
  72. H = windowSize
  73. W = windowSize
  74. sMask = np.zeros([H,W])
  75. ix = int(centerLine[0][0])
  76. iy = int(centerLine[0][1])
  77. pixelRange = int(min(H,W) / scale) # scale = 10
  78. k = calculSlope(centerLine[0],centerLine[1])
  79. if abs(k) > 1:
  80. while ix > 0:
  81. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  82. frontY = max(0, iy-pixelRange)
  83. backY = min(W,iy+pixelRange+1)
  84. sMask[ix, frontY:backY] = 1
  85. ix = ix - 1
  86. ix = int(centerLine[0][0])
  87. while ix < H:
  88. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  89. frontY = max(0, iy-pixelRange)
  90. backY = min(W,iy+pixelRange+1)
  91. sMask[ix, frontY:backY] = 1
  92. ix = ix + 1
  93. else:
  94. while iy > 0:
  95. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  96. frontX = max(0, ix-pixelRange)
  97. backX = min(H,ix+pixelRange+1)
  98. sMask[frontX:backX, iy] = 1
  99. iy = iy - 1
  100. iy = int(centerLine[0][1])
  101. while iy < W:
  102. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  103. frontX = max(0, ix-pixelRange)
  104. backX = min(H,ix+pixelRange+1)
  105. sMask[frontX:backX, iy] = 1
  106. iy = iy + 1
  107. return sMask
  108. def drawGroupLine(file, lineList, flineListCluster, scale, functionName, colorSTR, outputPath):
  109. c = ImageColor.colormap
  110. cList = list(c.items())
  111. (inputPath,inputFile) = os.path.split(file)
  112. print(inputPath)
  113. print(inputFile)
  114. with Image.open(file) as img4draw:
  115. w, h = img4draw.size
  116. scale = 1/64
  117. wScale = np.ceil(w*scale)
  118. hScale = np.ceil(h*scale)
  119. img1 = ImageDraw.Draw(img4draw)
  120. # for n,lineSet in enumerate(flineListCluster):
  121. # for [v1,v2],w,_ in lineSet[1:]:
  122. # img1.line([(v1[0]*wScale,v1[1]*hScale), (v2[0]*wScale,v2[1]*hScale)], fill = cList[int(n*2)+2][1], width = 4)
  123. for [v1,v2] in lineList:
  124. img1.line([(v1[0],v1[1]), (v2[0],v2[1])], fill = colorSTR, width = 8)
  125. img4draw.save(os.path.join(outputPath, inputFile[:-4] + '_' + str(functionName) + inputFile[-4:] ))
  126. def sortSlope(lineListArray):
  127. slopeList = []
  128. groupWeight = 0
  129. for l in lineListArray:
  130. if (l[0][1][0] - l[0][0][0] ) == 0:
  131. k = 1000
  132. else:
  133. k = (l[0][1][1] - l[0][0][1]) / (l[0][1][0] - l[0][0][0])
  134. slopeList.append(k)
  135. groupWeight = groupWeight + l[1]
  136. # print('weight = ', l[1])
  137. index = np.argsort(np.array(slopeList))
  138. #groupWeight = np.mean(groupWeight)
  139. return [lineListArray[int(np.median(index))][0], lineListArray[0][1], lineListArray[int(np.median(index))][2]]
  140. # index[len(index)//2]
  141. def forceLinesClusterIntegration(cluster):
  142. forceL = []
  143. for i,lineSet in enumerate(cluster):
  144. forceL.append(sortSlope(lineSet[1:]))
  145. return forceL
  146. def refine(lineList, fg, wg, iP, ws):
  147. wlist = []
  148. forceList = []
  149. for l in lineList:
  150. wlist.append(l[1])
  151. npwList = np.array(wlist)
  152. sortWeight = npwList.argsort()[::-1]
  153. for n,wId in enumerate(sortWeight):
  154. if n == 0:
  155. gMask = clusterRegion(lineList[wId][0], fg, ws)
  156. forceList.append([gMask, lineList[wId]])
  157. else:
  158. judge, forceList = judgeVertexAdvanced(lineList[wId][2], lineList[wId][0], npwList[wId], forceList, wg, iP )
  159. if judge == False:
  160. gMask = clusterRegion(lineList[wId][0], fg, ws)
  161. forceList.append([gMask, lineList[wId]])
  162. flList = forceLinesClusterIntegration(forceList)
  163. return flList
  164. def findSaliantLineCluster(gradient4d,allLines,allLinesIndex,ws, orgW, orgH ):
  165. weightList = []
  166. fineGrained0 = 8
  167. intePrec0 = 0.8
  168. forceLinesCluster = []
  169. for l in allLines:
  170. w = np.sum(gradient4d*l)
  171. weightList.append(w)
  172. npWeightList = np.array(weightList)
  173. sortWeightList = npWeightList.argsort()[::-1] # [::-1] inverse a list
  174. for n,wId in enumerate(sortWeightList[:300]):
  175. if n == 0:
  176. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  177. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  178. #
  179. else:
  180. # print(npWeightList[sortWeightList[n-1]])
  181. # print(npWeightList[wId])
  182. if (npWeightList[sortWeightList[n-1]] - npWeightList[wId]) > 10 :
  183. print('weight break------in line ', str(n))
  184. break
  185. judge, forceLinesCluster = judgeVertexAdvanced(allLines[wId], allLinesIndex[wId], npWeightList[wId], forceLinesCluster , 2, intePrec0)
  186. if judge == False:
  187. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  188. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  189. forceLinesRough = forceLinesClusterIntegration(forceLinesCluster)
  190. forceLinesRoughNew = forceLinesRough
  191. forceLinesRoughOrg = []
  192. fineGrained = 7
  193. wGrained = 3
  194. intePrec = 0.7
  195. for i in range(10000):
  196. if len(forceLinesRoughNew) == len(forceLinesRoughOrg):
  197. if (fineGrained <= 4 )and (wGrained >= 10) :
  198. print('break in loop ', str(i))
  199. break
  200. forceLinesRoughOrg = forceLinesRoughNew
  201. forceLinesRoughNew = refine(forceLinesRoughNew, fineGrained, wGrained, intePrec, ws)
  202. if fineGrained > 4:
  203. fineGrained = fineGrained-1
  204. if intePrec > 0.6:
  205. intePrec = intePrec - 0.05
  206. if wGrained < 10:
  207. wGrained = wGrained + 1
  208. forceLines = []
  209. for l in forceLinesRoughNew:
  210. forceLines.append(l[0])
  211. forceLines = np.array(forceLines)
  212. scale = 1/ws
  213. HWscale = np.array([[np.ceil(orgW*scale),np.ceil(orgH*scale)],
  214. [np.ceil(orgW*scale),np.ceil(orgH*scale)]])
  215. HWS = np.expand_dims(HWscale,0).repeat(forceLines.shape[0],axis=0)
  216. forceLines = forceLines*HWS
  217. return forceLines, forceLinesCluster,HWS
  218. def judgeVertexAdvanced(line1,v1, v1w, forceL, wSeuil = 4, intersectPrecent = 0.7):
  219. v1 = np.array(v1)
  220. newGroup = False
  221. for cl in forceL:
  222. vPossible = cl[0]*line1
  223. if np.sum(vPossible) > (np.sum(line1)*intersectPrecent):
  224. if abs(cl[1][1] - v1w) < wSeuil:
  225. cl.append([v1,v1w,line1])
  226. return True,forceL
  227. else:
  228. return True,forceL
  229. return False, forceL
  230. def getLeadingLine(imgpath, outPath):
  231. windowSize = 64
  232. allLines, allLinesIndex = getBaseLines(windowSize)
  233. img = imread(imgpath)
  234. print(img.shape)
  235. if (len(img.shape) != 3) or (img.shape[2] != 3):
  236. print('NOT a 3 channel image')
  237. else:
  238. orgH, orgW, _ = img.shape
  239. resizeImg = resize(img,(windowSize,windowSize))
  240. # Add contrast
  241. logImg = exposure.adjust_log(resizeImg, 1)
  242. grayImg = rgb2gray(logImg)
  243. gradient4d= gradient4D(grayImg)
  244. forceLines, forceLinesCluster, scale = findSaliantLineCluster(gradient4d,allLines,allLinesIndex,windowSize, orgW, orgH )
  245. drawGroupLine(imgpath, forceLines, forceLinesCluster, scale, 'forceLines', 'red', outPath)
  246. if __name__ == '__main__':
  247. parser = argparse.ArgumentParser(description='Find the Probable leading lines, please provide 1) your input image path or a folder path for input images, and 2) the output folder you wish.')
  248. parser.add_argument('input', type=str, help='The path for your input image or folder')
  249. parser.add_argument('-o', '--output', type=str, default='./OUTPUT', help='The path for your output folder ')
  250. args = parser.parse_args()
  251. INPUT_DIRECTORY = args.input
  252. OUTPUT_DIRECTORY = args.output
  253. print('INPUT : ', INPUT_DIRECTORY)
  254. print('OUTPUT : ', OUTPUT_DIRECTORY)
  255. if not (os.path.exists(OUTPUT_DIRECTORY)):
  256. print('Create output path:' , OUTPUT_DIRECTORY)
  257. os.makedirs(OUTPUT_DIRECTORY)
  258. start = time.time()
  259. if os.path.isfile( INPUT_DIRECTORY ):
  260. if INPUT_DIRECTORY.lower().endswith(('.jpg', '.png')) and not INPUT_DIRECTORY.lower().startswith('.'):
  261. getLeadingLine(INPUT_DIRECTORY,OUTPUT_DIRECTORY)
  262. elif os.path.isdir( INPUT_DIRECTORY ):
  263. files= os.listdir(INPUT_DIRECTORY)
  264. for i, file in enumerate(files):
  265. if file.lower().endswith(('.jpg', '.png')) and not file.lower().startswith('.'):
  266. fullpath = os.path.join(INPUT_DIRECTORY, file)
  267. getLeadingLine(fullpath,OUTPUT_DIRECTORY)
  268. end = time.time()
  269. print(' use time = ', str((end - start)/60.0), 'm')