mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2026-01-07 03:30:06 +01:00
improved performance of arch_semantics and reg dependency matching
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import string
|
||||
import re
|
||||
|
||||
import pyparsing as pp
|
||||
|
||||
@@ -362,45 +363,44 @@ class ParserX86ATT(BaseParser):
|
||||
|
||||
def is_reg_dependend_of(self, reg_a, reg_b):
|
||||
"""Check if ``reg_a`` is dependent on ``reg_b``"""
|
||||
# Normalize name
|
||||
reg_a_name = reg_a['name'].upper()
|
||||
reg_b_name = reg_b['name'].upper()
|
||||
|
||||
# Check if they are the same registers
|
||||
if reg_a.name == reg_b.name:
|
||||
if reg_a_name == reg_b_name:
|
||||
return True
|
||||
# Check vector registers first
|
||||
if self.is_vector_register(reg_a):
|
||||
if self.is_vector_register(reg_b):
|
||||
if reg_a.name[1:] == reg_b.name[1:]:
|
||||
if reg_a_name[1:] == reg_b_name[1:]:
|
||||
# Registers in the same vector space
|
||||
return True
|
||||
return False
|
||||
# Check basic GPRs
|
||||
a_dep = ['RAX', 'EAX', 'AX', 'AH', 'AL']
|
||||
b_dep = ['RBX', 'EBX', 'BX', 'BH', 'BL']
|
||||
c_dep = ['RCX', 'ECX', 'CX', 'CH', 'CL']
|
||||
d_dep = ['RDX', 'EDX', 'DX', 'DH', 'DL']
|
||||
sp_dep = ['RSP', 'ESP', 'SP', 'SPL']
|
||||
src_dep = ['RSI', 'ESI', 'SI', 'SIL']
|
||||
dst_dep = ['RDI', 'EDI', 'DI', 'DIL']
|
||||
basic_gprs = [a_dep, b_dep, c_dep, d_dep, sp_dep, src_dep, dst_dep]
|
||||
gpr_groups = {
|
||||
'A': ['RAX', 'EAX', 'AX', 'AH', 'AL'],
|
||||
'B': ['RBX', 'EBX', 'BX', 'BH', 'BL'],
|
||||
'C': ['RCX', 'ECX', 'CX', 'CH', 'CL'],
|
||||
'D': ['RDX', 'EDX', 'DX', 'DH', 'DL'],
|
||||
'SP': ['RSP', 'ESP', 'SP', 'SPL'],
|
||||
'SRC': ['RSI', 'ESI', 'SI', 'SIL'],
|
||||
'DST': ['RDI', 'EDI', 'DI', 'DIL']
|
||||
}
|
||||
if self.is_basic_gpr(reg_a):
|
||||
if self.is_basic_gpr(reg_b):
|
||||
for dep_group in basic_gprs:
|
||||
if reg_a['name'].upper() in dep_group:
|
||||
if reg_b['name'].upper() in dep_group:
|
||||
for dep_group in gpr_groups.values():
|
||||
if reg_a_name in dep_group:
|
||||
if reg_b['name'] in dep_group:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Check other GPRs
|
||||
gpr_parser = (
|
||||
pp.CaselessLiteral('R')
|
||||
+ pp.Word(pp.nums).setResultsName('id')
|
||||
+ pp.Optional(pp.Word('dwbDWB', exact=1))
|
||||
)
|
||||
try:
|
||||
id_a = gpr_parser.parseString(reg_a['name'], parseAll=True).asDict()['id']
|
||||
id_b = gpr_parser.parseString(reg_b['name'], parseAll=True).asDict()['id']
|
||||
if id_a == id_b:
|
||||
return True
|
||||
except pp.ParseException:
|
||||
return False
|
||||
ma = re.match(r'R([0-9]+)[DWB]?', reg_a_name)
|
||||
mb = re.match(r'R([0-9]+)[DWB]?', reg_a_name)
|
||||
if ma and mb and ma.group(1) == mb.group(1):
|
||||
return True
|
||||
|
||||
# No dependencies
|
||||
return False
|
||||
|
||||
@@ -414,19 +414,11 @@ class ParserX86ATT(BaseParser):
|
||||
"""Check if register is a general purpose register"""
|
||||
if register is None:
|
||||
return False
|
||||
gpr_parser = (
|
||||
pp.CaselessLiteral('R')
|
||||
+ pp.Word(pp.nums).setResultsName('id')
|
||||
+ pp.Optional(pp.Word('dwbDWB', exact=1))
|
||||
)
|
||||
|
||||
if self.is_basic_gpr(register):
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
gpr_parser.parseString(register['name'], parseAll=True)
|
||||
return True
|
||||
except pp.ParseException:
|
||||
return False
|
||||
|
||||
return re.match(r'R([0-9]+)[DWB]?', register['name'], re.IGNORECASE)
|
||||
|
||||
def is_vector_register(self, register):
|
||||
"""Check if register is a vector register"""
|
||||
|
||||
@@ -398,9 +398,7 @@ class ArchSemantics(ISASemantics):
|
||||
|
||||
def g(obj, value):
|
||||
obj[item] = value
|
||||
|
||||
else:
|
||||
|
||||
def g(obj, *values):
|
||||
for item, value in zip(items, values):
|
||||
obj[item] = value
|
||||
@@ -416,7 +414,9 @@ class ArchSemantics(ISASemantics):
|
||||
@staticmethod
|
||||
def get_throughput_sum(kernel):
|
||||
"""Get the overall throughput sum separated by port of all instructions of a kernel."""
|
||||
port_pressures = [instr['port_pressure'] for instr in kernel]
|
||||
# ignoring all lines with throughput == 0.0, because there won't be anything to sum up
|
||||
# typically comment, label and non-instruction lines
|
||||
port_pressures = [instr['port_pressure'] for instr in kernel if instr['throughput'] != 0.0]
|
||||
# Essentially summing up each columns of port_pressures, where each column is one port
|
||||
# and each row is one line of the kernel
|
||||
# round is necessary to ensure termination of ArchsSemantics.assign_optimal_throughput
|
||||
|
||||
@@ -225,7 +225,7 @@ class MachineModel(object):
|
||||
for y in list(filter(lambda x: True if x != 'class' else False, op))
|
||||
]
|
||||
operands.append('{}({})'.format(op['class'], ','.join(op_attrs)))
|
||||
return '{} {}'.format(instruction_form['name'], ','.join(operands))
|
||||
return '{} {}'.format(instruction_form['name'].lower(), ','.join(operands))
|
||||
|
||||
@staticmethod
|
||||
def get_isa_for_arch(arch):
|
||||
@@ -285,7 +285,8 @@ class MachineModel(object):
|
||||
{
|
||||
k: v
|
||||
for k, v in self._data.items()
|
||||
if k not in ['instruction_forms', 'load_throughput', 'internal_version']
|
||||
if k not in ['instruction_forms', 'instruction_forms_dict', 'load_throughput',
|
||||
'internal_version']
|
||||
},
|
||||
stream,
|
||||
)
|
||||
|
||||
@@ -83,28 +83,21 @@ class TestSemanticTools(unittest.TestCase):
|
||||
self.assertIsNone(test_mm_x86.get_instruction(None, []))
|
||||
self.assertIsNone(test_mm_arm.get_instruction(None, []))
|
||||
|
||||
# test dict DB creation
|
||||
test_mm_x86._data['instruction_dict'] = test_mm_x86._convert_to_dict(
|
||||
test_mm_x86._data['instruction_forms']
|
||||
)
|
||||
test_mm_arm._data['instruction_dict'] = test_mm_arm._convert_to_dict(
|
||||
test_mm_arm._data['instruction_forms']
|
||||
)
|
||||
# test get_instruction from dict DB
|
||||
self.assertIsNone(test_mm_x86.get_instruction_from_dict(None, []))
|
||||
self.assertIsNone(test_mm_arm.get_instruction_from_dict(None, []))
|
||||
self.assertIsNone(test_mm_x86.get_instruction_from_dict('NOT_IN_DB', []))
|
||||
self.assertIsNone(test_mm_arm.get_instruction_from_dict('NOT_IN_DB', []))
|
||||
# test get_instruction from DB
|
||||
self.assertIsNone(test_mm_x86.get_instruction(None, []))
|
||||
self.assertIsNone(test_mm_arm.get_instruction(None, []))
|
||||
self.assertIsNone(test_mm_x86.get_instruction('NOT_IN_DB', []))
|
||||
self.assertIsNone(test_mm_arm.get_instruction('NOT_IN_DB', []))
|
||||
name_x86_1 = 'vaddpd'
|
||||
operands_x86_1 = [
|
||||
{'class': 'register', 'name': 'xmm'},
|
||||
{'class': 'register', 'name': 'xmm'},
|
||||
{'class': 'register', 'name': 'xmm'},
|
||||
]
|
||||
instr_form_x86_1 = test_mm_x86.get_instruction_from_dict(name_x86_1, operands_x86_1)
|
||||
instr_form_x86_1 = test_mm_x86.get_instruction(name_x86_1, operands_x86_1)
|
||||
self.assertEqual(instr_form_x86_1, test_mm_x86.get_instruction(name_x86_1, operands_x86_1))
|
||||
self.assertEqual(
|
||||
test_mm_x86.get_instruction_from_dict('jg', [{'class': 'identifier'}]),
|
||||
test_mm_x86.get_instruction('jg', [{'class': 'identifier'}]),
|
||||
test_mm_x86.get_instruction('jg', [{'class': 'identifier'}]),
|
||||
)
|
||||
name_arm_1 = 'fadd'
|
||||
@@ -113,10 +106,10 @@ class TestSemanticTools(unittest.TestCase):
|
||||
{'class': 'register', 'prefix': 'v', 'shape': 's'},
|
||||
{'class': 'register', 'prefix': 'v', 'shape': 's'},
|
||||
]
|
||||
instr_form_arm_1 = test_mm_arm.get_instruction_from_dict(name_arm_1, operands_arm_1)
|
||||
instr_form_arm_1 = test_mm_arm.get_instruction(name_arm_1, operands_arm_1)
|
||||
self.assertEqual(instr_form_arm_1, test_mm_arm.get_instruction(name_arm_1, operands_arm_1))
|
||||
self.assertEqual(
|
||||
test_mm_arm.get_instruction_from_dict('b.ne', [{'class': 'identifier'}]),
|
||||
test_mm_arm.get_instruction('b.ne', [{'class': 'identifier'}]),
|
||||
test_mm_arm.get_instruction('b.ne', [{'class': 'identifier'}]),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user