high8 registers

This commit is contained in:
Andreas Abel
2021-01-04 20:51:23 +01:00
parent ddbcc18731
commit 9ce44e1079
3 changed files with 115 additions and 58 deletions

View File

@@ -111,7 +111,9 @@ def getRegMemInit(instrNode, opRegDict, memOffset, useIndexedAddr):
reg = opRegDict[opIdx]
regPrefix = re.sub('\d', '', reg)
if 'MM' in regPrefix and xtype.startswith('f'):
if reg in High8Regs:
init += ['MOV {}, 0'.format(reg)]
elif 'MM' in regPrefix and xtype.startswith('f'):
init += ['MOV RAX, 0x4000000040000000']
for i in range(0, getRegSize(reg)/8, 8): init += ['MOV [R14+' + str(i) + '], RAX']
@@ -155,6 +157,7 @@ def runExperiment(instrNode, instrCode, init=None, unrollCount=500, loopCount=0,
init = list(OrderedDict.fromkeys(init)) # remove duplicates while maintaining the order
initCode = '; '.join(init)
useLateInit = any((reg in initCode.upper()) for reg in High8Regs)
if instrNode is not None and (instrNode.attrib.get('vex', '') == '1' or instrNode.attrib.get('evex', '') == '1'):
# vex and evex encoded instructions need a warm-up period before memory reads operate at full speed;
@@ -174,19 +177,25 @@ def runExperiment(instrNode, instrCode, init=None, unrollCount=500, loopCount=0,
nanoBenchCmd += ' -asm "' + instrCode + '"'
initObjFile = None
lateInitObjFile=None
if initCode:
if debugOutput: print 'init: ' + initCode
initObjFile = '/tmp/ramdisk/init.o'
assemble(initCode, initObjFile, asmFile='/tmp/ramdisk/init.s')
objFile = '/tmp/ramdisk/init.o'
if useLateInit:
lateInitObjFile = objFile
nanoBenchCmd += ' -asm_late_init "' + initCode + '"'
else:
initObjFile = objFile
nanoBenchCmd += ' -asm_init "' + initCode + '"'
assemble(initCode, objFile, asmFile='/tmp/ramdisk/init.s')
localHtmlReports.append('<li>Init: <pre>' + re.sub(';[ \t]*(.)', r';\n\1', initCode) + '</pre></li>\n')
nanoBenchCmd += ' -asm_init &quot;' + initCode + '&quot;'
localHtmlReports.append('<li><a href="javascript:;" onclick="this.outerHTML = \'<pre>' + nanoBenchCmd + '</pre>\'">Show nanoBench command</a></li>\n')
if debugOutput: print nanoBenchCmd
setNanoBenchParameters(unrollCount=unrollCount, loopCount=loopCount, warmUpCount=warmUpCount, basicMode=basicMode)
ret = runNanoBench(codeObjFile=codeObjFile, initObjFile=initObjFile)
ret = runNanoBench(codeObjFile=codeObjFile, initObjFile=initObjFile, lateInitObjFile=lateInitObjFile)
localHtmlReports.append('<li>Results:\n<ul>\n')
for evt, value in ret.items():
@@ -398,7 +407,7 @@ def getInstrInstanceFromNode(instrNode, doNotWriteRegs=None, doNotReadRegs=None,
ignoreRegs |= set(doNotWriteRegs)|globalDoNotWriteRegs|set(opRegDict.values())
if operandNode.attrib.get('r', '0') == '1':
ignoreRegs |= set(doNotReadRegs)|writtenRegs|readRegs|set(opRegDict.values())
regsList = filter(lambda x: not any(y in ignoreRegs for y in getSubRegs(x)) and not (x in [z for y in ignoreRegs for z in getSubRegs(y)]), regsList)
regsList = filter(lambda x: not any(getCanonicalReg(x) == getCanonicalReg(y) for y in ignoreRegs), regsList)
if not regsList:
return None;
reg = sortRegs(regsList)[0]
@@ -639,11 +648,15 @@ def hasCommonRegister(instrNode):
def findCommonRegisters(instrNode):
for opNode1 in instrNode.findall('./operand[@type="reg"]'):
regs1 = set(map(getCanonicalReg, opNode1.text.split(",")))
regs1 = set(opNode1.text.split(","))
regs1Canonical = set(map(getCanonicalReg, regs1))
for opNode2 in instrNode.findall('./operand[@type="reg"]'):
if opNode1 == opNode2: continue
regs2 = set(map(getCanonicalReg, opNode2.text.split(",")))
intersection = regs1.intersection(regs2)
regs2 = set(opNode2.text.split(","))
regs2Canonical = set(map(getCanonicalReg, regs2))
if (regs1.intersection(High8Regs) and regs2.intersection(Low8Regs)) or (regs2.intersection(High8Regs) and regs1.intersection(Low8Regs)):
continue
intersection = regs1Canonical.intersection(regs2Canonical)
if intersection:
return intersection
return set()
@@ -911,12 +924,11 @@ def getTPConfigsForDiv(instrNode):
if 'YMM' in instrNode.attrib['iform']: regType = 'YMM'
if 'ZMM' in instrNode.attrib['iform']: regType = 'ZMM'
targetRegIdx = min(int(opNode.attrib['idx']) for opNode in instrNode.findall('./operand') if opNode.text and regType in opNode.text)
config.init = ['MOV RAX, ' + arg]
for i in range(0, getRegSize(regType)/8, 8): config.init += ['MOV [R14+' + str(i) + '], RAX']
targetRegIdx = min(int(opNode.attrib['idx']) for opNode in instrNode.findall('./operand') if opNode.text and regType in opNode.text)
if memDivisor:
for i in range(0, 64, 8): config.init += ['MOV [R14+' + str(i) + '], RAX']
instrs = [getInstrInstanceFromNode(instrNode, opRegDict={targetRegIdx:regType+str(reg)}) for reg in range(2, 10)]
else:
sourceReg = regType + '0'
@@ -1175,11 +1187,21 @@ basicLatency = {}
def getBasicLatencies(instrNodeList):
movsxResult = runExperiment(instrNodeDict['MOVSXD (R64, R32)'], 'MOVSX RAX, EAX')
movsxCycles = int(round(movsxResult['Core cycles']))
if not movsxCycles == 1:
if movsxCycles != 1:
print 'Latency of MOVSX must be 1'
sys.exit()
basicLatency['MOVSX'] = movsxCycles
movsxR8hResult = runExperiment(None, 'MOVSX EAX, AH; MOV AH, AL')
movsxR8hCycles = int(round(movsxR8hResult['Core cycles']))
if movsxR8hCycles != 2:
print 'Latency of "MOVSX EAX, AH; MOV AH, AL" must be 2'
sys.exit()
basicLatency['MOV_R8h_R8l'] = 1
movR8hR8hResult = runExperiment(instrNodeDict['MOV_88 (R8h, R8h)'], 'MOV AH, AH')
basicLatency['MOV_R8h_R8h'] = int(round(movR8hR8hResult['Core cycles']))
andResult = runExperiment(instrNodeDict['AND_21 (R64, R64)'], 'AND RAX, RBX')
basicLatency['AND'] = int(round(andResult['Core cycles']))
@@ -1353,7 +1375,7 @@ def getDivLatConfigLists(instrNode, opNode1, opNode2, cRep):
if memDivisor:
instrI = getInstrInstanceFromNode(instrNode)
else:
divisorReg = regToSize('RBX', width)
divisorReg = 'BH' if ('BH' in divisorNode.text) else regToSize('RBX', width)
instrI = getInstrInstanceFromNode(instrNode, opRegDict={int(divisorNode.attrib['idx']):divisorReg})
if width == 8:
@@ -1389,15 +1411,15 @@ def getDivLatConfigLists(instrNode, opNode1, opNode2, cRep):
else:
config.notes.append('fast division')
immReg = {'RAX': 'R8', 'RDX': 'R9', 'divisor': 'R10'}
config.init = ['MOV ' + immReg['RAX'] + ', ' + RAX,
'MOV ' + immReg['RDX'] + ', ' + RDX,
'MOV ' + immReg['divisor'] + ', ' + divisor]
immReg = {'RAX': 'R8', 'RDX': 'R9', 'divisor': 'RCX'}
config.init = ['MOV {}, {}'.format(immReg['RAX'], RAX),
'MOV {}, {}'.format(immReg['RDX'], RDX),
'MOV {}, {}'.format(immReg['divisor'], divisor)]
if memDivisor:
config.init += ['MOV [R14], ' + immReg['divisor']]
else:
config.init += ['MOV RBX, ' + immReg['divisor']]
config.init += ['MOV {}, {}'.format(divisorReg, regToSize(immReg['divisor'], width))]
config.init += ['MOV RAX, ' + immReg['RAX'],
'MOV RDX, ' + immReg['RDX']]
@@ -1903,10 +1925,28 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
reg2Prefix = re.sub('\d', '', reg2)
if reg1 in GPRegs and reg2 in GPRegs:
# MOVSX avoids partial reg stalls and cannot be eliminated by "move elimination"
chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg1), regToSize(reg2, min(32, getRegSize(reg2))))
chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg1), regTo32(reg1)) * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=basicLatency['MOVSX']*(cRep+1)))
if reg1 in High8Regs:
if reg2 in High8Regs:
chainInstrs = 'MOV {}, {};'.format(reg1, reg2)
chainInstrs += 'MOV {}, {};'.format(reg1, reg1) * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=basicLatency['MOV_R8h_R8h']*(cRep+1)))
elif reg2 in Low8Regs:
chainInstrs = 'MOV {}, {};'.format(reg1, reg2)
chainInstrs += 'MOV {}, {};'.format(reg1, reg1) * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=basicLatency['MOV_R8h_R8l'] + basicLatency['MOV_R8h_R8h']*cRep))
else:
chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg1), regToSize(reg2, min(32, getRegSize(reg2))))
chainInstrs += 'MOV {}, {};'.format(reg1, regTo8(reg1))
chainInstrs += 'MOV {}, {};'.format(reg1, reg1) * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=basicLatency['MOVSX'] + basicLatency['MOV_R8h_R8l']
+ basicLatency['MOV_R8h_R8h']*cRep))
else:
# MOVSX avoids partial reg stalls and cannot be eliminated by "move elimination"
reg1s = regTo32(reg1) if (reg2 in High8Regs) else regTo64(reg1)
reg2s = reg2 if (reg2 in High8Regs) else regToSize(reg2, min(32, getRegSize(reg2)))
chainInstrs = 'MOVSX {}, {};'.format(reg1s, reg2s)
chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg1), regTo32(reg1)) * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=basicLatency['MOVSX']*(cRep+1)))
elif reg1Prefix == 'K' and reg2Prefix == 'K':
chainInstr = 'KMOVQ {}, {};'.format(reg1, reg2)
chainInstr += 'KMOVQ {0}, {0};'.format(reg1) * cRep
@@ -1965,9 +2005,14 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
chainLatency = basicLatency['CMOV' + flag[0]]
instrI = getInstrInstanceFromNode(instrNode, ['R15'], ['R15'], useDistinctRegs, {startNodeIdx:reg})
movsxInstr = 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, regSize)))
chainInstrs = chainInstr + movsxInstr * cRep
chainLatency = chainLatency + basicLatency['MOVSX'] * cRep
if reg in High8Regs:
movInstr = 'MOV {}, {};'.format(reg, reg)
chainInstrs = chainInstr + movInstr * cRep
chainLatency = chainLatency + basicLatency['MOV_R8h_R8h'] * cRep
else:
movsxInstr = 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, regSize)))
chainInstrs = chainInstr + movsxInstr * cRep
chainLatency = chainLatency + basicLatency['MOVSX'] * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
elif 'MM' in reg:
@@ -1990,9 +2035,14 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
configList.isUpperBound = True
chainInstrs = 'MOV {}, [{}];'.format(reg, addrReg)
chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, getRegSize(reg)))) * cRep
chainLatency = int(basicLatency['MOV_10MOVSX_MOV_'+str(getRegSize(reg))] >= 12) # 0 if CPU supports zero-latency store forwarding
chainLatency += basicLatency['MOVSX'] * cRep
if reg in High8Regs:
chainInstrs += 'MOV {}, {};'.format(reg, reg) * cRep
chainLatency = basicLatency['MOV_R8h_R8h'] * cRep
else:
chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, getRegSize(reg)))) * cRep
chainLatency = basicLatency['MOVSX'] * cRep
chainLatency += int(basicLatency['MOV_10MOVSX_MOV_'+str(getRegSize(reg))] >= 12) # 0 if CPU supports zero-latency store forwarding
if re.search('BT.*MEMv_GPRv', instrNode.attrib['iform']):
chainInstrs += 'AND ' + reg + ', 0;'
@@ -2030,8 +2080,12 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
chainLatency = basicLatency['TEST']
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, getRegSize(reg)))) * cRep + chainInstrs
chainLatency += basicLatency['MOVSX'] * cRep
if reg in High8Regs:
chainInstrs = 'MOV {}, {};'.format(reg, reg) * cRep + chainInstrs
chainLatency += basicLatency['MOV_R8h_R8h'] * cRep
else:
chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg), regToSize(reg, min(32, getRegSize(reg)))) * cRep + chainInstrs
chainLatency += basicLatency['MOVSX'] * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
else:
# ToDo: there is no instruction from flag to vector reg; the only non-GPR that is possible are ST(0) and X87STATUS
@@ -2093,7 +2147,6 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
chainInstrs = 'MOVSX ' + regTo64(reg) + ', ' + regToSize(reg, min(32, regSize)) + ';'
chainInstrs += 'XOR {}, {};'.format(chainReg, regTo64(reg)) * cRep + ('TEST R15, R15;' if instrReadsFlags else '') # cRep is a multiple of 2
chainLatency = basicLatency['MOVSX'] + basicLatency['XOR'] * cRep
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
else:
# mem -> reg
configList = LatConfigList()
@@ -2103,7 +2156,11 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addrMem
chainInstrs += 'mov [{}], {};'.format(addrReg, regToSize('R12', regSize))
chainLatency = basicLatency['MOVSX'] * cRep
chainLatency += int(basicLatency['MOV_10MOVSX_MOV_'+str(regSize)] >= 12) # 0 if CPU supports zero-latency store forwarding
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
if reg in High8Regs:
chainInstrs = 'MOVSX {}, {};'.format(regTo32(reg), reg) + chainInstrs
chainInstrs += 'MOV {}, {}'.format(reg, reg) # 'clean' reg again; this is not on the critical path
chainLatency += basicLatency['MOVSX']
configList.append(LatConfig(instrI, chainInstrs=chainInstrs, chainLatency=chainLatency))
elif 'MM' in reg:
if addrMem in ['addr', 'addr_index']:
# addr -> reg
@@ -2311,7 +2368,7 @@ def getLatencies(instrNode, instrNodeList, tpDict, tpDictSameReg, htmlReports):
print 'readOnlyRegOpNodeIdx not found in opRegDict'
continue
reg = latConfig.instrI.opRegDict[readOnlyRegOpNodeIdx]
if not reg in GPRegs or reg in globalDoNotWriteRegs or reg in specialRegs: continue
if (not reg in GPRegs) or (reg in High8Regs) or (reg in globalDoNotWriteRegs) or (reg in specialRegs): continue
if any((opNode is not None) for opNode in instrNode.findall('./operand[@type="reg"][@w="1"]')
if regTo64(latConfig.instrI.opRegDict[int(opNode.attrib['idx'])]) == regTo64(reg)): continue
@@ -2864,7 +2921,7 @@ def main():
latencyDict = {instrNodeDict[k.attrib['string']]:v for k,v in pickle.load(f).items()}
elif not useIACA or iacaVersion == '2.1':
for i, instrNode in enumerate(instrNodeList):
#if not 'AES' in instrNode.attrib['string']: continue
#if not 'DIV' in instrNode.attrib['string']: continue
print 'Measuring latencies for ' + instrNode.attrib['string'] + ' (' + str(i) + '/' + str(len(instrNodeList)) + ')'
htmlReports = ['<h1>' + instrNode.attrib['string'] + ' - Latency' + (' (IACA '+iacaVersion+')' if useIACA else '') + '</h1>\n<hr>\n']

View File

@@ -1,4 +1,5 @@
import re
from collections import namedtuple
GPRegs = {'AH', 'AL', 'AX', 'BH', 'BL', 'BP', 'BPL', 'BX', 'CH', 'CL', 'CX', 'DH', 'DI', 'DIL', 'DL', 'DX', 'EAX',
'EBP', 'EBX', 'ECX', 'EDI', 'EDX', 'ESI', 'ESP', 'R10', 'R10B', 'R10D', 'R10W', 'R11', 'R11B', 'R11D', 'R11W', 'R12',
@@ -6,6 +7,9 @@ GPRegs = {'AH', 'AL', 'AX', 'BH', 'BL', 'BP', 'BPL', 'BX', 'CH', 'CL', 'CX', 'DH
'R8', 'R8B', 'R8D', 'R8W', 'R9', 'R9B', 'R9D', 'R9W', 'RAX', 'RBP', 'RBX', 'RCX', 'RDI', 'RDX', 'RSI', 'RSP', 'SI',
'SIL', 'SP', 'SPL'}
High8Regs = {'AH', 'BH', 'CH', 'DH'}
Low8Regs = {'AL', 'BL', 'BPL', 'CL', 'DIL', 'DL', 'R10B', 'R11B', 'R12B', 'R13B', 'R14B', 'R15B', 'R8B', 'R9B', 'SIL', 'SPL'}
STATUSFLAGS = {'CF', 'PF', 'AF', 'ZF', 'SF', 'OF'}
STATUSFLAGS_noAF = {'CF', 'PF', 'ZF', 'SF', 'OF'}
@@ -87,28 +91,6 @@ def regToSize(reg, size):
elif size == 32: return regTo32(reg)
else: return regTo64(reg)
# Returns a set of registers that are a part of the register that is provided (e.g., EAX is a part of RAX; RAX is also a part of RAX)
def getSubRegs(reg):
subRegs = set()
subRegs.add(reg)
if reg in GPRegs:
regSize = getRegSize(reg)
if regSize > 8:
for size in [16, 32, 64]:
if size > regSize: continue
subRegs.add(regToSize(reg, size))
if 'AX' in reg or 'BX' in reg or 'CX' in reg or 'DX' in reg:
subRegs.add(reg[-2] + 'L')
subRegs.add(reg[-2] + 'H')
else:
subRegs.add(regTo8(reg))
elif 'ZMM' in reg:
subRegs.add('Y' + reg[1:])
subRegs.add('X' + reg[1:])
elif 'YMM' in reg:
subRegs.add('X' + reg[1:])
return subRegs
# Returns for a GPR the corresponding 64-bit registers, and for a (X|Y|Z)MM register the corresponding XMM register
def getCanonicalReg(reg):
if reg in GPRegs:
@@ -141,3 +123,20 @@ def getRegSize(reg):
elif reg.startswith('YMM'): return 256
elif reg.startswith('ZMM'): return 512
else: return -1
MemAddr = namedtuple('MemAddr', ['base', 'index', 'scale', 'displacement'])
def getMemAddr(memAddrAsm):
base = index = None
displacement = 0
scale = 1
for c in re.split('\+|-', re.search('\[(.*)\]', memAddrAsm).group(1)):
if '0x' in c:
displacement = int(c, 0)
if '-0x' in memAddrAsm:
displacement = -displacement
elif '*' in c:
index, scale = c.split('*')
scale = int(scale)
else:
base = c
return MemAddr(base, index, scale, displacement)