This commit is contained in:
Andreas Abel
2020-05-06 00:15:27 +02:00
parent b7315a4dfb
commit be71cd8bdf
2 changed files with 161 additions and 156 deletions

View File

@@ -19,6 +19,7 @@ import shutil
import tarfile import tarfile
from utils import * from utils import *
from x64_lib import *
sys.path.append('../..') sys.path.append('../..')
from kernelNanoBench import * from kernelNanoBench import *
@@ -44,14 +45,7 @@ specialRegs = {'ES', 'CS', 'SS', 'DS', 'FS', 'GS', 'IP', 'EIP', 'FSBASEy', 'GDTR
'TR', 'TSC', 'TSCAUX', 'X87CONTROL', 'X87POP', 'X87POP2', 'X87PUSH', 'X87STATUS', 'X87TAG', 'XCR0', 'XMM0dq', 'CR0', 'CR2', 'CR3', 'CR4', 'CR8', 'ERROR', 'TR', 'TSC', 'TSCAUX', 'X87CONTROL', 'X87POP', 'X87POP2', 'X87PUSH', 'X87STATUS', 'X87TAG', 'XCR0', 'XMM0dq', 'CR0', 'CR2', 'CR3', 'CR4', 'CR8', 'ERROR',
'BND0', 'BND1', 'BND2', 'BND3'} 'BND0', 'BND1', 'BND2', 'BND3'}
GPRRegs = {'AH', 'AL', 'AX', 'BH', 'BL', 'BP', 'BPL', 'BX', 'CH', 'CL', 'CX', 'DH', 'DI', 'DIL', 'DL', 'DX', 'EAX',
'EBP', 'EBX', 'ECX', 'EDI', 'EDX', 'ESI', 'ESP', 'R10', 'R10B', 'R10D', 'R10W', 'R11', 'R11B', 'R11D', 'R11W', 'R12',
'R12B', 'R12D', 'R12W', 'R13', 'R13B', 'R13D', 'R13W', 'R14', 'R14B', 'R14D', 'R14W', 'R15', 'R15B', 'R15D', 'R15W',
'R8', 'R8B', 'R8D', 'R8W', 'R9', 'R9B', 'R9D', 'R9W', 'RAX', 'RBP', 'RBX', 'RCX', 'RDI', 'RDX', 'RSI', 'RSP', 'SI',
'SIL', 'SP', 'SPL'}
STATUSFLAGS = {'CF', 'PF', 'AF', 'ZF', 'SF', 'OF'}
STATUSFLAGS_noAF = {'CF', 'PF', 'ZF', 'SF', 'OF'}
maxTPRep = 16 maxTPRep = 16
@@ -65,138 +59,7 @@ def isAMDCPU():
def isIntelCPU(): def isIntelCPU():
return not isAMDCPU() return not isAMDCPU()
def regTo64(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'RAX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'RBX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'RCX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'RDX'
if 'SP' in reg: return 'RSP'
if 'BP' in reg: return 'RBP'
if 'SI' in reg: return 'RSI'
if 'DI' in reg: return 'RDI'
if '8' in reg: return 'R8'
if '9' in reg: return 'R9'
if '10' in reg: return 'R10'
if '11' in reg: return 'R11'
if '12' in reg: return 'R12'
if '13' in reg: return 'R13'
if '14' in reg: return 'R14'
if '15' in reg: return 'R15'
def regTo32(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'EAX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'EBX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'ECX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'EDX'
if 'SP' in reg: return 'ESP'
if 'BP' in reg: return 'EBP'
if 'SI' in reg: return 'ESI'
if 'DI' in reg: return 'EDI'
if '8' in reg: return 'R8D'
if '9' in reg: return 'R9D'
if '10' in reg: return 'R10D'
if '11' in reg: return 'R11D'
if '12' in reg: return 'R12D'
if '13' in reg: return 'R13D'
if '14' in reg: return 'R14D'
if '15' in reg: return 'R15D'
def regTo16(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'AX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'BX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'CX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'DX'
if 'SP' in reg: return 'SP'
if 'BP' in reg: return 'BP'
if 'SI' in reg: return 'SI'
if 'DI' in reg: return 'DI'
if '8' in reg: return 'R8W'
if '9' in reg: return 'R9W'
if '10' in reg: return 'R10W'
if '11' in reg: return 'R11W'
if '12' in reg: return 'R12W'
if '13' in reg: return 'R13W'
if '14' in reg: return 'R14W'
if '15' in reg: return 'R15W'
def regTo8(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'AL'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'BL'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'CL'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'DL'
if 'SP' in reg: return 'SPL'
if 'BP' in reg: return 'BPL'
if 'SI' in reg: return 'SIL'
if 'DI' in reg: return 'DIL'
if '8' in reg: return 'R8B'
if '9' in reg: return 'R9B'
if '10' in reg: return 'R10B'
if '11' in reg: return 'R11B'
if '12' in reg: return 'R12B'
if '13' in reg: return 'R13B'
if '14' in reg: return 'R14B'
if '15' in reg: return 'R15B'
def regToSize(reg, size):
if size == 8: return regTo8(reg)
elif size == 16: return regTo16(reg)
elif size == 32: return regTo32(reg)
else: return regTo64(reg)
# Returns a set of registers that are a part of the register that is provided (e.g., EAX is a part of RAX; RAX is also a part of RAX)
def getSubRegs(reg):
subRegs = set()
subRegs.add(reg)
if reg in GPRRegs:
regSize = getRegSize(reg)
if regSize > 8:
for size in [16, 32, 64]:
if size > regSize: continue
subRegs.add(regToSize(reg, size))
if 'AX' in reg or 'BX' in reg or 'CX' in reg or 'DX' in reg:
subRegs.add(reg[-2] + 'L')
subRegs.add(reg[-2] + 'H')
else:
subRegs.add(regTo8(reg))
elif 'ZMM' in reg:
subRegs.add('Y' + reg[1:])
subRegs.add('X' + reg[1:])
elif 'YMM' in reg:
subRegs.add('X' + reg[1:])
return subRegs
# Returns for a GPR the corresponding 64-bit registers, and for a (X|Y|Z)MM register the corresponding XMM register
def getCanonicalReg(reg):
if reg in GPRRegs:
return regTo64(reg)
elif 'MM' in reg:
return re.sub('^[YZ]', 'X', reg)
else:
return reg
def getRegForMemPrefix(reg, memPrefix):
return regToSize(reg, getSizeOfMemPrefix(memPrefix))
def getSizeOfMemPrefix(memPrefix):
if 'zmmword' in memPrefix: return 512
elif 'ymmword' in memPrefix: return 256
elif 'xmmword' in memPrefix: return 128
elif 'qword' in memPrefix: return 64
elif 'dword' in memPrefix: return 32
elif 'word' in memPrefix: return 16
elif 'byte' in memPrefix: return 8
else: return -1
def getRegSize(reg):
if reg[-1] == 'L' or reg[-1] == 'H' or reg[-1] == 'B': return 8
elif reg[-1] == 'W' or reg in ['AX', 'BX', 'CX', 'DX', 'SP', 'BP' 'SI', 'DI']: return 16
elif reg[0] == 'E' or reg[-1] == 'D': return 32
elif reg in GPRRegs: return 64
elif reg.startswith('MM'): return 64
elif reg.startswith('XMM'): return 128
elif reg.startswith('YMM'): return 256
elif reg.startswith('ZMM'): return 512
else: return -1
def getAddrReg(instrNode, opNode): def getAddrReg(instrNode, opNode):
if opNode.attrib.get('suppressed', '0') == '1': if opNode.attrib.get('suppressed', '0') == '1':
@@ -1308,7 +1171,7 @@ def getDependencyBreakingInstrs(instrNode, opRegDict, ignoreOperand = None):
elif opNode.attrib.get('suppressed', '0') == '1': elif opNode.attrib.get('suppressed', '0') == '1':
reg = opNode.text reg = opNode.text
regPrefix = re.sub('\d', '', reg) regPrefix = re.sub('\d', '', reg)
if reg in GPRRegs: if reg in GPRegs:
if reg not in globalDoNotWriteRegs: if reg not in globalDoNotWriteRegs:
depBreakingInstrs[opNode] = 'MOV ' + reg + ', 0' # don't use XOR as this would also break flag dependencies depBreakingInstrs[opNode] = 'MOV ' + reg + ', 0' # don't use XOR as this would also break flag dependencies
elif reg in ['RSP', 'RBP']: elif reg in ['RSP', 'RBP']:
@@ -1351,7 +1214,7 @@ def getDependencyBreakingInstrsForSuppressedOperands(instrNode):
if opNode.attrib.get('suppressed', '0') == '0' and ',' in opNode.text: continue if opNode.attrib.get('suppressed', '0') == '0' and ',' in opNode.text: continue
reg = opNode.text reg = opNode.text
if not reg in GPRRegs: continue if not reg in GPRegs: continue
if reg in globalDoNotWriteRegs|specialRegs: continue if reg in globalDoNotWriteRegs|specialRegs: continue
writeOfRegFound = False writeOfRegFound = False
@@ -1680,13 +1543,13 @@ def getAllChainInstrsFromRegToReg(instrNode, startReg, targetReg):
if dataType and any((d in iclass) for d in allFPDataTypes) and not dataType in iclass: continue if dataType and any((d in iclass) for d in allFPDataTypes) and not dataType in iclass: continue
for chainOpNode1 in chainInstrNode.findall('./operand[@type="reg"][@r="1"]'): for chainOpNode1 in chainInstrNode.findall('./operand[@type="reg"][@r="1"]'):
regs1 = [r for r in chainOpNode1.text.split(',') if (r in GPRRegs and startReg in GPRRegs and regTo64(startReg)==regTo64(r)) or regs1 = [r for r in chainOpNode1.text.split(',') if (r in GPRegs and startReg in GPRegs and regTo64(startReg)==regTo64(r)) or
((r not in GPRRegs) and startReg[1:] == r[1:] and getRegSize(r) <= getRegSize(startReg))] ((r not in GPRegs) and startReg[1:] == r[1:] and getRegSize(r) <= getRegSize(startReg))]
if not regs1: continue if not regs1: continue
reg1 = regs1[0] reg1 = regs1[0]
for chainOpNode2 in chainInstrNode.findall('./operand[@type="reg"][@w="1"]'): for chainOpNode2 in chainInstrNode.findall('./operand[@type="reg"][@w="1"]'):
regs2 = [r for r in chainOpNode2.text.split(',') if r!=reg1 and ((r in GPRRegs and targetReg in GPRRegs and regTo64(targetReg)==regTo64(r)) or regs2 = [r for r in chainOpNode2.text.split(',') if r!=reg1 and ((r in GPRegs and targetReg in GPRegs and regTo64(targetReg)==regTo64(r)) or
((r not in GPRRegs) and targetReg[1:] == r[1:] and getRegSize(r) <= getRegSize(targetReg)))] ((r not in GPRegs) and targetReg[1:] == r[1:] and getRegSize(r) <= getRegSize(targetReg)))]
if not regs2: continue if not regs2: continue
reg2 = regs2[0] reg2 = regs2[0]
result.append(getInstrInstanceFromNode(chainInstrNode, [reg1, reg2], [reg1, reg2], True, {int(chainOpNode1.attrib['idx']):reg1, int(chainOpNode2.attrib['idx']):reg2})) result.append(getInstrInstanceFromNode(chainInstrNode, [reg1, reg2], [reg1, reg2], True, {int(chainOpNode1.attrib['idx']):reg1, int(chainOpNode2.attrib['idx']):reg2}))
@@ -1896,7 +1759,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
else: else:
if len(regs2) == 1: if len(regs2) == 1:
reg2 = sortRegs(regs2)[0] reg2 = sortRegs(regs2)[0]
otherRegs = filter(lambda x: (x in GPRRegs and regTo64(x)!=regTo64(reg2)) or (x not in GPRRegs and x[1:]!=reg2[1:]), regs1) otherRegs = filter(lambda x: (x in GPRegs and regTo64(x)!=regTo64(reg2)) or (x not in GPRegs and x[1:]!=reg2[1:]), regs1)
if otherRegs: if otherRegs:
reg1 = sortRegs(otherRegs)[0] reg1 = sortRegs(otherRegs)[0]
else: else:
@@ -1906,7 +1769,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
if not useDistinctRegs and reg1 in regs2: if not useDistinctRegs and reg1 in regs2:
reg2 = reg1 reg2 = reg1
else: else:
otherRegs = filter(lambda x: (x in GPRRegs and regTo64(x)!=regTo64(reg1)) or (x not in GPRRegs and x[1:]!=reg1[1:]), regs2) otherRegs = filter(lambda x: (x in GPRegs and regTo64(x)!=regTo64(reg1)) or (x not in GPRegs and x[1:]!=reg1[1:]), regs2)
if otherRegs: if otherRegs:
reg2 = sortRegs(otherRegs)[0] reg2 = sortRegs(otherRegs)[0]
else: else:
@@ -1920,7 +1783,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
reg1Prefix = re.sub('\d', '', reg1) reg1Prefix = re.sub('\d', '', reg1)
reg2Prefix = re.sub('\d', '', reg2) reg2Prefix = re.sub('\d', '', reg2)
if reg1 in GPRRegs and reg2 in GPRRegs: if reg1 in GPRegs and reg2 in GPRegs:
# MOVSX avoids partial reg stalls and cannot be eliminated by "move elimination" # MOVSX avoids partial reg stalls and cannot be eliminated by "move elimination"
chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg1), regToSize(reg2, min(32, getRegSize(reg2)))) chainInstrs = 'MOVSX {}, {};'.format(regTo64(reg1), regToSize(reg2, min(32, getRegSize(reg2))))
chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg1), regTo32(reg1)) * cRep chainInstrs += 'MOVSX {}, {};'.format(regTo64(reg1), regTo32(reg1)) * cRep
@@ -1973,7 +1836,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
if not ('flag_'+flag) in targetNode.attrib: continue if not ('flag_'+flag) in targetNode.attrib: continue
if not 'w' in targetNode.attrib[('flag_'+flag)]: continue if not 'w' in targetNode.attrib[('flag_'+flag)]: continue
if reg in GPRRegs: if reg in GPRegs:
regSize = getRegSize(reg) regSize = getRegSize(reg)
if regSize == 8: if regSize == 8:
chainInstr = 'SET{} {};'.format(flag[0], reg) chainInstr = 'SET{} {};'.format(flag[0], reg)
@@ -2003,7 +1866,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
reg = sortRegs(regs1)[0] reg = sortRegs(regs1)[0]
addrReg = getAddrReg(instrNode, targetNode) addrReg = getAddrReg(instrNode, targetNode)
if reg in GPRRegs: if reg in GPRegs:
instrI = getInstrInstanceFromNode(instrNode, useDistinctRegs=useDistinctRegs, opRegDict={startNodeIdx:reg}) instrI = getInstrInstanceFromNode(instrNode, useDistinctRegs=useDistinctRegs, opRegDict={startNodeIdx:reg})
configList.isUpperBound = True configList.isUpperBound = True
@@ -2042,7 +1905,7 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
reg = sortRegs(regs)[0] reg = sortRegs(regs)[0]
if reg in GPRRegs: if reg in GPRegs:
instrI = getInstrInstanceFromNode(instrNode, useDistinctRegs=useDistinctRegs, opRegDict={targetNodeIdx:reg}) instrI = getInstrInstanceFromNode(instrNode, useDistinctRegs=useDistinctRegs, opRegDict={targetNodeIdx:reg})
chainInstrs = 'TEST {0}, {0};'.format(reg) chainInstrs = 'TEST {0}, {0};'.format(reg)
chainLatency = basicLatency['TEST'] chainLatency = basicLatency['TEST']
@@ -2096,11 +1959,11 @@ def getLatConfigLists(instrNode, startNode, targetNode, useDistinctRegs, addr_me
regSize = getRegSize(reg) regSize = getRegSize(reg)
if suppressedStart: if suppressedStart:
if not regs.issubset(GPRRegs): if not regs.issubset(GPRegs):
print 'read from suppressed mem to non-GPR reg not yet supported' print 'read from suppressed mem to non-GPR reg not yet supported'
return None return None
if reg in GPRRegs: if reg in GPRegs:
instrI = getInstrInstanceFromNode(instrNode, [addrReg, 'R12'], [addrReg, 'R12'], useDistinctRegs, {targetNodeIdx:reg}) instrI = getInstrInstanceFromNode(instrNode, [addrReg, 'R12'], [addrReg, 'R12'], useDistinctRegs, {targetNodeIdx:reg})
if addr_mem == 'addr': if addr_mem == 'addr':
@@ -2309,7 +2172,7 @@ def getLatencies(instrNode, instrNodeList, tpDict, htmlReports):
print 'readOnlyRegOpNodeIdx not found in opRegDict' print 'readOnlyRegOpNodeIdx not found in opRegDict'
continue continue
reg = latConfig.instrI.opRegDict[readOnlyRegOpNodeIdx] reg = latConfig.instrI.opRegDict[readOnlyRegOpNodeIdx]
if not reg in GPRRegs or reg in globalDoNotWriteRegs or reg in specialRegs: continue if not reg in GPRegs or reg in globalDoNotWriteRegs or reg in specialRegs: continue
if any((opNode is not None) for opNode in instrNode.findall('./operand[@type="reg"][@w="1"]') if any((opNode is not None) for opNode in instrNode.findall('./operand[@type="reg"][@w="1"]')
if regTo64(latConfig.instrI.opRegDict[int(opNode.attrib['idx'])]) == regTo64(reg)): continue if regTo64(latConfig.instrI.opRegDict[int(opNode.attrib['idx'])]) == regTo64(reg)): continue
@@ -2516,7 +2379,7 @@ def filterInstructions(XMLRoot):
isaSet = XMLInstr.attrib['isa-set'] isaSet = XMLInstr.attrib['isa-set']
# Future instruction set extensions # Future instruction set extensions
if extension in ['CET', 'RDPRU']: instrSet.discard(XMLInstr) if extension in ['CET', 'RDPRU', 'SERIALIZE', 'TSX_LDTRK']: instrSet.discard(XMLInstr)
# Not supported by assembler # Not supported by assembler
if XMLInstr.attrib['iclass'] == 'NOP' and len(XMLInstr.findall('operand')) > 1: if XMLInstr.attrib['iclass'] == 'NOP' and len(XMLInstr.findall('operand')) > 1:
@@ -2621,7 +2484,8 @@ def filterInstructions(XMLRoot):
if extension == 'TBM' and not cpuid.get_bit(ecx8_1, 21): instrSet.discard(XMLInstr) if extension == 'TBM' and not cpuid.get_bit(ecx8_1, 21): instrSet.discard(XMLInstr)
if extension == 'RDTSCP' and not cpuid.get_bit(edx8_1, 27): instrSet.discard(XMLInstr) if extension == 'RDTSCP' and not cpuid.get_bit(edx8_1, 27): instrSet.discard(XMLInstr)
if extension == '3DNOW' and not cpuid.get_bit(edx8_1, 31): instrSet.discard(XMLInstr) if extension == '3DNOW' and not cpuid.get_bit(edx8_1, 31): instrSet.discard(XMLInstr)
if extension in ['CLZERO']and not cpuid.get_bit(ebx8_8, 0): instrSet.discard(XMLInstr) if extension == 'CLZERO' and not cpuid.get_bit(ebx8_8, 0): instrSet.discard(XMLInstr)
if extension == 'MCOMMIT' and not cpuid.get_bit(ebx8_8, 8): instrSet.discard(XMLInstr)
# Virtualization instructions # Virtualization instructions
if extension in ['SVM', 'VMFUNC', 'VTX']: instrSet.discard(XMLInstr) if extension in ['SVM', 'VMFUNC', 'VTX']: instrSet.discard(XMLInstr)
@@ -2640,7 +2504,7 @@ def filterInstructions(XMLRoot):
if XMLInstr.attrib['category'] in ['X87_ALU']: instrSet.discard(XMLInstr) if XMLInstr.attrib['category'] in ['X87_ALU']: instrSet.discard(XMLInstr)
# System instructions # System instructions
if extension in ['INVPCID', 'MONITOR', 'MONITORX', 'RDWRFSGS', 'SMAP', 'XSAVE', 'XSAVEC', 'XSAVEOPT', 'XSAVES']: instrSet.discard(XMLInstr) if extension in ['INVPCID', 'MONITOR', 'MONITORX', 'RDWRFSGS', 'SMAP', 'SNP', 'XSAVE', 'XSAVEC', 'XSAVEOPT', 'XSAVES']: instrSet.discard(XMLInstr)
if XMLInstr.attrib['category'] in ['INTERRUPT', 'SEGOP', 'SYSCALL', 'SYSRET']: instrSet.discard(XMLInstr) if XMLInstr.attrib['category'] in ['INTERRUPT', 'SEGOP', 'SYSCALL', 'SYSRET']: instrSet.discard(XMLInstr)
if XMLInstr.attrib['iclass'] in ['CALL_FAR', 'HLT', 'INVD', 'IRET', 'IRETD', 'IRETQ', 'JMP_FAR', 'LTR', 'RET_FAR', 'UD2']: if XMLInstr.attrib['iclass'] in ['CALL_FAR', 'HLT', 'INVD', 'IRET', 'IRETD', 'IRETQ', 'JMP_FAR', 'LTR', 'RET_FAR', 'UD2']:
instrSet.discard(XMLInstr) instrSet.discard(XMLInstr)

141
tools/cpuBench/x64_lib.py Normal file
View File

@@ -0,0 +1,141 @@
GPRegs = {'AH', 'AL', 'AX', 'BH', 'BL', 'BP', 'BPL', 'BX', 'CH', 'CL', 'CX', 'DH', 'DI', 'DIL', 'DL', 'DX', 'EAX',
'EBP', 'EBX', 'ECX', 'EDI', 'EDX', 'ESI', 'ESP', 'R10', 'R10B', 'R10D', 'R10W', 'R11', 'R11B', 'R11D', 'R11W', 'R12',
'R12B', 'R12D', 'R12W', 'R13', 'R13B', 'R13D', 'R13W', 'R14', 'R14B', 'R14D', 'R14W', 'R15', 'R15B', 'R15D', 'R15W',
'R8', 'R8B', 'R8D', 'R8W', 'R9', 'R9B', 'R9D', 'R9W', 'RAX', 'RBP', 'RBX', 'RCX', 'RDI', 'RDX', 'RSI', 'RSP', 'SI',
'SIL', 'SP', 'SPL'}
STATUSFLAGS = {'CF', 'PF', 'AF', 'ZF', 'SF', 'OF'}
STATUSFLAGS_noAF = {'CF', 'PF', 'ZF', 'SF', 'OF'}
def regTo64(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'RAX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'RBX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'RCX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'RDX'
if 'SP' in reg: return 'RSP'
if 'BP' in reg: return 'RBP'
if 'SI' in reg: return 'RSI'
if 'DI' in reg: return 'RDI'
if '8' in reg: return 'R8'
if '9' in reg: return 'R9'
if '10' in reg: return 'R10'
if '11' in reg: return 'R11'
if '12' in reg: return 'R12'
if '13' in reg: return 'R13'
if '14' in reg: return 'R14'
if '15' in reg: return 'R15'
def regTo32(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'EAX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'EBX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'ECX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'EDX'
if 'SP' in reg: return 'ESP'
if 'BP' in reg: return 'EBP'
if 'SI' in reg: return 'ESI'
if 'DI' in reg: return 'EDI'
if '8' in reg: return 'R8D'
if '9' in reg: return 'R9D'
if '10' in reg: return 'R10D'
if '11' in reg: return 'R11D'
if '12' in reg: return 'R12D'
if '13' in reg: return 'R13D'
if '14' in reg: return 'R14D'
if '15' in reg: return 'R15D'
def regTo16(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'AX'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'BX'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'CX'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'DX'
if 'SP' in reg: return 'SP'
if 'BP' in reg: return 'BP'
if 'SI' in reg: return 'SI'
if 'DI' in reg: return 'DI'
if '8' in reg: return 'R8W'
if '9' in reg: return 'R9W'
if '10' in reg: return 'R10W'
if '11' in reg: return 'R11W'
if '12' in reg: return 'R12W'
if '13' in reg: return 'R13W'
if '14' in reg: return 'R14W'
if '15' in reg: return 'R15W'
def regTo8(reg):
if 'AX' in reg or 'AH' in reg or 'AL' in reg: return 'AL'
if 'BX' in reg or 'BH' in reg or 'BL' in reg: return 'BL'
if 'CX' in reg or 'CH' in reg or 'CL' in reg: return 'CL'
if 'DX' in reg or 'DH' in reg or 'DL' in reg: return 'DL'
if 'SP' in reg: return 'SPL'
if 'BP' in reg: return 'BPL'
if 'SI' in reg: return 'SIL'
if 'DI' in reg: return 'DIL'
if '8' in reg: return 'R8B'
if '9' in reg: return 'R9B'
if '10' in reg: return 'R10B'
if '11' in reg: return 'R11B'
if '12' in reg: return 'R12B'
if '13' in reg: return 'R13B'
if '14' in reg: return 'R14B'
if '15' in reg: return 'R15B'
def regToSize(reg, size):
if size == 8: return regTo8(reg)
elif size == 16: return regTo16(reg)
elif size == 32: return regTo32(reg)
else: return regTo64(reg)
# Returns a set of registers that are a part of the register that is provided (e.g., EAX is a part of RAX; RAX is also a part of RAX)
def getSubRegs(reg):
subRegs = set()
subRegs.add(reg)
if reg in GPRegs:
regSize = getRegSize(reg)
if regSize > 8:
for size in [16, 32, 64]:
if size > regSize: continue
subRegs.add(regToSize(reg, size))
if 'AX' in reg or 'BX' in reg or 'CX' in reg or 'DX' in reg:
subRegs.add(reg[-2] + 'L')
subRegs.add(reg[-2] + 'H')
else:
subRegs.add(regTo8(reg))
elif 'ZMM' in reg:
subRegs.add('Y' + reg[1:])
subRegs.add('X' + reg[1:])
elif 'YMM' in reg:
subRegs.add('X' + reg[1:])
return subRegs
# Returns for a GPR the corresponding 64-bit registers, and for a (X|Y|Z)MM register the corresponding XMM register
def getCanonicalReg(reg):
if reg in GPRegs:
return regTo64(reg)
elif 'MM' in reg:
return re.sub('^[YZ]', 'X', reg)
else:
return reg
def getRegForMemPrefix(reg, memPrefix):
return regToSize(reg, getSizeOfMemPrefix(memPrefix))
def getSizeOfMemPrefix(memPrefix):
if 'zmmword' in memPrefix: return 512
elif 'ymmword' in memPrefix: return 256
elif 'xmmword' in memPrefix: return 128
elif 'qword' in memPrefix: return 64
elif 'dword' in memPrefix: return 32
elif 'word' in memPrefix: return 16
elif 'byte' in memPrefix: return 8
else: return -1
def getRegSize(reg):
if reg[-1] == 'L' or reg[-1] == 'H' or reg[-1] == 'B': return 8
elif reg[-1] == 'W' or reg in ['AX', 'BX', 'CX', 'DX', 'SP', 'BP' 'SI', 'DI']: return 16
elif reg[0] == 'E' or reg[-1] == 'D': return 32
elif reg in GPRegs: return 64
elif reg.startswith('MM'): return 64
elif reg.startswith('XMM'): return 128
elif reg.startswith('YMM'): return 256
elif reg.startswith('ZMM'): return 512
else: return -1