added dependency check for registers and tests

This commit is contained in:
JanLJL
2019-06-12 18:57:53 +02:00
parent 8b377f4db1
commit 75a405e33e
6 changed files with 347 additions and 59 deletions

92
osaca/dependency_finder.py Executable file
View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python3
import networkx as nx
class KernelDAG(nx.DiGraph):
def __init__(self, parsed_kernel, parser, hw_model):
self.kernel = parsed_kernel
self.parser = parser
self.model = hw_model
# self.dag = self.create_DAG()
def check_for_loop(self, kernel):
raise NotImplementedError
def create_DAG(self):
# 1. go through kernel instruction forms (as vertices)
# 2. find edges (to dependend further instruction)
# 3. get LT/TP value and set as edge weight
dag = nx.DiGraph()
for i, instruction in enumerate(self.kernel):
throughput = self.model.get_throughput(instruction)
latency = self.model.get_latency(instruction)
for dep in self.find_depending(instruction, self.kernel[i + 1:]):
dag.add_edge(
instruction.line_number,
dep.line_number,
latency=latency,
thorughput=throughput,
)
def find_depending(self, instruction_form, kernel):
for dst in instruction_form.operands.destination:
if 'register' in dst:
# Check for read of register until overwrite
for instr_form in kernel:
if self.is_read(dst.register, instr_form):
yield instr_form
elif self.is_written(dst.register, instr_form):
break
elif 'memory' in dst:
# Check if base register is altered during memory access
if 'pre_indexed' in dst.memory or 'post_indexed' in dst.memory:
# Check for read of base register until overwrite
for instr_form in kernel:
if self.is_read(dst.memory.base, instr_form):
yield instr_form
elif self.is_written(dst.memory.base, instr_form):
break
def is_read(self, register, instruction_form):
is_read = False
for src in instruction_form.operands.source:
if 'register' in src:
is_read = self.parser.is_reg_dependend_of(register, src.register) or is_read
if 'memory' in src:
if src.memory.base is not None:
is_read = self.parser.is_reg_dependend_of(register, src.memory.base) or is_read
if src.memory.index is not None:
is_read = (
self.parser.is_reg_dependend_of(register, src.memory.index) or is_read
)
# Check also if read in destination memory address
for dst in instruction_form.operands.destination:
if 'memory' in dst:
if dst.memory.base is not None:
is_read = self.parser.is_reg_dependend_of(register, dst.memory.base) or is_read
if dst.memory.index is not None:
is_read = (
self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read
)
return is_read
def is_written(self, register, instruction_form):
is_written = False
for dst in instruction_form.operands.destination:
if 'register' in dst:
is_written = self.parser.is_reg_dependend_of(register, dst.register) or is_written
if 'memory' in dst:
if 'pre_indexed' in dst.memory or 'post_indexed' in dst.memory:
is_written = (
self.parser.is_reg_dependend_of(register, dst.memory.base) or is_written
)
# Check also for possible pre- or post-indexing in memory addresses
for src in instruction_form.operands.source:
if 'memory' in src:
if 'pre_indexed' in src.memory or 'post_indexed' in src.memory:
is_written = (
self.parser.is_reg_dependend_of(register, src.memory.base) or is_written
)
return is_written

View File

@@ -1,4 +1,4 @@
#!usr/bin/env python3
#!/usr/bin/env python3
class BaseParser(object):
@@ -43,3 +43,19 @@ class BaseParser(object):
def construct_parser(self):
raise NotImplementedError()
##################
# Helper functions
##################
def process_operand(self, operand):
raise NotImplementedError
def get_full_reg_name(self, register):
raise NotImplementedError
def normalize_imd(self, imd):
raise NotImplementedError
def is_reg_dependend_of(self, reg_a, reg_b):
raise NotImplementedError

View File

@@ -189,7 +189,7 @@ class ParserAArch64v81(BaseParser):
# 1. Parse comment
try:
result = self._process_operand(self.comment.parseString(line, parseAll=True).asDict())
result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict())
result = AttrDict.convert_dict(result)
instruction_form[self.COMMENT_ID] = ' '.join(result[self.COMMENT_ID])
except pp.ParseException:
@@ -198,7 +198,7 @@ class ParserAArch64v81(BaseParser):
# 2. Parse label
if result is None:
try:
result = self._process_operand(
result = self.process_operand(
self.label.parseString(line, parseAll=True).asDict()
)
result = AttrDict.convert_dict(result)
@@ -213,7 +213,7 @@ class ParserAArch64v81(BaseParser):
# 3. Parse directive
if result is None:
try:
result = self._process_operand(
result = self.process_operand(
self.directive.parseString(line, parseAll=True).asDict()
)
result = AttrDict.convert_dict(result)
@@ -262,28 +262,28 @@ class ParserAArch64v81(BaseParser):
# Check first operand
if 'operand1' in result:
if is_store:
operands.source.append(self._process_operand(result['operand1']))
operands.source.append(self.process_operand(result['operand1']))
else:
operands.destination.append(self._process_operand(result['operand1']))
operands.destination.append(self.process_operand(result['operand1']))
# Check second operand
if 'operand2' in result:
if is_store and 'operand3' not in result or is_load and 'operand3' in result:
# destination
operands.destination.append(self._process_operand(result['operand2']))
operands.destination.append(self.process_operand(result['operand2']))
else:
operands.source.append(self._process_operand(result['operand2']))
operands.source.append(self.process_operand(result['operand2']))
# Check third operand
if 'operand3' in result:
if is_store and 'operand4' not in result or is_load and 'operand4' in result:
operands.destination.append(self._process_operand(result['operand3']))
operands.destination.append(self.process_operand(result['operand3']))
else:
operands.source.append(self._process_operand(result['operand3']))
operands.source.append(self.process_operand(result['operand3']))
# Check fourth operand
if 'operand4' in result:
if is_store:
operands.destination.append(self._process_operand(result['operand4']))
operands.destination.append(self.process_operand(result['operand4']))
else:
operands.source.append(self._process_operand(result['operand4']))
operands.source.append(self.process_operand(result['operand4']))
return_dict = AttrDict(
{
@@ -296,7 +296,7 @@ class ParserAArch64v81(BaseParser):
)
return return_dict
def _process_operand(self, operand):
def process_operand(self, operand):
# structure memory addresses
if self.MEMORY_ID in operand:
return self.substitute_memory_address(operand[self.MEMORY_ID])
@@ -390,14 +390,19 @@ class ParserAArch64v81(BaseParser):
return int(imd['value'], 16)
return int(imd['value'], 10)
elif 'float' in imd:
return self.ieee_val_to_int(imd['float'])
return self.ieee_to_int(imd['float'])
elif 'double' in imd:
return self.ieee_val_to_int(imd['double'])
return self.ieee_to_int(imd['double'])
# identifier
return imd
def ieee_to_int(self, ieee_val):
exponent = int(ieee_val['exponent'], 10)
if ieee_val.e_sign == '-':
if ieee_val['e_sign'] == '-':
exponent *= -1
return float(ieee_val['mantissa']) * (10 ** exponent)
def is_reg_dependend_of(self, reg_a, reg_b):
if reg_a['name'] == reg_b['name']:
return True
return False

View File

@@ -2,8 +2,8 @@
import pyparsing as pp
from .base_parser import BaseParser
from .attr_dict import AttrDict
from .base_parser import BaseParser
class ParserX86ATT(BaseParser):
@@ -103,19 +103,21 @@ class ParserX86ATT(BaseParser):
:param int line_id: default None, identifier of instruction form
:return: parsed instruction form
"""
instruction_form = AttrDict({
self.INSTRUCTION_ID: None,
self.OPERANDS_ID: None,
self.DIRECTIVE_ID: None,
self.COMMENT_ID: None,
self.LABEL_ID: None,
'line_number': line_number,
})
instruction_form = AttrDict(
{
self.INSTRUCTION_ID: None,
self.OPERANDS_ID: None,
self.DIRECTIVE_ID: None,
self.COMMENT_ID: None,
self.LABEL_ID: None,
'line_number': line_number,
}
)
result = None
# 1. Parse comment
try:
result = self._process_operand(self.comment.parseString(line, parseAll=True).asDict())
result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict())
result = AttrDict.convert_dict(result)
instruction_form[self.COMMENT_ID] = ' '.join(result[self.COMMENT_ID])
except pp.ParseException:
@@ -124,7 +126,7 @@ class ParserX86ATT(BaseParser):
# 2. Parse label
if result is None:
try:
result = self._process_operand(
result = self.process_operand(
self.label.parseString(line, parseAll=True).asDict()
)
result = AttrDict.convert_dict(result)
@@ -139,14 +141,16 @@ class ParserX86ATT(BaseParser):
# 3. Parse directive
if result is None:
try:
result = self._process_operand(
result = self.process_operand(
self.directive.parseString(line, parseAll=True).asDict()
)
result = AttrDict.convert_dict(result)
instruction_form[self.DIRECTIVE_ID] = AttrDict({
'name': result[self.DIRECTIVE_ID]['name'],
'parameters': result[self.DIRECTIVE_ID]['parameters'],
})
instruction_form[self.DIRECTIVE_ID] = AttrDict(
{
'name': result[self.DIRECTIVE_ID]['name'],
'parameters': result[self.DIRECTIVE_ID]['parameters'],
}
)
if self.COMMENT_ID in result[self.DIRECTIVE_ID]:
instruction_form[self.COMMENT_ID] = ' '.join(
result[self.DIRECTIVE_ID][self.COMMENT_ID]
@@ -177,29 +181,31 @@ class ParserX86ATT(BaseParser):
# Check from right to left
# Check third operand
if 'operand3' in result:
operands['destination'].append(self._process_operand(result['operand3']))
operands['destination'].append(self.process_operand(result['operand3']))
# Check second operand
if 'operand2' in result:
if len(operands['destination']) != 0:
operands['source'].insert(0, self._process_operand(result['operand2']))
operands['source'].insert(0, self.process_operand(result['operand2']))
else:
operands['destination'].append(self._process_operand(result['operand2']))
operands['destination'].append(self.process_operand(result['operand2']))
# Check first operand
if 'operand1' in result:
if len(operands['destination']) != 0:
operands['source'].insert(0, self._process_operand(result['operand1']))
operands['source'].insert(0, self.process_operand(result['operand1']))
else:
operands['destination'].append(self._process_operand(result['operand1']))
return_dict = AttrDict({
self.INSTRUCTION_ID: result['mnemonic'],
self.OPERANDS_ID: operands,
self.COMMENT_ID: ' '.join(result[self.COMMENT_ID])
if self.COMMENT_ID in result
else None,
})
operands['destination'].append(self.process_operand(result['operand1']))
return_dict = AttrDict(
{
self.INSTRUCTION_ID: result['mnemonic'],
self.OPERANDS_ID: operands,
self.COMMENT_ID: ' '.join(result[self.COMMENT_ID])
if self.COMMENT_ID in result
else None,
}
)
return return_dict
def _process_operand(self, operand):
def process_operand(self, operand):
# For the moment, only used to structure memory addresses
if self.MEMORY_ID in operand:
return self.substitute_memory_address(operand[self.MEMORY_ID])
@@ -242,3 +248,56 @@ class ParserX86ATT(BaseParser):
return int(imd['value'], 10)
# identifier
return imd
def is_reg_dependend_of(self, reg_a, reg_b):
# Check if they are the same registers
if reg_a.name == reg_b.name:
return True
# Check vector registers first
if self.is_vector_register(reg_a):
if self.is_vector_register(reg_b):
if reg_a.name[1:] == reg_b.name[1:]:
# Registers in the same vector space
return True
return False
# Check basic GPRs
a_dep = ['RAX', 'EAX', 'AX', 'AH', 'AL']
b_dep = ['RBX', 'EBX', 'BX', 'BH', 'BL']
c_dep = ['RCX', 'ECX', 'CX', 'CH', 'CL']
d_dep = ['RDX', 'EDX', 'DX', 'DH', 'DL']
sp_dep = ['RSP', 'ESP', 'SP', 'SPL']
src_dep = ['RSI', 'ESI', 'SI', 'SIL']
dst_dep = ['RDI', 'EDI', 'DI', 'DIL']
basic_gprs = [a_dep, b_dep, c_dep, d_dep, sp_dep, src_dep, dst_dep]
if self.is_basic_gpr(reg_a):
if self.is_basic_gpr(reg_b):
for dep_group in basic_gprs:
if reg_a['name'].upper() in dep_group:
if reg_b['name'].upper() in dep_group:
return True
return False
# Check other GPRs
gpr_parser = (
pp.CaselessLiteral('R')
+ pp.Word(pp.nums).setResultsName('id')
+ pp.Optional(pp.Word('dwbDWB', exact=1))
)
try:
id_a = gpr_parser.parseString(reg_a['name'], parseAll=True).asDict()['id']
id_b = gpr_parser.parseString(reg_b['name'], parseAll=True).asDict()['id']
if id_a == id_b:
return True
except pp.ParseException:
return False
# No dependencies
return False
def is_basic_gpr(self, register):
if any(char.isdigit() for char in register['name']):
return False
return True
def is_vector_register(self, register):
if len(register['name']) > 2 and register['name'][1:3] == 'mm':
return True
return False

View File

@@ -283,28 +283,90 @@ class TestParserAArch64v81(unittest.TestCase):
self.assertEqual(parsed[0].line_number, 1)
self.assertEqual(len(parsed), 645)
def test_normalize_imd(self):
imd_decimal_1 = {'value': '79'}
imd_hex_1 = {'value': '0x4f'}
imd_decimal_2 = {'value': '8'}
imd_hex_2 = {'value': '0x8'}
imd_float_11 = {'float': {'mantissa': '0.79', 'e_sign': '+', 'exponent': '2'}}
imd_float_12 = {'float': {'mantissa': '790.0', 'e_sign': '-', 'exponent': '1'}}
imd_double_11 = {'double': {'mantissa': '0.79', 'e_sign': '+', 'exponent': '2'}}
imd_double_12 = {'double': {'mantissa': '790.0', 'e_sign': '-', 'exponent': '1'}}
value1 = self.parser.normalize_imd(imd_decimal_1)
self.assertEqual(value1, self.parser.normalize_imd(imd_hex_1))
self.assertEqual(
self.parser.normalize_imd(imd_decimal_2), self.parser.normalize_imd(imd_hex_2)
)
self.assertEqual(self.parser.normalize_imd(imd_float_11), value1)
self.assertEqual(self.parser.normalize_imd(imd_float_12), value1)
self.assertEqual(self.parser.normalize_imd(imd_double_11), value1)
self.assertEqual(self.parser.normalize_imd(imd_double_12), value1)
def test_reg_dependency(self):
reg_1_1 = AttrDict({'prefix': 'b', 'name': '1'})
reg_1_2 = AttrDict({'prefix': 'h', 'name': '1'})
reg_1_3 = AttrDict({'prefix': 's', 'name': '1'})
reg_1_4 = AttrDict({'prefix': 'd', 'name': '1'})
reg_1_4 = AttrDict({'prefix': 'q', 'name': '1'})
reg_2_1 = AttrDict({'prefix': 'w', 'name': '2'})
reg_2_2 = AttrDict({'prefix': 'x', 'name': '2'})
reg_v1_1 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '16', 'shape': 'b'})
reg_v1_2 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '8', 'shape': 'h'})
reg_v1_3 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '4', 'shape': 's'})
reg_v1_4 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '2', 'shape': 'd'})
reg_b5 = AttrDict({'prefix': 'b', 'name': '5'})
reg_q15 = AttrDict({'prefix': 'q', 'name': '15'})
reg_v10 = AttrDict({'prefix': 'v', 'name': '10', 'lanes': '2', 'shape': 's'})
reg_v20 = AttrDict({'prefix': 'v', 'name': '20', 'lanes': '2', 'shape': 'd'})
reg_1 = [reg_1_1, reg_1_2, reg_1_3, reg_1_4]
reg_2 = [reg_2_1, reg_2_2]
reg_v = [reg_v1_1, reg_v1_2, reg_v1_3, reg_v1_4]
reg_others = [reg_b5, reg_q15, reg_v10, reg_v20]
regs = reg_1 + reg_2 + reg_v + reg_others
# test each register against each other
for ri in reg_1:
for rj in regs:
assert_value = True if rj in reg_1 else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_2:
for rj in regs:
assert_value = True if rj in reg_2 else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_v:
for rj in regs:
assert_value = True if rj in reg_v else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_others:
for rj in regs:
assert_value = True if rj == ri else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
##################
# Helper functions
##################
def _get_comment(self, parser, comment):
return ' '.join(
AttrDict.convert_dict(
parser._process_operand(
parser.comment.parseString(comment, parseAll=True).asDict()
)
parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())
).comment
)
def _get_label(self, parser, label):
return AttrDict.convert_dict(
parser._process_operand(parser.label.parseString(label, parseAll=True).asDict())
parser.process_operand(parser.label.parseString(label, parseAll=True).asDict())
).label
def _get_directive(self, parser, directive):
return AttrDict.convert_dict(
parser._process_operand(
parser.directive.parseString(directive, parseAll=True).asDict()
)
parser.process_operand(parser.directive.parseString(directive, parseAll=True).asDict())
).directive
@staticmethod

View File

@@ -186,28 +186,82 @@ class TestParserX86ATT(unittest.TestCase):
self.assertEqual(parsed[0].line_number, 1)
self.assertEqual(len(parsed), 353)
def test_normalize_imd(self):
imd_decimal_1 = {'value': '79'}
imd_hex_1 = {'value': '0x4f'}
imd_decimal_2 = {'value': '8'}
imd_hex_2 = {'value': '0x8'}
self.assertEqual(
self.parser.normalize_imd(imd_decimal_1), self.parser.normalize_imd(imd_hex_1)
)
self.assertEqual(
self.parser.normalize_imd(imd_decimal_2), self.parser.normalize_imd(imd_hex_2)
)
def test_reg_dependency(self):
reg_a1 = AttrDict({'name': 'rax'})
reg_a2 = AttrDict({'name': 'eax'})
reg_a3 = AttrDict({'name': 'ax'})
reg_a4 = AttrDict({'name': 'al'})
reg_r11 = AttrDict({'name': 'r11'})
reg_r11b = AttrDict({'name': 'r11b'})
reg_r11d = AttrDict({'name': 'r11d'})
reg_r11w = AttrDict({'name': 'r11w'})
reg_xmm1 = AttrDict({'name': 'xmm1'})
reg_ymm1 = AttrDict({'name': 'ymm1'})
reg_zmm1 = AttrDict({'name': 'zmm1'})
reg_b1 = AttrDict({'name': 'rbx'})
reg_r15 = AttrDict({'name': 'r15'})
reg_xmm2 = AttrDict({'name': 'xmm2'})
reg_ymm3 = AttrDict({'name': 'ymm3'})
reg_a = [reg_a1, reg_a2, reg_a3, reg_a4]
reg_r = [reg_r11, reg_r11b, reg_r11d, reg_r11w]
reg_vec_1 = [reg_xmm1, reg_ymm1, reg_zmm1]
reg_others = [reg_b1, reg_r15, reg_xmm2, reg_ymm3]
regs = reg_a + reg_r + reg_vec_1 + reg_others
# test each register against each other
for ri in reg_a:
for rj in regs:
assert_value = True if rj in reg_a else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_r:
for rj in regs:
assert_value = True if rj in reg_r else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_vec_1:
for rj in regs:
assert_value = True if rj in reg_vec_1 else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
for ri in reg_others:
for rj in regs:
assert_value = True if rj == ri else False
with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value):
self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value)
##################
# Helper functions
##################
def _get_comment(self, parser, comment):
return ' '.join(
AttrDict.convert_dict(
parser._process_operand(
parser.comment.parseString(comment, parseAll=True).asDict()
)
parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())
).comment
)
def _get_label(self, parser, label):
return AttrDict.convert_dict(
parser._process_operand(parser.label.parseString(label, parseAll=True).asDict())
parser.process_operand(parser.label.parseString(label, parseAll=True).asDict())
).label
def _get_directive(self, parser, directive):
return AttrDict.convert_dict(
parser._process_operand(
parser.directive.parseString(directive, parseAll=True).asDict()
)
parser.process_operand(parser.directive.parseString(directive, parseAll=True).asDict())
).directive
@staticmethod