zhJ il y a 1 an
commit
0fb56be3b8
1 fichiers modifiés avec 342 ajouts et 0 suppressions
  1. 342 0
      LigneForce.py

+ 342 - 0
LigneForce.py

@@ -0,0 +1,342 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+import numpy as np
+import os
+import sys
+import copy
+import time
+
+from skimage.io import imread, imsave
+from skimage.color import rgb2gray
+from skimage import exposure
+from skimage.transform import resize
+
+from PIL import Image, ImageDraw, ImageColor
+
+
+def gradient4D(img):
+    (row, col) = img.shape
+    g4d = np.zeros((row, col))
+    
+    for i in range(row-1):
+        for j in range(col-1):
+            g4d[i, j] = abs(img[i+1, j] - img[i, j] ) + abs(img[i, j+1] - img[i, j]) + abs(img[i+1, j-1] - img[i, j]) + abs(img[i+1, j+1] - img[i, j])
+            
+    return npNormalise(g4d)
+
+
+def npNormalise(xArray):
+    XNorm = (xArray - xArray.min()) / (xArray.max() - xArray.min())
+    return XNorm
+
+def getBaseLines(ws):
+    baseLineList = []
+    baseLineListIdex = []
+    for i in range(0,ws-5):
+        #if i == 1: break
+        for j in range(5,ws):  #  cut bord
+            #if j == 6: break
+    #               1
+    #         -------------
+    #         |           |
+    #       4 |           |2
+    #         |           |
+    #         |           |
+    #         --------------
+    #               3
+            # adjacent edge
+            img12 = Image.new('F', (ws,ws),0)
+            draw12 = ImageDraw.Draw(img12)
+            draw12.line(xy=(i, 0, ws-1, j),
+                      fill=(1), width = 1)
+            baseLineList.append(np.asarray(img12)) 
+            baseLineListIdex.append(np.asarray([[i, 0],[ws-1, j]]))
+            baseLineList.append(np.rot90(np.asarray(img12), 1, axes=(0, 1)))
+            baseLineListIdex.append(np.asarray([[j, 0],[0, ws-1-i]]))
+            baseLineList.append(np.rot90(np.asarray(img12), 2, axes=(0, 1)))
+            baseLineListIdex.append(np.asarray([[0, ws-1-j],[ws-1-i, ws-1]]))
+            baseLineList.append(np.rot90(np.asarray(img12), 3, axes=(0, 1)))
+            baseLineListIdex.append(np.asarray([[ws-1, i],[ws-1-j, ws-1]]))
+
+            # opposite side
+            img13 = Image.new('F', (ws,ws),0)
+            draw13 = ImageDraw.Draw(img13)
+            draw13.line(xy=(i, 0, j, ws-1),
+                      fill=(1), width = 1)
+            baseLineList.append(np.asarray(img13)) 
+            baseLineListIdex.append(np.asarray([[i,0],[j, ws-1]]))
+            baseLineList.append(np.asarray(img13).T) 
+            baseLineListIdex.append(np.asarray([[0,i],[ws-1, j]]))
+    print('base line number :', len(baseLineList))
+    return np.asarray(baseLineList), np.asarray(baseLineListIdex)
+
+
+def calculSlope(v1,v2):
+    difX = v2[0] - v1[0]
+    difY = v2[1] - v1[1]
+    if difX == 0 :
+        lk = 5*difY
+    else:
+        lk = difY / difX
+    return lk
+
+def clusterRegion(centerLine, scale = 4, windowSize=64):
+    H = windowSize
+    W = windowSize
+    sMask = np.zeros([H,W])
+    ix = int(centerLine[0][0])
+    iy = int(centerLine[0][1])
+    pixelRange = int(min(H,W) / scale)  #  scale = 10 
+    
+    k = calculSlope(centerLine[0],centerLine[1])
+
+    if abs(k) > 1:  
+        while ix > 0:
+            iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
+            frontY = max(0, iy-pixelRange)
+            backY = min(W,iy+pixelRange+1)
+            sMask[ix, frontY:backY] = 1
+            ix = ix - 1
+        ix = int(centerLine[0][0])
+
+        while ix < H:
+            iy = int(round(((ix-centerLine[0][1]) / k) + centerLine[0][0]))
+            frontY = max(0, iy-pixelRange)
+            backY = min(W,iy+pixelRange+1) 
+            sMask[ix, frontY:backY] = 1
+            ix = ix + 1
+    else:
+        while iy > 0:
+            ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
+            frontX = max(0, ix-pixelRange)
+            backX = min(H,ix+pixelRange+1)
+            sMask[frontX:backX, iy] = 1
+            iy = iy - 1
+        iy = int(centerLine[0][1])
+
+        while iy < W:
+            ix = int(round(((iy-centerLine[0][0]) * k) + centerLine[0][1]))
+            frontX = max(0, ix-pixelRange)
+            backX = min(H,ix+pixelRange+1)
+            sMask[frontX:backX, iy] = 1
+            iy = iy + 1
+    return sMask
+
+
+
+def drawGroupLine(file, lineList, flineListCluster, scale, functionName, colorSTR, outputPath):
+    c = ImageColor.colormap
+    cList = list(c.items())
+    (inputPath,inputFile) = os.path.split(file)
+    print(inputPath)
+    print(inputFile)
+    with Image.open(file) as img4draw:
+        w, h = img4draw.size
+        scale = 1/64
+        wScale = np.ceil(w*scale)
+        hScale = np.ceil(h*scale)
+        
+        img1 = ImageDraw.Draw(img4draw) 
+            
+#         for n,lineSet in enumerate(flineListCluster):
+#             for [v1,v2],w,_ in lineSet[1:]:
+#                 img1.line([(v1[0]*wScale,v1[1]*hScale), (v2[0]*wScale,v2[1]*hScale)], fill = cList[int(n*2)+2][1], width = 4)
+        
+        for [v1,v2] in lineList:
+            img1.line([(v1[0],v1[1]), (v2[0],v2[1])], fill = colorSTR, width = 8)
+            
+        img4draw.save(os.path.join(outputPath, inputFile[:-4] + '_' + str(functionName) + inputFile[-4:] ))
+
+
+def sortSlope(lineListArray):
+    slopeList = []
+    groupWeight = 0
+    for l in lineListArray:
+        if (l[0][1][0] - l[0][0][0] ) == 0:
+            k = 1000
+        else:
+            k = (l[0][1][1] - l[0][0][1]) / (l[0][1][0] - l[0][0][0])
+        slopeList.append(k)
+        groupWeight = groupWeight + l[1]
+#         print('weight = ', l[1])
+    index = np.argsort(np.array(slopeList))
+    #groupWeight = np.mean(groupWeight)
+    
+    
+    return [lineListArray[int(np.median(index))][0], lineListArray[0][1], lineListArray[int(np.median(index))][2]]
+#     index[len(index)//2]
+
+def forceLinesClusterIntegration(cluster):
+    forceL = []
+    
+    for i,lineSet in enumerate(cluster):
+
+        forceL.append(sortSlope(lineSet[1:]))
+    return forceL
+
+
+
+def refine(lineList, fg, wg, iP, ws):
+    wlist = []
+    forceList = []
+    for l in lineList:
+        wlist.append(l[1])
+    npwList = np.array(wlist)
+    sortWeight = npwList.argsort()[::-1] 
+    
+    for n,wId in enumerate(sortWeight):
+        if n == 0:
+            gMask = clusterRegion(lineList[wId][0], fg, ws)
+            forceList.append([gMask, lineList[wId]])
+        else:
+            judge, forceList = judgeVertexAdvanced(lineList[wId][2], lineList[wId][0], npwList[wId], forceList, wg, iP )
+            if  judge == False:
+                gMask = clusterRegion(lineList[wId][0], fg, ws)
+                forceList.append([gMask, lineList[wId]])  
+    flList = forceLinesClusterIntegration(forceList)
+    return flList     
+                
+
+def findSaliantLineCluster(gradient4d,allLines,allLinesIndex,ws, orgW, orgH ):
+    weightList = []
+    fineGrained0 = 8
+    intePrec0 = 0.8
+    forceLinesCluster = []
+    for l in allLines:
+        w = np.sum(gradient4d*l)
+        weightList.append(w)
+
+    npWeightList = np.array(weightList)
+    sortWeightList = npWeightList.argsort()[::-1]   # [::-1] inverse a list
+    for n,wId in enumerate(sortWeightList[:300]):
+        if n == 0:
+            groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
+            
+            forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
+#
+        else:
+#             print(npWeightList[sortWeightList[n-1]])
+#             print(npWeightList[wId])
+            if (npWeightList[sortWeightList[n-1]] - npWeightList[wId]) > 10 :
+                print('weight break------in line ', str(n))
+                break
+            judge, forceLinesCluster = judgeVertexAdvanced(allLines[wId], allLinesIndex[wId], npWeightList[wId], forceLinesCluster , 2, intePrec0)
+            if  judge == False:
+                groupMask = clusterRegion(allLinesIndex[wId], fineGrained0, ws)
+                
+                forceLinesCluster.append([groupMask, [allLinesIndex[wId], npWeightList[wId], allLines[wId]]])
+
+    forceLinesRough = forceLinesClusterIntegration(forceLinesCluster)
+    
+    forceLinesRoughNew = forceLinesRough
+    forceLinesRoughOrg = []
+    fineGrained = 7
+    wGrained = 3
+    intePrec = 0.7
+    for i in range(10000):
+        
+        if len(forceLinesRoughNew) == len(forceLinesRoughOrg):
+            if (fineGrained <= 4 )and (wGrained >= 10) :
+                print('break in loop ', str(i))
+                break
+        
+        forceLinesRoughOrg = forceLinesRoughNew
+        forceLinesRoughNew = refine(forceLinesRoughNew, fineGrained, wGrained, intePrec, ws)
+        if fineGrained > 4:
+            fineGrained = fineGrained-1
+        if intePrec > 0.6:
+            intePrec = intePrec - 0.05
+        if wGrained < 10:
+            wGrained = wGrained + 1
+      
+            
+    
+    forceLines = []
+    
+    for l in forceLinesRoughNew:
+        forceLines.append(l[0])
+     
+    forceLines = np.array(forceLines)
+    scale = 1/ws
+    HWscale = np.array([[np.ceil(orgW*scale),np.ceil(orgH*scale)],
+                        [np.ceil(orgW*scale),np.ceil(orgH*scale)]])
+    HWS = np.expand_dims(HWscale,0).repeat(forceLines.shape[0],axis=0)
+    forceLines = forceLines*HWS
+    return forceLines, forceLinesCluster,HWS
+
+
+
+def judgeVertexAdvanced(line1,v1, v1w, forceL, wSeuil = 4, intersectPrecent = 0.7):
+    v1 = np.array(v1)
+    newGroup = False
+    
+    for cl in forceL:
+        vPossible = cl[0]*line1
+        if np.sum(vPossible) > (np.sum(line1)*intersectPrecent):
+            if abs(cl[1][1] - v1w) < wSeuil: 
+                cl.append([v1,v1w,line1])
+                return True,forceL
+            else:
+                return True,forceL
+
+    return False, forceL
+    
+ 
+def getLeadingLine(imgpath, outPath):
+    windowSize = 64
+    allLines, allLinesIndex = getBaseLines(windowSize)
+
+    img = imread(imgpath)
+    print(img.shape)
+    if (len(img.shape) != 3) or (img.shape[2] != 3):
+        print('NOT a 3 channel image')
+    else:
+        orgH, orgW, _ = img.shape
+            
+        resizeImg = resize(img,(windowSize,windowSize))
+
+         # Add contrast
+        logImg = exposure.adjust_log(resizeImg, 1)
+
+        
+        grayImg = rgb2gray(logImg)
+            
+        gradient4d= gradient4D(grayImg)
+        
+        forceLines, forceLinesCluster, scale = findSaliantLineCluster(gradient4d,allLines,allLinesIndex,windowSize, orgW, orgH )
+        drawGroupLine(imgpath, forceLines, forceLinesCluster, scale, 'forceLines', 'red', outPath)
+
+
+
+
+if __name__ == '__main__':
+    print(sys.argv[1])
+    print(sys.argv[2])
+    
+    INPUT_DIRECTORY =  sys.argv[1]
+    OUTPUT_DIRECTORY = sys.argv[2]
+
+    if not (os.path.exists(OUTPUT_DIRECTORY)):
+        print('Create output path:' , OUTPUT_DIRECTORY)
+        os.makedirs(OUTPUT_DIRECTORY)
+
+    start = time.time()
+    if os.path.isfile( INPUT_DIRECTORY ):
+        if INPUT_DIRECTORY.lower().endswith(('.jpg', '.png')) and not INPUT_DIRECTORY.lower().startswith('.'):
+            getLeadingLine(INPUT_DIRECTORY,OUTPUT_DIRECTORY)
+    elif os.path.isdir( INPUT_DIRECTORY ):
+        files= os.listdir(INPUT_DIRECTORY)
+        for i, file in enumerate(files):
+            if file.lower().endswith(('.jpg', '.png')) and not file.lower().startswith('.'):
+                fullpath = os.path.join(INPUT_DIRECTORY, file)
+                getLeadingLine(fullpath,OUTPUT_DIRECTORY)
+
+    end = time.time()
+    print('  use time = ', str((end - start)/60.0), 'm')
+
+
+            
+            
+    
+