LigneForce.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. #Copyright (C) [2023] [ZHANG Jing, Université du Littoral Côte d'Opale]
  2. #
  3. #This program is free software: you can redistribute it and/or modify
  4. #it under the terms of the GNU General Public License as published by
  5. #the Free Software Foundation, either version 3 of the License, or
  6. #(at your option) any later version.
  7. #
  8. #This program is distributed in the hope that it will be useful,
  9. #but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. #MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. #GNU General Public License for more details.
  12. #
  13. #You should have received a copy of the GNU General Public License
  14. #along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. #!/usr/bin/env python
  16. # coding: utf-8
  17. import numpy as np
  18. import os
  19. import sys
  20. import copy
  21. import time
  22. import argparse
  23. from skimage.io import imread, imsave
  24. from skimage.color import rgb2gray
  25. from skimage import exposure
  26. from skimage.transform import resize
  27. from PIL import Image, ImageDraw, ImageColor
  28. ## compute image gradient map, take the absolute difference values in 4 dimensions for each pixel
  29. # --------------------------
  30. # | |
  31. # | i, j i, j+1 |
  32. # | i+1,j-1 i+1,j i+1,j+1 |
  33. # --------------------------
  34. def gradient4D(img):
  35. (row, col) = img.shape
  36. g4d = np.zeros((row, col))
  37. for i in range(row-1):
  38. for j in range(col-1):
  39. g4d[i, j] = abs(img[i+1, j] - img[i, j] ) + abs(img[i, j+1] - img[i, j]) + abs(img[i+1, j-1] - img[i, j]) + abs(img[i+1, j+1] - img[i, j])
  40. return npNormalise(g4d)
  41. # normalise values to [0, 1]
  42. def npNormalise(xArray):
  43. XNorm = (xArray - xArray.min()) / (xArray.max() - xArray.min())
  44. return XNorm
  45. # compute all potential lines for a square of size ws*ws
  46. # input : length of the square
  47. # output : 1. a set of binary images, each image contains only one line
  48. # 2. a set containing the coordinates of the start points and the end points of the line
  49. def getBaseLines(ws):
  50. baseLineList = []
  51. baseLineListIdex = []
  52. # Skip 5 pixels to avoid lines near edges
  53. for i in range(0,ws-5):
  54. for j in range(5,ws): # cut bord
  55. # 1
  56. # -------------
  57. # | |
  58. # 4 | |2
  59. # | |
  60. # | |
  61. # --------------
  62. # 3
  63. # adjacent edge (like edge 1 and edge 2, edge 1 and edge 4)
  64. img12 = Image.new('F', (ws,ws),0)
  65. draw12 = ImageDraw.Draw(img12)
  66. # lines that the start point in edge 1 and the end point in edge 2
  67. draw12.line(xy=(i, 0, ws-1, j),
  68. fill=(1), width = 1)
  69. baseLineList.append(np.asarray(img12))
  70. baseLineListIdex.append(np.asarray([[i, 0],[ws-1, j]]))
  71. # lines that the start point in edge 4 and the end point in edge 1
  72. baseLineList.append(np.rot90(np.asarray(img12), 1, axes=(0, 1)))
  73. baseLineListIdex.append(np.asarray([[j, 0],[0, ws-1-i]]))
  74. # lines that the start point in edge 3 and the end point in edge 4
  75. baseLineList.append(np.rot90(np.asarray(img12), 2, axes=(0, 1)))
  76. baseLineListIdex.append(np.asarray([[0, ws-1-j],[ws-1-i, ws-1]]))
  77. # lines that the start point in edge 2 and the end point in edge 3
  78. baseLineList.append(np.rot90(np.asarray(img12), 3, axes=(0, 1)))
  79. baseLineListIdex.append(np.asarray([[ws-1, i],[ws-1-j, ws-1]]))
  80. # opposite side
  81. img13 = Image.new('F', (ws,ws),0)
  82. draw13 = ImageDraw.Draw(img13)
  83. # lines that the start point in edge 4 and the end point in edge 2
  84. draw13.line(xy=(i, 0, j, ws-1),
  85. fill=(1), width = 1)
  86. baseLineList.append(np.asarray(img13))
  87. baseLineListIdex.append(np.asarray([[i,0],[j, ws-1]]))
  88. # lines that the start point in edge 1 and the end point in edge 3
  89. baseLineList.append(np.asarray(img13).T)
  90. baseLineListIdex.append(np.asarray([[0,i],[ws-1, j]]))
  91. print('base line number :', len(baseLineList))
  92. return np.asarray(baseLineList), np.asarray(baseLineListIdex)
  93. # Calculate the slope of the line formed by vertex1 and vertex2
  94. def calculSlope(v1,v2):
  95. difX = v2[0] - v1[0]
  96. difY = v2[1] - v1[1]
  97. if difX == 0 :
  98. lk = 5*difY
  99. else:
  100. lk = difY / difX
  101. return lk
  102. # Compute the band mask of a line
  103. def clusterRegion(centerLine, scale = 4, windowSize=64):
  104. H = windowSize
  105. W = windowSize
  106. sMask = np.zeros([H,W])
  107. ix = int(centerLine[0][0])
  108. iy = int(centerLine[0][1])
  109. # calculate the width of band mask
  110. pixelRange = int(min(H,W) / scale) # scale = 10
  111. # get the slope of line
  112. k = calculSlope(centerLine[0],centerLine[1])
  113. if abs(k) > 1:
  114. while ix > 0:
  115. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  116. frontY = max(0, iy-pixelRange)
  117. backY = min(W,iy+pixelRange+1)
  118. sMask[ix, frontY:backY] = 1
  119. ix = ix - 1
  120. ix = int(centerLine[0][0])
  121. while ix < H:
  122. iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
  123. frontY = max(0, iy-pixelRange)
  124. backY = min(W,iy+pixelRange+1)
  125. sMask[ix, frontY:backY] = 1
  126. ix = ix + 1
  127. else:
  128. while iy > 0:
  129. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  130. frontX = max(0, ix-pixelRange)
  131. backX = min(H,ix+pixelRange+1)
  132. sMask[frontX:backX, iy] = 1
  133. iy = iy - 1
  134. iy = int(centerLine[0][1])
  135. while iy < W:
  136. ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
  137. frontX = max(0, ix-pixelRange)
  138. backX = min(H,ix+pixelRange+1)
  139. sMask[frontX:backX, iy] = 1
  140. iy = iy + 1
  141. return sMask
  142. # fonction for display all the lines
  143. def drawGroupLine(file, lineList, flineListCluster, scale, functionName, colorSTR, outputPath):
  144. c = ImageColor.colormap
  145. cList = list(c.items())
  146. (inputPath,inputFile) = os.path.split(file)
  147. print(inputPath)
  148. print(inputFile)
  149. # read the orignal file for draw
  150. with Image.open(file) as img4draw:
  151. w, h = img4draw.size
  152. if w >h: # add lineWidth to adapt the visibility of drawing results to different image sizes
  153. lineWidth = int(h/40)
  154. else:
  155. lineWidth = int(w/40)
  156. scale = 1/64
  157. wScale = np.ceil(w*scale)
  158. hScale = np.ceil(h*scale)
  159. img1 = ImageDraw.Draw(img4draw)
  160. # draw the cluster result
  161. # for n,lineSet in enumerate(flineListCluster):
  162. # for [v1,v2],w,_ in lineSet[1:]:
  163. # img1.line([(v1[0]*wScale,v1[1]*hScale), (v2[0]*wScale,v2[1]*hScale)], fill = cList[int(n*2)+2][1], width = 4)
  164. # draw all the centers
  165. for [v1,v2] in lineList:
  166. img1.line([(v1[0],v1[1]), (v2[0],v2[1])], fill = colorSTR, width = lineWidth)
  167. img4draw.save(os.path.join(outputPath, inputFile[:-4] + '_' + str(functionName) + inputFile[-4:] ))
  168. # sort the slope of lines, inutile for version 0
  169. def sortSlope(lineListArray):
  170. print('lineListArray', lineListArray)
  171. slopeList = []
  172. groupWeight = 0
  173. for l in lineListArray:
  174. if (l[0][1][0] - l[0][0][0] ) == 0:
  175. k = 1000
  176. else:
  177. k = (l[0][1][1] - l[0][0][1]) / (l[0][1][0] - l[0][0][0])
  178. slopeList.append(k)
  179. groupWeight = groupWeight + l[1]
  180. # print('weight = ', l[1])
  181. print('slopeList : ', slopeList)
  182. index = np.argsort(np.array(slopeList))
  183. print('sortSlope index : ', index)
  184. print('sortSlope index median : ', int(np.median(index)))
  185. #groupWeight = np.mean(groupWeight)
  186. return [lineListArray[int(np.median(index))][0], lineListArray[0][1], lineListArray[int(np.median(index))][2]]
  187. # return [lineListArray[int(len(index)/2)][0], lineListArray[0][1], lineListArray[int(len(index)/2)][2]]
  188. # index[len(index)//2]
  189. # extraction the center of each group
  190. def forceLinesClusterIntegration(cluster):
  191. forceL = []
  192. for i,lineSet in enumerate(cluster):
  193. forceL.append(lineSet[1])
  194. return forceL
  195. # refine the cluster result
  196. def refine(lineList, fg, wg, iP, ws):
  197. wlist = []
  198. forceList = []
  199. for l in lineList:
  200. wlist.append(l[1])
  201. npwList = np.array(wlist)
  202. sortWeight = npwList.argsort()[::-1]
  203. for n,wId in enumerate(sortWeight):
  204. if n == 0:
  205. gMask = clusterRegion(lineList[wId][0], fg, ws)
  206. forceList.append([gMask, lineList[wId]])
  207. else:
  208. judge, forceList = judgeVertexAdvanced(lineList[wId][2], lineList[wId][0], npwList[wId], forceList, wg, iP )
  209. if judge == False:
  210. gMask = clusterRegion(lineList[wId][0], fg, ws)
  211. forceList.append([gMask, lineList[wId]])
  212. flList = forceLinesClusterIntegration(forceList)
  213. return flList
  214. # the main process of clustering the lines
  215. def findSaliantLineCluster(gradient4d,allLines,allLinesIndex,ws, orgW, orgH ):
  216. weightList = []
  217. fineGrained0 = 8 # initial refine grained = ws / 8
  218. intePrec0 = 0.8 # initial intersection precision
  219. forceLinesCluster = []
  220. # compute weights of lines
  221. for l in allLines:
  222. w = np.sum(gradient4d*l)
  223. weightList.append(w)
  224. npWeightList = np.array(weightList)
  225. sortWeightList = npWeightList.argsort()[::-1] # [::-1] inverse a list, range from large to small
  226. # top 300 weighted candidates, about 0.14% of the total lines
  227. for n,wId in enumerate(sortWeightList[:300]):
  228. if n == 0:
  229. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  230. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  231. #
  232. else:
  233. # print(npWeightList[sortWeightList[n-1]])
  234. # print(npWeightList[wId])
  235. if (npWeightList[sortWeightList[n-1]] - npWeightList[wId]) > 10 :
  236. print('weight break------in line ', str(n))
  237. break
  238. judge, forceLinesCluster = judgeVertexAdvanced(allLines[wId], allLinesIndex[wId], npWeightList[wId], forceLinesCluster , 2, intePrec0)
  239. if judge == False:
  240. groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
  241. forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
  242. forceLinesRough = forceLinesClusterIntegration(forceLinesCluster)
  243. forceLinesRoughNew = forceLinesRough
  244. forceLinesRoughOrg = []
  245. fineGrained = 7
  246. wGrained = 3
  247. intePrec = 0.7
  248. for i in range(10000):
  249. if len(forceLinesRoughNew) == len(forceLinesRoughOrg):
  250. if (fineGrained <= 4 )and (wGrained >= 10) :
  251. print('break in loop ', str(i))
  252. break
  253. forceLinesRoughOrg = forceLinesRoughNew
  254. forceLinesRoughNew = refine(forceLinesRoughNew, fineGrained, wGrained, intePrec, ws)
  255. if fineGrained > 4:
  256. fineGrained = fineGrained-1
  257. if intePrec > 0.6:
  258. intePrec = intePrec - 0.05
  259. if wGrained < 10:
  260. wGrained = wGrained + 1
  261. forceLines = []
  262. for l in forceLinesRoughNew:
  263. forceLines.append(l[0])
  264. forceLines = np.array(forceLines)
  265. scale = 1/ws
  266. HWscale = np.array([[np.ceil(orgW*scale),np.ceil(orgH*scale)],
  267. [np.ceil(orgW*scale),np.ceil(orgH*scale)]])
  268. HWS = np.expand_dims(HWscale,0).repeat(forceLines.shape[0],axis=0)
  269. forceLines = forceLines*HWS
  270. return forceLines, forceLinesCluster,HWS
  271. def judgeVertexAdvanced(line1,v1, v1w, forceL, wSeuil = 4, intersectPrecent = 0.7):
  272. v1 = np.array(v1)
  273. newGroup = False
  274. for cl in forceL:
  275. vPossible = cl[0]*line1
  276. if np.sum(vPossible) > (np.sum(line1)*intersectPrecent):
  277. if abs(cl[1][1] - v1w) < wSeuil:
  278. cl.append([v1,v1w,line1])
  279. return True,forceL
  280. else:
  281. return True,forceL
  282. return False, forceL
  283. def getLeadingLine(imgpath, outPath):
  284. windowSize = 64
  285. allLines, allLinesIndex = getBaseLines(windowSize)
  286. img = imread(imgpath)
  287. print(img.shape)
  288. if (len(img.shape) != 3) or (img.shape[2] != 3):
  289. print('NOT a 3 channel image')
  290. else:
  291. orgH, orgW, _ = img.shape
  292. resizeImg = resize(img,(windowSize,windowSize))
  293. # Add contrast
  294. logImg = exposure.adjust_log(resizeImg, 1)
  295. grayImg = rgb2gray(logImg)
  296. gradient4d= gradient4D(grayImg)
  297. forceLines, forceLinesCluster, scale = findSaliantLineCluster(gradient4d,allLines,allLinesIndex,windowSize, orgW, orgH )
  298. drawGroupLine(imgpath, forceLines, forceLinesCluster, scale, 'forceLines', 'red', outPath)
  299. if __name__ == '__main__':
  300. parser = argparse.ArgumentParser(description='Find the Probable leading lines, please provide 1) your input image path or a folder path for input images, and 2) the output folder you wish.')
  301. parser.add_argument('input', type=str, help='The path for your input image or folder')
  302. parser.add_argument('-o', '--output', type=str, default='./OUTPUT', help='The path for your output folder ')
  303. args = parser.parse_args()
  304. INPUT_DIRECTORY = args.input
  305. OUTPUT_DIRECTORY = args.output
  306. print('INPUT : ', INPUT_DIRECTORY)
  307. print('OUTPUT : ', OUTPUT_DIRECTORY)
  308. if not (os.path.exists(OUTPUT_DIRECTORY)):
  309. print('Create output path:' , OUTPUT_DIRECTORY)
  310. os.makedirs(OUTPUT_DIRECTORY)
  311. start = time.time()
  312. if os.path.isfile( INPUT_DIRECTORY ):
  313. if INPUT_DIRECTORY.lower().endswith(('.jpg', '.png')) and not INPUT_DIRECTORY.lower().startswith('.'):
  314. getLeadingLine(INPUT_DIRECTORY,OUTPUT_DIRECTORY)
  315. elif os.path.isdir( INPUT_DIRECTORY ):
  316. files= os.listdir(INPUT_DIRECTORY)
  317. for i, file in enumerate(files):
  318. if file.lower().endswith(('.jpg', '.png')) and not file.lower().startswith('.'):
  319. fullpath = os.path.join(INPUT_DIRECTORY, file)
  320. getLeadingLine(fullpath,OUTPUT_DIRECTORY)
  321. end = time.time()
  322. print(' use time = ', str((end - start)/60.0), 'm')