From 75a405e33e1d5a099a01423c6a7ce51060090494 Mon Sep 17 00:00:00 2001 From: JanLJL Date: Wed, 12 Jun 2019 18:57:53 +0200 Subject: [PATCH] added dependency check for registers and tests --- osaca/dependency_finder.py | 92 +++++++++++++++++++++++ osaca/parser/base_parser.py | 18 ++++- osaca/parser/parser_AArch64v81.py | 35 +++++---- osaca/parser/parser_x86att.py | 117 ++++++++++++++++++++++-------- tests/test_parser_AArch64v81.py | 76 +++++++++++++++++-- tests/test_parser_x86att.py | 68 +++++++++++++++-- 6 files changed, 347 insertions(+), 59 deletions(-) create mode 100755 osaca/dependency_finder.py diff --git a/osaca/dependency_finder.py b/osaca/dependency_finder.py new file mode 100755 index 0000000..3ee654e --- /dev/null +++ b/osaca/dependency_finder.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +import networkx as nx + + +class KernelDAG(nx.DiGraph): + def __init__(self, parsed_kernel, parser, hw_model): + self.kernel = parsed_kernel + self.parser = parser + self.model = hw_model + + # self.dag = self.create_DAG() + + def check_for_loop(self, kernel): + raise NotImplementedError + + def create_DAG(self): + # 1. go through kernel instruction forms (as vertices) + # 2. find edges (to dependend further instruction) + # 3. get LT/TP value and set as edge weight + dag = nx.DiGraph() + for i, instruction in enumerate(self.kernel): + throughput = self.model.get_throughput(instruction) + latency = self.model.get_latency(instruction) + for dep in self.find_depending(instruction, self.kernel[i + 1:]): + dag.add_edge( + instruction.line_number, + dep.line_number, + latency=latency, + thorughput=throughput, + ) + + def find_depending(self, instruction_form, kernel): + for dst in instruction_form.operands.destination: + if 'register' in dst: + # Check for read of register until overwrite + for instr_form in kernel: + if self.is_read(dst.register, instr_form): + yield instr_form + elif self.is_written(dst.register, instr_form): + break + elif 'memory' in dst: + # Check if base register is altered during memory access + if 'pre_indexed' in dst.memory or 'post_indexed' in dst.memory: + # Check for read of base register until overwrite + for instr_form in kernel: + if self.is_read(dst.memory.base, instr_form): + yield instr_form + elif self.is_written(dst.memory.base, instr_form): + break + + def is_read(self, register, instruction_form): + is_read = False + for src in instruction_form.operands.source: + if 'register' in src: + is_read = self.parser.is_reg_dependend_of(register, src.register) or is_read + if 'memory' in src: + if src.memory.base is not None: + is_read = self.parser.is_reg_dependend_of(register, src.memory.base) or is_read + if src.memory.index is not None: + is_read = ( + self.parser.is_reg_dependend_of(register, src.memory.index) or is_read + ) + # Check also if read in destination memory address + for dst in instruction_form.operands.destination: + if 'memory' in dst: + if dst.memory.base is not None: + is_read = self.parser.is_reg_dependend_of(register, dst.memory.base) or is_read + if dst.memory.index is not None: + is_read = ( + self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read + ) + return is_read + + def is_written(self, register, instruction_form): + is_written = False + for dst in instruction_form.operands.destination: + if 'register' in dst: + is_written = self.parser.is_reg_dependend_of(register, dst.register) or is_written + if 'memory' in dst: + if 'pre_indexed' in dst.memory or 'post_indexed' in dst.memory: + is_written = ( + self.parser.is_reg_dependend_of(register, dst.memory.base) or is_written + ) + # Check also for possible pre- or post-indexing in memory addresses + for src in instruction_form.operands.source: + if 'memory' in src: + if 'pre_indexed' in src.memory or 'post_indexed' in src.memory: + is_written = ( + self.parser.is_reg_dependend_of(register, src.memory.base) or is_written + ) + return is_written diff --git a/osaca/parser/base_parser.py b/osaca/parser/base_parser.py index 3982ba8..d2821dd 100755 --- a/osaca/parser/base_parser.py +++ b/osaca/parser/base_parser.py @@ -1,4 +1,4 @@ -#!usr/bin/env python3 +#!/usr/bin/env python3 class BaseParser(object): @@ -43,3 +43,19 @@ class BaseParser(object): def construct_parser(self): raise NotImplementedError() + + ################## + # Helper functions + ################## + + def process_operand(self, operand): + raise NotImplementedError + + def get_full_reg_name(self, register): + raise NotImplementedError + + def normalize_imd(self, imd): + raise NotImplementedError + + def is_reg_dependend_of(self, reg_a, reg_b): + raise NotImplementedError diff --git a/osaca/parser/parser_AArch64v81.py b/osaca/parser/parser_AArch64v81.py index 9917ae1..8f98e61 100755 --- a/osaca/parser/parser_AArch64v81.py +++ b/osaca/parser/parser_AArch64v81.py @@ -189,7 +189,7 @@ class ParserAArch64v81(BaseParser): # 1. Parse comment try: - result = self._process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) result = AttrDict.convert_dict(result) instruction_form[self.COMMENT_ID] = ' '.join(result[self.COMMENT_ID]) except pp.ParseException: @@ -198,7 +198,7 @@ class ParserAArch64v81(BaseParser): # 2. Parse label if result is None: try: - result = self._process_operand( + result = self.process_operand( self.label.parseString(line, parseAll=True).asDict() ) result = AttrDict.convert_dict(result) @@ -213,7 +213,7 @@ class ParserAArch64v81(BaseParser): # 3. Parse directive if result is None: try: - result = self._process_operand( + result = self.process_operand( self.directive.parseString(line, parseAll=True).asDict() ) result = AttrDict.convert_dict(result) @@ -262,28 +262,28 @@ class ParserAArch64v81(BaseParser): # Check first operand if 'operand1' in result: if is_store: - operands.source.append(self._process_operand(result['operand1'])) + operands.source.append(self.process_operand(result['operand1'])) else: - operands.destination.append(self._process_operand(result['operand1'])) + operands.destination.append(self.process_operand(result['operand1'])) # Check second operand if 'operand2' in result: if is_store and 'operand3' not in result or is_load and 'operand3' in result: # destination - operands.destination.append(self._process_operand(result['operand2'])) + operands.destination.append(self.process_operand(result['operand2'])) else: - operands.source.append(self._process_operand(result['operand2'])) + operands.source.append(self.process_operand(result['operand2'])) # Check third operand if 'operand3' in result: if is_store and 'operand4' not in result or is_load and 'operand4' in result: - operands.destination.append(self._process_operand(result['operand3'])) + operands.destination.append(self.process_operand(result['operand3'])) else: - operands.source.append(self._process_operand(result['operand3'])) + operands.source.append(self.process_operand(result['operand3'])) # Check fourth operand if 'operand4' in result: if is_store: - operands.destination.append(self._process_operand(result['operand4'])) + operands.destination.append(self.process_operand(result['operand4'])) else: - operands.source.append(self._process_operand(result['operand4'])) + operands.source.append(self.process_operand(result['operand4'])) return_dict = AttrDict( { @@ -296,7 +296,7 @@ class ParserAArch64v81(BaseParser): ) return return_dict - def _process_operand(self, operand): + def process_operand(self, operand): # structure memory addresses if self.MEMORY_ID in operand: return self.substitute_memory_address(operand[self.MEMORY_ID]) @@ -390,14 +390,19 @@ class ParserAArch64v81(BaseParser): return int(imd['value'], 16) return int(imd['value'], 10) elif 'float' in imd: - return self.ieee_val_to_int(imd['float']) + return self.ieee_to_int(imd['float']) elif 'double' in imd: - return self.ieee_val_to_int(imd['double']) + return self.ieee_to_int(imd['double']) # identifier return imd def ieee_to_int(self, ieee_val): exponent = int(ieee_val['exponent'], 10) - if ieee_val.e_sign == '-': + if ieee_val['e_sign'] == '-': exponent *= -1 return float(ieee_val['mantissa']) * (10 ** exponent) + + def is_reg_dependend_of(self, reg_a, reg_b): + if reg_a['name'] == reg_b['name']: + return True + return False diff --git a/osaca/parser/parser_x86att.py b/osaca/parser/parser_x86att.py index e9bcc7c..08fc8ba 100755 --- a/osaca/parser/parser_x86att.py +++ b/osaca/parser/parser_x86att.py @@ -2,8 +2,8 @@ import pyparsing as pp -from .base_parser import BaseParser from .attr_dict import AttrDict +from .base_parser import BaseParser class ParserX86ATT(BaseParser): @@ -103,19 +103,21 @@ class ParserX86ATT(BaseParser): :param int line_id: default None, identifier of instruction form :return: parsed instruction form """ - instruction_form = AttrDict({ - self.INSTRUCTION_ID: None, - self.OPERANDS_ID: None, - self.DIRECTIVE_ID: None, - self.COMMENT_ID: None, - self.LABEL_ID: None, - 'line_number': line_number, - }) + instruction_form = AttrDict( + { + self.INSTRUCTION_ID: None, + self.OPERANDS_ID: None, + self.DIRECTIVE_ID: None, + self.COMMENT_ID: None, + self.LABEL_ID: None, + 'line_number': line_number, + } + ) result = None # 1. Parse comment try: - result = self._process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) result = AttrDict.convert_dict(result) instruction_form[self.COMMENT_ID] = ' '.join(result[self.COMMENT_ID]) except pp.ParseException: @@ -124,7 +126,7 @@ class ParserX86ATT(BaseParser): # 2. Parse label if result is None: try: - result = self._process_operand( + result = self.process_operand( self.label.parseString(line, parseAll=True).asDict() ) result = AttrDict.convert_dict(result) @@ -139,14 +141,16 @@ class ParserX86ATT(BaseParser): # 3. Parse directive if result is None: try: - result = self._process_operand( + result = self.process_operand( self.directive.parseString(line, parseAll=True).asDict() ) result = AttrDict.convert_dict(result) - instruction_form[self.DIRECTIVE_ID] = AttrDict({ - 'name': result[self.DIRECTIVE_ID]['name'], - 'parameters': result[self.DIRECTIVE_ID]['parameters'], - }) + instruction_form[self.DIRECTIVE_ID] = AttrDict( + { + 'name': result[self.DIRECTIVE_ID]['name'], + 'parameters': result[self.DIRECTIVE_ID]['parameters'], + } + ) if self.COMMENT_ID in result[self.DIRECTIVE_ID]: instruction_form[self.COMMENT_ID] = ' '.join( result[self.DIRECTIVE_ID][self.COMMENT_ID] @@ -177,29 +181,31 @@ class ParserX86ATT(BaseParser): # Check from right to left # Check third operand if 'operand3' in result: - operands['destination'].append(self._process_operand(result['operand3'])) + operands['destination'].append(self.process_operand(result['operand3'])) # Check second operand if 'operand2' in result: if len(operands['destination']) != 0: - operands['source'].insert(0, self._process_operand(result['operand2'])) + operands['source'].insert(0, self.process_operand(result['operand2'])) else: - operands['destination'].append(self._process_operand(result['operand2'])) + operands['destination'].append(self.process_operand(result['operand2'])) # Check first operand if 'operand1' in result: if len(operands['destination']) != 0: - operands['source'].insert(0, self._process_operand(result['operand1'])) + operands['source'].insert(0, self.process_operand(result['operand1'])) else: - operands['destination'].append(self._process_operand(result['operand1'])) - return_dict = AttrDict({ - self.INSTRUCTION_ID: result['mnemonic'], - self.OPERANDS_ID: operands, - self.COMMENT_ID: ' '.join(result[self.COMMENT_ID]) - if self.COMMENT_ID in result - else None, - }) + operands['destination'].append(self.process_operand(result['operand1'])) + return_dict = AttrDict( + { + self.INSTRUCTION_ID: result['mnemonic'], + self.OPERANDS_ID: operands, + self.COMMENT_ID: ' '.join(result[self.COMMENT_ID]) + if self.COMMENT_ID in result + else None, + } + ) return return_dict - def _process_operand(self, operand): + def process_operand(self, operand): # For the moment, only used to structure memory addresses if self.MEMORY_ID in operand: return self.substitute_memory_address(operand[self.MEMORY_ID]) @@ -242,3 +248,56 @@ class ParserX86ATT(BaseParser): return int(imd['value'], 10) # identifier return imd + + def is_reg_dependend_of(self, reg_a, reg_b): + # Check if they are the same registers + if reg_a.name == reg_b.name: + return True + # Check vector registers first + if self.is_vector_register(reg_a): + if self.is_vector_register(reg_b): + if reg_a.name[1:] == reg_b.name[1:]: + # Registers in the same vector space + return True + return False + # Check basic GPRs + a_dep = ['RAX', 'EAX', 'AX', 'AH', 'AL'] + b_dep = ['RBX', 'EBX', 'BX', 'BH', 'BL'] + c_dep = ['RCX', 'ECX', 'CX', 'CH', 'CL'] + d_dep = ['RDX', 'EDX', 'DX', 'DH', 'DL'] + sp_dep = ['RSP', 'ESP', 'SP', 'SPL'] + src_dep = ['RSI', 'ESI', 'SI', 'SIL'] + dst_dep = ['RDI', 'EDI', 'DI', 'DIL'] + basic_gprs = [a_dep, b_dep, c_dep, d_dep, sp_dep, src_dep, dst_dep] + if self.is_basic_gpr(reg_a): + if self.is_basic_gpr(reg_b): + for dep_group in basic_gprs: + if reg_a['name'].upper() in dep_group: + if reg_b['name'].upper() in dep_group: + return True + return False + # Check other GPRs + gpr_parser = ( + pp.CaselessLiteral('R') + + pp.Word(pp.nums).setResultsName('id') + + pp.Optional(pp.Word('dwbDWB', exact=1)) + ) + try: + id_a = gpr_parser.parseString(reg_a['name'], parseAll=True).asDict()['id'] + id_b = gpr_parser.parseString(reg_b['name'], parseAll=True).asDict()['id'] + if id_a == id_b: + return True + except pp.ParseException: + return False + # No dependencies + return False + + def is_basic_gpr(self, register): + if any(char.isdigit() for char in register['name']): + return False + return True + + def is_vector_register(self, register): + if len(register['name']) > 2 and register['name'][1:3] == 'mm': + return True + return False diff --git a/tests/test_parser_AArch64v81.py b/tests/test_parser_AArch64v81.py index e47d18b..797cc87 100755 --- a/tests/test_parser_AArch64v81.py +++ b/tests/test_parser_AArch64v81.py @@ -283,28 +283,90 @@ class TestParserAArch64v81(unittest.TestCase): self.assertEqual(parsed[0].line_number, 1) self.assertEqual(len(parsed), 645) + def test_normalize_imd(self): + imd_decimal_1 = {'value': '79'} + imd_hex_1 = {'value': '0x4f'} + imd_decimal_2 = {'value': '8'} + imd_hex_2 = {'value': '0x8'} + imd_float_11 = {'float': {'mantissa': '0.79', 'e_sign': '+', 'exponent': '2'}} + imd_float_12 = {'float': {'mantissa': '790.0', 'e_sign': '-', 'exponent': '1'}} + imd_double_11 = {'double': {'mantissa': '0.79', 'e_sign': '+', 'exponent': '2'}} + imd_double_12 = {'double': {'mantissa': '790.0', 'e_sign': '-', 'exponent': '1'}} + + value1 = self.parser.normalize_imd(imd_decimal_1) + self.assertEqual(value1, self.parser.normalize_imd(imd_hex_1)) + self.assertEqual( + self.parser.normalize_imd(imd_decimal_2), self.parser.normalize_imd(imd_hex_2) + ) + self.assertEqual(self.parser.normalize_imd(imd_float_11), value1) + self.assertEqual(self.parser.normalize_imd(imd_float_12), value1) + self.assertEqual(self.parser.normalize_imd(imd_double_11), value1) + self.assertEqual(self.parser.normalize_imd(imd_double_12), value1) + + def test_reg_dependency(self): + reg_1_1 = AttrDict({'prefix': 'b', 'name': '1'}) + reg_1_2 = AttrDict({'prefix': 'h', 'name': '1'}) + reg_1_3 = AttrDict({'prefix': 's', 'name': '1'}) + reg_1_4 = AttrDict({'prefix': 'd', 'name': '1'}) + reg_1_4 = AttrDict({'prefix': 'q', 'name': '1'}) + reg_2_1 = AttrDict({'prefix': 'w', 'name': '2'}) + reg_2_2 = AttrDict({'prefix': 'x', 'name': '2'}) + reg_v1_1 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '16', 'shape': 'b'}) + reg_v1_2 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '8', 'shape': 'h'}) + reg_v1_3 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '4', 'shape': 's'}) + reg_v1_4 = AttrDict({'prefix': 'v', 'name': '11', 'lanes': '2', 'shape': 'd'}) + + reg_b5 = AttrDict({'prefix': 'b', 'name': '5'}) + reg_q15 = AttrDict({'prefix': 'q', 'name': '15'}) + reg_v10 = AttrDict({'prefix': 'v', 'name': '10', 'lanes': '2', 'shape': 's'}) + reg_v20 = AttrDict({'prefix': 'v', 'name': '20', 'lanes': '2', 'shape': 'd'}) + + reg_1 = [reg_1_1, reg_1_2, reg_1_3, reg_1_4] + reg_2 = [reg_2_1, reg_2_2] + reg_v = [reg_v1_1, reg_v1_2, reg_v1_3, reg_v1_4] + reg_others = [reg_b5, reg_q15, reg_v10, reg_v20] + regs = reg_1 + reg_2 + reg_v + reg_others + + # test each register against each other + for ri in reg_1: + for rj in regs: + assert_value = True if rj in reg_1 else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_2: + for rj in regs: + assert_value = True if rj in reg_2 else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_v: + for rj in regs: + assert_value = True if rj in reg_v else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_others: + for rj in regs: + assert_value = True if rj == ri else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + ################## # Helper functions ################## def _get_comment(self, parser, comment): return ' '.join( AttrDict.convert_dict( - parser._process_operand( - parser.comment.parseString(comment, parseAll=True).asDict() - ) + parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict()) ).comment ) def _get_label(self, parser, label): return AttrDict.convert_dict( - parser._process_operand(parser.label.parseString(label, parseAll=True).asDict()) + parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) ).label def _get_directive(self, parser, directive): return AttrDict.convert_dict( - parser._process_operand( - parser.directive.parseString(directive, parseAll=True).asDict() - ) + parser.process_operand(parser.directive.parseString(directive, parseAll=True).asDict()) ).directive @staticmethod diff --git a/tests/test_parser_x86att.py b/tests/test_parser_x86att.py index a6d5e9c..9927956 100755 --- a/tests/test_parser_x86att.py +++ b/tests/test_parser_x86att.py @@ -186,28 +186,82 @@ class TestParserX86ATT(unittest.TestCase): self.assertEqual(parsed[0].line_number, 1) self.assertEqual(len(parsed), 353) + def test_normalize_imd(self): + imd_decimal_1 = {'value': '79'} + imd_hex_1 = {'value': '0x4f'} + imd_decimal_2 = {'value': '8'} + imd_hex_2 = {'value': '0x8'} + self.assertEqual( + self.parser.normalize_imd(imd_decimal_1), self.parser.normalize_imd(imd_hex_1) + ) + self.assertEqual( + self.parser.normalize_imd(imd_decimal_2), self.parser.normalize_imd(imd_hex_2) + ) + + def test_reg_dependency(self): + reg_a1 = AttrDict({'name': 'rax'}) + reg_a2 = AttrDict({'name': 'eax'}) + reg_a3 = AttrDict({'name': 'ax'}) + reg_a4 = AttrDict({'name': 'al'}) + reg_r11 = AttrDict({'name': 'r11'}) + reg_r11b = AttrDict({'name': 'r11b'}) + reg_r11d = AttrDict({'name': 'r11d'}) + reg_r11w = AttrDict({'name': 'r11w'}) + reg_xmm1 = AttrDict({'name': 'xmm1'}) + reg_ymm1 = AttrDict({'name': 'ymm1'}) + reg_zmm1 = AttrDict({'name': 'zmm1'}) + + reg_b1 = AttrDict({'name': 'rbx'}) + reg_r15 = AttrDict({'name': 'r15'}) + reg_xmm2 = AttrDict({'name': 'xmm2'}) + reg_ymm3 = AttrDict({'name': 'ymm3'}) + + reg_a = [reg_a1, reg_a2, reg_a3, reg_a4] + reg_r = [reg_r11, reg_r11b, reg_r11d, reg_r11w] + reg_vec_1 = [reg_xmm1, reg_ymm1, reg_zmm1] + reg_others = [reg_b1, reg_r15, reg_xmm2, reg_ymm3] + regs = reg_a + reg_r + reg_vec_1 + reg_others + + # test each register against each other + for ri in reg_a: + for rj in regs: + assert_value = True if rj in reg_a else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_r: + for rj in regs: + assert_value = True if rj in reg_r else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_vec_1: + for rj in regs: + assert_value = True if rj in reg_vec_1 else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + for ri in reg_others: + for rj in regs: + assert_value = True if rj == ri else False + with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): + self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + ################## # Helper functions ################## def _get_comment(self, parser, comment): return ' '.join( AttrDict.convert_dict( - parser._process_operand( - parser.comment.parseString(comment, parseAll=True).asDict() - ) + parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict()) ).comment ) def _get_label(self, parser, label): return AttrDict.convert_dict( - parser._process_operand(parser.label.parseString(label, parseAll=True).asDict()) + parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) ).label def _get_directive(self, parser, directive): return AttrDict.convert_dict( - parser._process_operand( - parser.directive.parseString(directive, parseAll=True).asDict() - ) + parser.process_operand(parser.directive.parseString(directive, parseAll=True).asDict()) ).directive @staticmethod