leadingLines.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # Copyright (C) 2023 ZHANG Jing, Université du Littoral Côte d'Opale
  2. #
  3. # This program is free software: you can redistribute it and/or modify
  4. # it under the terms of the GNU General Public License as published by
  5. # the Free Software Foundation, either version 3 of the License, or
  6. # (at your option) any later version.
  7. #
  8. # This program is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU General Public License for more details.
  12. #
  13. # You should have received a copy of the GNU General Public License
  14. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. #!/usr/bin/env python
  16. # coding: utf-8
  17. import numpy as np
  18. import os
  19. import sys
  20. import copy
  21. import time
  22. import argparse
  23. from skimage.io import imread, imsave
  24. from skimage.color import rgb2gray
  25. from skimage import exposure
  26. from skimage.transform import resize
  27. from PIL import Image, ImageDraw, ImageColor
  28. ## compute image gradient map, take the absolute difference values in 4 dimensions for each pixel
  29. # --------------------------
  30. # | |
  31. # | i, j i, j+1 |
  32. # | i+1,j-1 i+1,j i+1,j+1 |
  33. # --------------------------
  34. def gradient4D(img):
  35. (row, col) = img.shape
  36. g4d = np.zeros((row, col))
  37. for i in range(row-1):
  38. for j in range(col-1):
  39. g4d[i, j] = abs(img[i+1, j] - img[i, j] ) + abs(img[i, j+1] - img[i, j]) + abs(img[i+1, j-1] - img[i, j]) + abs(img[i+1, j+1] - img[i, j])
  40. return npNormalise(g4d)
  41. # normalise values to [0, 1]
  42. def npNormalise(xArray):
  43. XNorm = (xArray - xArray.min()) / (xArray.max() - xArray.min())
  44. return XNorm
  45. # compute all potential lines for a square of size ws*ws
  46. # input : length of the square
  47. # output : 1. a set of binary images, each image contains only one line
  48. # 2. a set containing the coordinates of the start points and the end points of the line
  49. def getBaseLines(ws):
  50. baseLineList = []
  51. baseLineListIdex = []
  52. # Skip 5 pixels to avoid lines near edges
  53. for i in range(0,ws-5):
  54. for j in range(5,ws): # cut bord
  55. # 1
  56. # -------------
  57. # | |
  58. # 4 | |2
  59. # | |
  60. # | |
  61. # --------------
  62. # 3
  63. # adjacent edge (like edge 1 and edge 2, edge 1 and edge 4)
  64. img12 = Image.new('F', (ws,ws),0)
  65. draw12 = ImageDraw.Draw(img12)
  66. # lines that the start point in edge 1 and the end point in edge 2
  67. draw12.line(xy=(i, 0, ws-1, j),
  68. fill=(1), width = 1)
  69. baseLineList.append(np.asarray(img12))
  70. baseLineListIdex.append(np.asarray([[i, 0],[ws-1, j]]))
  71. # lines that the start point in edge 4 and the end point in edge 1
  72. baseLineList.append(np.rot90(np.asarray(img12), 1, axes=(0, 1)))
  73. baseLineListIdex.append(np.asarray([[j, 0],[0, ws-1-i]]))
  74. # lines that the start point in edge 3 and the end point in edge 4
  75. baseLineList.append(np.rot90(np.asarray(img12), 2, axes=(0, 1)))
  76. baseLineListIdex.append(np.asarray([[0, ws-1-j],[ws-1-i, ws-1]]))
  77. # lines that the start point in edge 2 and the end point in edge 3
  78. baseLineList.append(np.rot90(np.asarray(img12), 3, axes=(0, 1)))
  79. baseLineListIdex.append(np.asarray([[ws-1, i],[ws-1-j, ws-1]]))
  80. # opposite side
  81. img13 = Image.new('F', (ws,ws),0)
  82. draw13 = ImageDraw.Draw(img13)
  83. # lines that the start point in edge 4 and the end point in edge 2
  84. draw13.line(xy=(i, 0, j, ws-1),
  85. fill=(1), width = 1)
  86. baseLineList.append(np.asarray(img13))
  87. baseLineListIdex.append(np.asarray([[i,0],[j, ws-1]]))
  88. # lines that the start point in edge 1 and the end point in edge 3
  89. baseLineList.append(np.asarray(img13).T)
  90. baseLineListIdex.append(np.asarray([[0,i],[ws-1, j]]))
  91. print('base line number :', len(baseLineList))
  92. return np.asarray(baseLineList), np.asarray(baseLineListIdex)
  93. # Calculate the slope of the line formed by vertex1 and vertex2
  94. def calculSlope(v1,v2):
  95. difX = v2[0] - v1[0]
  96. difY = v2[1] - v1[1]
  97. if difX == 0 :
  98. lk = 5*difY
  99. else:
  100. lk = difY / difX
  101. return lk
  102. # Compute the band mask of a line
  103. def clusterRegion(centerLine, scale = 4, windowSize=64):
  104. H = windowSize
  105. W = windowSize
  106. sMask = np.zeros([H,W])
  107. ix = int(centerLine[0][0])
  108. iy = int(centerLine[0][1])
  109. # calculate the width of band mask
  110. pixelRange = int(min(H,W) / scale) # scale = 10
  111. # get the slope of line
  112. k = calculSlope(centerLine[0],centerLine[1])
  113. if abs(k) > 1:
  114. while ix > 0:
  115. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  116. frontY = max(0, iy-pixelRange)
  117. backY = min(W,iy+pixelRange+1)
  118. sMask[ix, frontY:backY] = 1
  119. ix = ix - 1
  120. ix = int(centerLine[0][0])
  121. while ix < H:
  122. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  123. frontY = max(0, iy-pixelRange)
  124. backY = min(W,iy+pixelRange+1)
  125. sMask[ix, frontY:backY] = 1
  126. ix = ix + 1
  127. else:
  128. while iy > 0:
  129. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  130. frontX = max(0, ix-pixelRange)
  131. backX = min(H,ix+pixelRange+1)
  132. sMask[frontX:backX, iy] = 1
  133. iy = iy - 1
  134. iy = int(centerLine[0][1])
  135. while iy < W:
  136. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  137. frontX = max(0, ix-pixelRange)
  138. backX = min(H,ix+pixelRange+1)
  139. sMask[frontX:backX, iy] = 1
  140. iy = iy + 1
  141. return sMask
  142. # fonction for display all the lines
  143. def drawGroupLine(file, lineList, flineListCluster, scale, functionName, colorSTR, outputPath):
  144. c = ImageColor.colormap
  145. cList = list(c.items())
  146. (inputPath,inputFile) = os.path.split(file)
  147. print(inputPath)
  148. print(inputFile)
  149. # read the orignal file for draw
  150. with Image.open(file) as img4draw:
  151. w, h = img4draw.size
  152. if w >h: # add lineWidth to adapt the visibility of drawing results to different image sizes
  153. lineWidth = int(h/40)
  154. else:
  155. lineWidth = int(w/40)
  156. scale = 1/64
  157. wScale = np.ceil(w*scale)
  158. hScale = np.ceil(h*scale)
  159. img1 = ImageDraw.Draw(img4draw)
  160. # draw all the centers
  161. for [v1,v2] in lineList:
  162. img1.line([(v1[0],v1[1]), (v2[0],v2[1])], fill = colorSTR, width = lineWidth)
  163. img4draw.save(os.path.join(outputPath, inputFile[:-4] + '_' + str(functionName) + inputFile[-4:] ))
  164. # sort the slope of lines, inutile for version 0
  165. def sortSlope(lineListArray):
  166. print('lineListArray', lineListArray)
  167. slopeList = []
  168. groupWeight = 0
  169. for l in lineListArray:
  170. if (l[0][1][0] - l[0][0][0] ) == 0:
  171. k = 1000
  172. else:
  173. k = (l[0][1][1] - l[0][0][1]) / (l[0][1][0] - l[0][0][0])
  174. slopeList.append(k)
  175. groupWeight = groupWeight + l[1]
  176. print('slopeList : ', slopeList)
  177. index = np.argsort(np.array(slopeList))
  178. print('sortSlope index : ', index)
  179. print('sortSlope index median : ', int(np.median(index)))
  180. return [lineListArray[int(np.median(index))][0], lineListArray[0][1], lineListArray[int(np.median(index))][2]]
  181. # extraction the center of each group
  182. def forceLinesClusterIntegration(cluster):
  183. forceL = []
  184. for i,lineSet in enumerate(cluster):
  185. forceL.append(lineSet[1])
  186. return forceL
  187. # refine the cluster result
  188. def refine(lineList, fg, wg, iP, ws):
  189. wlist = []
  190. forceList = []
  191. for l in lineList:
  192. wlist.append(l[1])
  193. npwList = np.array(wlist)
  194. sortWeight = npwList.argsort()[::-1]
  195. for n,wId in enumerate(sortWeight):
  196. if n == 0:
  197. gMask = clusterRegion(lineList[wId][0], fg, ws)
  198. forceList.append([gMask, lineList[wId]])
  199. else:
  200. judge, forceList = judgeVertexAdvanced(lineList[wId][2], lineList[wId][0], npwList[wId], forceList, wg, iP )
  201. if judge == False:
  202. gMask = clusterRegion(lineList[wId][0], fg, ws)
  203. forceList.append([gMask, lineList[wId]])
  204. flList = forceLinesClusterIntegration(forceList)
  205. return flList
  206. # the main process of clustering the lines
  207. def findSaliantLineCluster(gradient4d,allLines,allLinesIndex,ws, orgW, orgH ):
  208. weightList = []
  209. fineGrained0 = 8 # initial refine grained = ws / 8
  210. intePrec0 = 0.8 # initial intersection precision
  211. forceLinesCluster = []
  212. # compute weights of lines
  213. for l in allLines:
  214. w = np.sum(gradient4d*l)
  215. weightList.append(w)
  216. npWeightList = np.array(weightList)
  217. sortWeightList = npWeightList.argsort()[::-1] # [::-1] inverse a list, range from large to small
  218. # top 300 weighted candidates, about 0.14% of the total lines
  219. # initialization of the first group of the leading lines
  220. for n,wId in enumerate(sortWeightList[:300]):
  221. if n == 0:
  222. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  223. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  224. else:
  225. if (npWeightList[sortWeightList[n-1]] - npWeightList[wId]) > 10 :
  226. print('weight break------in line ', str(n))
  227. break
  228. judge, forceLinesCluster = judgeVertexAdvanced(allLines[wId], allLinesIndex[wId], npWeightList[wId], forceLinesCluster , 2, intePrec0)
  229. if judge == False:
  230. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  231. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  232. forceLinesRough = forceLinesClusterIntegration(forceLinesCluster)
  233. forceLinesRoughNew = forceLinesRough
  234. forceLinesRoughOrg = []
  235. fineGrained = 7
  236. wGrained = 3
  237. intePrec = 0.7
  238. # regrouping and filtering leading lines, reture center lines and line groups
  239. for i in range(10000):
  240. if len(forceLinesRoughNew) == len(forceLinesRoughOrg):
  241. if (fineGrained <= 4 )and (wGrained >= 10) :
  242. print('break in loop ', str(i))
  243. break
  244. forceLinesRoughOrg = forceLinesRoughNew
  245. forceLinesRoughNew = refine(forceLinesRoughNew, fineGrained, wGrained, intePrec, ws)
  246. # update parameters
  247. if fineGrained > 4:
  248. fineGrained = fineGrained-1
  249. if intePrec > 0.6:
  250. intePrec = intePrec - 0.05
  251. if wGrained < 10:
  252. wGrained = wGrained + 1
  253. forceLines = []
  254. for l in forceLinesRoughNew:
  255. forceLines.append(l[0])
  256. forceLines = np.array(forceLines)
  257. scale = 1/ws
  258. HWscale = np.array([[np.ceil(orgW*scale),np.ceil(orgH*scale)],
  259. [np.ceil(orgW*scale),np.ceil(orgH*scale)]])
  260. HWS = np.expand_dims(HWscale,0).repeat(forceLines.shape[0],axis=0)
  261. forceLines = forceLines*HWS
  262. return forceLines, forceLinesCluster,HWS
  263. # Judging whether a line belongs to an existing group of leading lines
  264. # if a line spatially belongs to the group and the weight are within the threshold, add it to the group;
  265. # else if the weights are beyond the threshold range(which means it is weakly significant),do not add it to the group, ignore
  266. def judgeVertexAdvanced(line1,v1, v1w, forceL, wSeuil = 4, intersectPrecent = 0.7):
  267. v1 = np.array(v1)
  268. newGroup = False
  269. for cl in forceL:
  270. vPossible = cl[0]*line1
  271. if np.sum(vPossible) > (np.sum(line1)*intersectPrecent):
  272. if abs(cl[1][1] - v1w) < wSeuil:
  273. cl.append([v1,v1w,line1])
  274. return True,forceL
  275. else:
  276. return True,forceL
  277. return False, forceL
  278. # compute leading lines of the image and generate an image with leading lines
  279. def getLeadingLine(imgpath, outPath):
  280. windowSize = 64
  281. allLines, allLinesIndex = getBaseLines(windowSize)
  282. img = imread(imgpath)
  283. print(img.shape)
  284. if (len(img.shape) != 3) or (img.shape[2] != 3):
  285. print('NOT a 3 channel image')
  286. else:
  287. orgH, orgW, _ = img.shape
  288. resizeImg = resize(img,(windowSize,windowSize))
  289. # add contrast
  290. logImg = exposure.adjust_log(resizeImg, 1)
  291. # get grayscale image
  292. grayImg = rgb2gray(logImg)
  293. # calculating the gradient
  294. gradient4d= gradient4D(grayImg)
  295. # grouping for leading lines
  296. forceLines, forceLinesCluster, scale = findSaliantLineCluster(gradient4d,allLines,allLinesIndex,windowSize, orgW, orgH )
  297. drawGroupLine(imgpath, forceLines, forceLinesCluster, scale, 'forceLines', 'red', outPath)
  298. if __name__ == '__main__':
  299. parser = argparse.ArgumentParser(description='Find the Probable leading lines, please provide 1) your input image path or a folder path for input images, and 2) the output folder you wish.')
  300. parser.add_argument('input', type=str, help='The path for your input image or folder')
  301. parser.add_argument('-o', '--output', type=str, default='./OUTPUT', help='The path for your output folder ')
  302. args = parser.parse_args()
  303. INPUT_DIRECTORY = args.input
  304. OUTPUT_DIRECTORY = args.output
  305. print('INPUT : ', INPUT_DIRECTORY)
  306. print('OUTPUT : ', OUTPUT_DIRECTORY)
  307. if not (os.path.exists(OUTPUT_DIRECTORY)):
  308. print('Create output path:' , OUTPUT_DIRECTORY)
  309. os.makedirs(OUTPUT_DIRECTORY)
  310. start = time.time()
  311. if os.path.isfile( INPUT_DIRECTORY ):
  312. if INPUT_DIRECTORY.lower().endswith(('.jpg', '.png')) and not INPUT_DIRECTORY.lower().startswith('.'):
  313. getLeadingLine(INPUT_DIRECTORY,OUTPUT_DIRECTORY)
  314. elif os.path.isdir( INPUT_DIRECTORY ):
  315. files= os.listdir(INPUT_DIRECTORY)
  316. for i, file in enumerate(files):
  317. if file.lower().endswith(('.jpg', '.png')) and not file.lower().startswith('.'):
  318. fullpath = os.path.join(INPUT_DIRECTORY, file)
  319. getLeadingLine(fullpath,OUTPUT_DIRECTORY)
  320. end = time.time()
  321. print(' use time = ', str((end - start)/60.0), 'm')