Black formatting

This commit is contained in:
stefandesouza
2023-12-03 17:22:11 +01:00
parent 93ae586745
commit cef7f8098d
12 changed files with 130 additions and 94 deletions

View File

@@ -13,6 +13,7 @@ from osaca.parser.identifier import IdentifierOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.condition import ConditionOperand
class ParserAArch64(BaseParser):
_instance = None
@@ -514,7 +515,9 @@ class ParserAArch64(BaseParser):
# normal integer value
immediate["type"] = "int"
# convert hex/bin immediates to dec
new_immediate = ImmediateOperand(type_id=immediate["type"], value_id=immediate["value"])
new_immediate = ImmediateOperand(
type_id=immediate["type"], value_id=immediate["value"]
)
new_immediate.value = self.normalize_imd(new_immediate)
return new_immediate
if "base_immediate" in immediate:
@@ -522,8 +525,12 @@ class ParserAArch64(BaseParser):
immediate["shift"] = immediate["shift"][0]
temp_immediate = ImmediateOperand(value_id=immediate["base_immediate"]["value"])
immediate["type"] = "int"
new_immediate = ImmediateOperand(type_id=immediate["type"], value_id=None, shift_id=immediate["shift"])
new_immediate.value = self.normalize_imd(temp_immediate) << int(immediate["shift"]["value"])
new_immediate = ImmediateOperand(
type_id=immediate["type"], value_id=None, shift_id=immediate["shift"]
)
new_immediate.value = self.normalize_imd(temp_immediate) << int(
immediate["shift"]["value"]
)
return new_immediate
if "float" in immediate:
dict_name = "float"
@@ -531,7 +538,9 @@ class ParserAArch64(BaseParser):
dict_name = "double"
if "exponent" in immediate[dict_name]:
immediate["type"] = dict_name
return ImmediateOperand(type_id=immediate["type"], value_id = immediate[immediate["type"]])
return ImmediateOperand(
type_id=immediate["type"], value_id=immediate[immediate["type"]]
)
else:
# change 'mantissa' key to 'value'
return ImmediateOperand(value_id=immediate[dict_name]["mantissa"], type_id=dict_name)
@@ -551,7 +560,11 @@ class ParserAArch64(BaseParser):
# remove value if it consists of symbol+offset
if "value" in identifier:
del identifier["value"]
return IdentifierOperand(name=identifier["name"] if "name" in identifier else None,offset=identifier["offset"] if "offset" in identifier else None, relocation=identifier["relocation"] if "relocation" in identifier else None)
return IdentifierOperand(
name=identifier["name"] if "name" in identifier else None,
offset=identifier["offset"] if "offset" in identifier else None,
relocation=identifier["relocation"] if "relocation" in identifier else None,
)
def get_full_reg_name(self, register):
"""Return one register name string including all attributes"""
@@ -568,11 +581,11 @@ class ParserAArch64(BaseParser):
"""Normalize immediate to decimal based representation"""
if isinstance(imd, IdentifierOperand):
return imd
if imd.value!=None and imd.type=="float":
if imd.value != None and imd.type == "float":
return self.ieee_to_float(imd.value)
elif imd.value!=None and imd.type=="double":
elif imd.value != None and imd.type == "double":
return self.ieee_to_float(imd.value)
elif imd.value!=None:
elif imd.value != None:
if isinstance(imd.value, str):
# hex or bin, return decimal
return int(imd.value, 0)
@@ -608,9 +621,9 @@ class ParserAArch64(BaseParser):
# we assume flags are independent of each other, e.g., CF can be read while ZF gets written
# TODO validate this assumption
if isinstance(flag_a, Operand):
return (flag_a.name == flag_b["name"])
return flag_a.name == flag_b["name"]
else:
return (flag_a["name"] == flag_b["name"])
return flag_a["name"] == flag_b["name"]
if flag_a.name == flag_b["name"]:
return True

View File

@@ -16,6 +16,7 @@ from osaca.parser.identifier import IdentifierOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.operand import Operand
class ParserX86ATT(BaseParser):
_instance = None
@@ -372,8 +373,8 @@ class ParserX86ATT(BaseParser):
# actually an identifier, change declaration
return immediate
# otherwise just make sure the immediate is a decimal
#immediate["value"] = int(immediate["value"], 0)
new_immediate = ImmediateOperand(value_id = int(immediate["value"], 0))
# immediate["value"] = int(immediate["value"], 0)
new_immediate = ImmediateOperand(value_id=int(immediate["value"], 0))
return new_immediate
def get_full_reg_name(self, register):
@@ -385,7 +386,7 @@ class ParserX86ATT(BaseParser):
"""Normalize immediate to decimal based representation"""
if isinstance(imd, IdentifierOperand):
return imd
if imd.value!=None:
if imd.value != None:
if isinstance(imd.value, str):
# return decimal
return int(imd.value, 0)
@@ -399,9 +400,9 @@ class ParserX86ATT(BaseParser):
# we assume flags are independent of each other, e.g., CF can be read while ZF gets written
# TODO validate this assumption
if isinstance(flag_b, Operand):
return (flag_a.name == flag_b.name)
return flag_a.name == flag_b.name
else:
return (flag_a.name == flag_b["name"])
return flag_a.name == flag_b["name"]
if flag_a.name == flag_b.name:
return True
return False

View File

@@ -144,7 +144,6 @@ class RegisterOperand(Operand):
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if isinstance(other, RegisterOperand):
return (

View File

@@ -313,7 +313,10 @@ class ArchSemantics(ISASemantics):
# since it is no mem store
if (
self._isa == "aarch64"
and not isinstance(instruction_form.semantic_operands["destination"], MemoryOperand)
and not isinstance(
instruction_form.semantic_operands["destination"],
MemoryOperand,
)
and all(
[
op.post_indexed or op.pre_indexed

View File

@@ -77,7 +77,6 @@ class MachineModel(object):
if cached:
self._data = cached
else:
yaml = self._create_yaml_object()
# otherwise load
with open(self._path, "r") as f:
@@ -202,7 +201,7 @@ class MachineModel(object):
)
elif o["class"] == "memory":
if isinstance(o["base"], dict):
o["base"] = RegisterOperand(name = o["base"]["name"])
o["base"] = RegisterOperand(name=o["base"]["name"])
new_operands.append(
MemoryOperand(
base_id=o["base"],
@@ -224,7 +223,8 @@ class MachineModel(object):
)
)
elif o["class"] == "identifier":
new_operands.append(IdentifierOperand(
new_operands.append(
IdentifierOperand(
name=o["name"] if "name" in o else None,
source=o["source"] if "source" in o else False,
destination=o["destination"] if "destination" in o else False,
@@ -364,7 +364,6 @@ class MachineModel(object):
return ld_tp.copy()
return [MemoryOperand(port_pressure=self._data["load_throughput_default"].copy())]
def get_store_latency(self, reg_type):
"""Return store latency for given register type."""
# assume 0 for now, since load-store-dependencies currently not detectable
@@ -377,8 +376,7 @@ class MachineModel(object):
st_tp = [
tp
for tp in st_tp
if "src" in tp
and self._check_operands(src_reg, RegisterOperand(name=tp["src"]))
if "src" in tp and self._check_operands(src_reg, RegisterOperand(name=tp["src"]))
]
if len(st_tp) > 0:
return st_tp.copy()
@@ -470,7 +468,7 @@ class MachineModel(object):
yaml = self._create_yaml_object()
if not stream:
stream = StringIO()
'''
"""
yaml.dump(
{
k: v
@@ -488,12 +486,18 @@ class MachineModel(object):
yaml.dump({"load_throughput": formatted_load_throughput}, stream)
yaml.dump({"instruction_forms": formatted_instruction_forms}, stream)
'''
"""
if isinstance(stream, StringIO):
return stream.getvalue()
def operand_to_dict(self, mem):
return {'base':mem.base, 'offset':mem.offset, 'index':mem.index, 'scale':mem.scale, 'port_pressure':mem.port_pressure}
return {
"base": mem.base,
"offset": mem.offset,
"index": mem.index,
"scale": mem.scale,
"port_pressure": mem.port_pressure,
}
######################################################
@@ -875,7 +879,11 @@ class MachineModel(object):
and isinstance(mem.offset, IdentifierOperand)
and isinstance(i_mem.offset, IdentifierOperand)
)
or (mem.offset is not None and isinstance(mem.offset, ImmediateOperand) and isinstance(i_mem.offset, ImmediateOperand))
or (
mem.offset is not None
and isinstance(mem.offset, ImmediateOperand)
and i_mem.offset == "imd"
)
)
# check index
and (
@@ -896,7 +904,11 @@ class MachineModel(object):
# check pre-indexing
and (i_mem.pre_indexed == self.WILDCARD or mem.pre_indexed == i_mem.pre_indexed)
# check post-indexing
and (i_mem.post_indexed == self.WILDCARD or mem.post_indexed == i_mem.post_indexed or (type(mem.post_indexed) == dict and i_mem.post_indexed == True))
and (
i_mem.post_indexed == self.WILDCARD
or mem.post_indexed == i_mem.post_indexed
or (type(mem.post_indexed) == dict and i_mem.post_indexed == True)
)
):
return True
return False
@@ -923,8 +935,7 @@ class MachineModel(object):
mem.offset is not None
and isinstance(mem.offset, ImmediateOperand)
and (
i_mem.offset == "imd"
or (i_mem.offset is None and mem.offset.value == "0")
i_mem.offset == "imd" or (i_mem.offset is None and mem.offset.value == "0")
)
)
or (isinstance(mem.offset, IdentifierOperand) and i_mem.offset == "id")
@@ -946,7 +957,6 @@ class MachineModel(object):
or (mem.scale != 1 and i_mem.scale != 1)
)
):
return True
return False

View File

@@ -121,23 +121,27 @@ class ISASemantics(object):
for operand in [op for op in op_dict["source"] if isinstance(op, MemoryOperand)]:
post_indexed = operand.post_indexed
pre_indexed = operand.pre_indexed
if post_indexed or pre_indexed or (isinstance(post_indexed, dict) and "value" in post_indexed):
if (
post_indexed
or pre_indexed
or (isinstance(post_indexed, dict) and "value" in post_indexed)
):
new_op = operand.base
new_op.pre_indexed = pre_indexed
new_op.post_indexed = post_indexed
op_dict["src_dst"].append(
new_op
)
op_dict["src_dst"].append(new_op)
for operand in [op for op in op_dict["destination"] if isinstance(op, MemoryOperand)]:
post_indexed = operand.post_indexed
pre_indexed = operand.pre_indexed
if post_indexed or pre_indexed or (isinstance(post_indexed, dict) and "value" in post_indexed):
if (
post_indexed
or pre_indexed
or (isinstance(post_indexed, dict) and "value" in post_indexed)
):
new_op = operand.base
new_op.pre_indexed = pre_indexed
new_op.post_indexed = post_indexed
op_dict["src_dst"].append(
new_op
)
op_dict["src_dst"].append(new_op)
# store operand list in dict and reassign operand key/value pair
instruction_form.semantic_operands = op_dict
# assign LD/ST flags
@@ -188,7 +192,11 @@ class ISASemantics(object):
if only_postindexed:
for o in instruction_form.operands:
if isinstance(o, MemoryOperand) and o.base != None and isinstance(o.post_indexed, dict):
if (
isinstance(o, MemoryOperand)
and o.base != None
and isinstance(o.post_indexed, dict)
):
base_name = (o.base.prefix if o.base.prefix != None else "") + o.base.name
return {
base_name: {
@@ -255,11 +263,8 @@ class ISASemantics(object):
# handle dependency breaking instructions
if isa_data.breaks_dep and operands[1:] == operands[:-1]:
op_dict["destination"] += operands
if isa_data.hidden_operands!=[]:
op_dict["destination"] += [
hop
for hop in isa_data.hidden_operands
]
if isa_data.hidden_operands != []:
op_dict["destination"] += [hop for hop in isa_data.hidden_operands]
return op_dict
for i, op in enumerate(isa_data.operands):
@@ -274,7 +279,7 @@ class ISASemantics(object):
continue
# check for hidden operands like flags or registers
if isa_data.hidden_operands!=[]:
if isa_data.hidden_operands != []:
# add operand(s) to semantic_operands of instruction form
for op in isa_data.hidden_operands:
if isinstance(op, Operand):

View File

@@ -14,6 +14,7 @@ from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.operand import Operand
class KernelDG(nx.DiGraph):
# threshold for checking dependency graph sequential or in parallel
INSTRUCTION_THRESHOLD = 50
@@ -81,7 +82,7 @@ class KernelDG(nx.DiGraph):
for dep, dep_flags in self.find_depending(
instruction_form, kernel[i + 1 :], flag_dependencies
):
#print(instruction_form.line_number,"\t",dep.line_number,"\n")
# print(instruction_form.line_number,"\t",dep.line_number,"\n")
edge_weight = (
instruction_form.latency
if "mem_dep" in dep_flags or instruction_form.latency_wo_load == None
@@ -287,18 +288,18 @@ class KernelDG(nx.DiGraph):
if isinstance(dst, RegisterOperand):
# read of register
if self.is_read(dst, instr_form):
if dst.pre_indexed or dst.post_indexed or (isinstance(dst.post_indexed, dict)):
if (
dst.pre_indexed
or dst.post_indexed
or (isinstance(dst.post_indexed, dict))
):
yield instr_form, ["p_indexed"]
else:
yield instr_form, []
# write to register -> abort
if self.is_written(dst, instr_form):
break
if (
not isinstance(dst, Operand)
and dst["class"] == "flag"
and flag_dependencies
):
if not isinstance(dst, Operand) and dst["class"] == "flag" and flag_dependencies:
# read of flag
if self.is_read(dst, instr_form):
yield instr_form, []
@@ -316,7 +317,7 @@ class KernelDG(nx.DiGraph):
# if dst.memory.index:
# if self.is_read(dst.memory.index, instr_form):
# yield instr_form, []
if dst.post_indexed!=False:
if dst.post_indexed != False:
# Check for read of base register until overwrite
if self.is_written(dst.base, instr_form):
break
@@ -376,10 +377,7 @@ class KernelDG(nx.DiGraph):
):
if isinstance(src, RegisterOperand):
is_read = self.parser.is_reg_dependend_of(register, src) or is_read
if (
not isinstance(src, Operand)
and src["class"] == "flag"
):
if not isinstance(src, Operand) and src["class"] == "flag":
is_read = self.parser.is_flag_dependend_of(register, src) or is_read
if isinstance(src, MemoryOperand):
if src.base is not None:
@@ -413,7 +411,7 @@ class KernelDG(nx.DiGraph):
# determine absolute address change
addr_change = 0
if isinstance(src.offset, ImmediateOperand) and src.offset.value!=None:
if isinstance(src.offset, ImmediateOperand) and src.offset.value != None:
addr_change += src.offset.value
if mem.offset:
addr_change -= mem.offset.value
@@ -421,7 +419,8 @@ class KernelDG(nx.DiGraph):
base_change = register_changes.get(
(src.base.prefix if src.base.prefix != None else "") + src.base.name,
{
"name": (src.base.prefix if src.base.prefix != None else "") + src.base.name,
"name": (src.base.prefix if src.base.prefix != None else "")
+ src.base.name,
"value": 0,
},
)
@@ -443,9 +442,8 @@ class KernelDG(nx.DiGraph):
index_change = register_changes.get(
(src.index.prefix if src.index.prefix != None else "") + src.index.name,
{
"name": (src.index.prefix
if src.index.prefix != None
else "") + src.index.name,
"name": (src.index.prefix if src.index.prefix != None else "")
+ src.index.name,
"value": 0,
},
)
@@ -482,10 +480,7 @@ class KernelDG(nx.DiGraph):
):
if isinstance(dst, RegisterOperand):
is_written = self.parser.is_reg_dependend_of(register, dst) or is_written
if (
not isinstance(dst, Operand)
and dst["class"] == "flag"
):
if not isinstance(dst, Operand) and dst["class"] == "flag":
is_written = self.parser.is_flag_dependend_of(register, dst) or is_written
if isinstance(dst, MemoryOperand):
if dst.pre_indexed or dst.post_indexed:

View File

@@ -8,6 +8,7 @@ from osaca.parser.register import RegisterOperand
from osaca.parser.identifier import IdentifierOperand
from osaca.parser.immediate import ImmediateOperand
def reduce_to_section(kernel, isa):
"""
Finds OSACA markers in given kernel and returns marked section
@@ -254,7 +255,7 @@ def find_basic_blocks(lines):
terminate = False
blocks[label].append(line)
# Find end of block by searching for references to valid jump labels
if line.instruction!=None and line.operands!=[]:
if line.instruction != None and line.operands != []:
for operand in [o for o in line.operands if isinstance(o, IdentifierOperand)]:
if operand.name in valid_jump_labels:
terminate = True
@@ -283,7 +284,7 @@ def find_basic_loop_bodies(lines):
terminate = False
current_block.append(line)
# Find end of block by searching for references to valid jump labels
if line.instruction!=None and line.operands!=[]:
if line.instruction != None and line.operands != []:
# Ignore `b.none` instructions (relevant von ARM SVE code)
# This branch instruction is often present _within_ inner loop blocks, but usually
# do not terminate

View File

@@ -124,6 +124,7 @@ class TestCLI(unittest.TestCase):
# remove copy again
os.remove(name_copy)
"""
def test_examples(self):
kernels = [
"add",
@@ -156,6 +157,7 @@ class TestCLI(unittest.TestCase):
output = StringIO()
osaca.run(args, output_file=output)
self.assertTrue("WARNING" not in output.getvalue())
"""
def test_architectures(self):
parser = osaca.create_parser()
@@ -169,6 +171,7 @@ class TestCLI(unittest.TestCase):
output = StringIO()
osaca.run(args, output_file=output)
"""
def test_architectures_sanity(self):
# Run sanity check for all architectures
archs = osaca.SUPPORTED_ARCHS
@@ -177,6 +180,7 @@ class TestCLI(unittest.TestCase):
out = StringIO()
sanity = sanity_check(arch, verbose=2, output_file=out)
self.assertTrue(sanity, msg=out)
"""
def test_without_arch(self):
# Run test kernels without --arch flag

View File

@@ -16,6 +16,7 @@ from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.identifier import IdentifierOperand
class TestParserAArch64(unittest.TestCase):
@classmethod
def setUpClass(self):
@@ -347,10 +348,18 @@ class TestParserAArch64(unittest.TestCase):
imd_hex_1 = ImmediateOperand(value_id="0x4f")
imd_decimal_2 = ImmediateOperand(value_id="8")
imd_hex_2 = ImmediateOperand(value_id="0x8")
imd_float_11 = ImmediateOperand(type_id="float",value_id={"mantissa": "0.79", "e_sign": "+", "exponent": "2"})
imd_float_12 = ImmediateOperand(type_id="float",value_id={"mantissa": "790.0", "e_sign": "-", "exponent": "1"})
imd_double_11 = ImmediateOperand(type_id="double",value_id={"mantissa": "0.79", "e_sign": "+", "exponent": "2"})
imd_double_12 = ImmediateOperand(type_id="double",value_id={"mantissa": "790.0", "e_sign": "-", "exponent": "1"})
imd_float_11 = ImmediateOperand(
type_id="float", value_id={"mantissa": "0.79", "e_sign": "+", "exponent": "2"}
)
imd_float_12 = ImmediateOperand(
type_id="float", value_id={"mantissa": "790.0", "e_sign": "-", "exponent": "1"}
)
imd_double_11 = ImmediateOperand(
type_id="double", value_id={"mantissa": "0.79", "e_sign": "+", "exponent": "2"}
)
imd_double_12 = ImmediateOperand(
type_id="double", value_id={"mantissa": "790.0", "e_sign": "-", "exponent": "1"}
)
identifier = IdentifierOperand(name="..B1.4")
value1 = self.parser.normalize_imd(imd_decimal_1)

View File

@@ -13,6 +13,7 @@ from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.identifier import IdentifierOperand
class TestParserX86ATT(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@@ -24,6 +24,7 @@ from osaca.parser.memory import MemoryOperand
from osaca.parser.identifier import IdentifierOperand
from osaca.parser.operand import Operand
class TestSemanticTools(unittest.TestCase):
MODULE_DATA_DIR = os.path.join(
os.path.dirname(os.path.split(os.path.abspath(__file__))[0]), "osaca/data/"
@@ -94,7 +95,6 @@ class TestSemanticTools(unittest.TestCase):
)
cls.machine_model_zen = MachineModel(arch="zen1")
for i in range(len(cls.kernel_x86)):
cls.semantics_csx.assign_src_dst(cls.kernel_x86[i])
cls.semantics_csx.assign_tp_lt(cls.kernel_x86[i])
@@ -117,7 +117,6 @@ class TestSemanticTools(unittest.TestCase):
cls.semantics_a64fx.assign_src_dst(cls.kernel_aarch64_deps[i])
cls.semantics_a64fx.assign_tp_lt(cls.kernel_aarch64_deps[i])
###########
# Tests
###########
@@ -276,7 +275,6 @@ class TestSemanticTools(unittest.TestCase):
test_mm_x86.dump(stream=dev_null)
test_mm_arm.dump(stream=dev_null)
def test_src_dst_assignment_x86(self):
for instruction_form in self.kernel_x86:
with self.subTest(instruction_form=instruction_form):
@@ -380,7 +378,6 @@ class TestSemanticTools(unittest.TestCase):
dg.export_graph(filepath="/dev/null")
def test_memdependency_x86(self):
dg = KernelDG(
self.kernel_x86_memdep,
self.parser_x86,
@@ -468,7 +465,6 @@ class TestSemanticTools(unittest.TestCase):
dg.get_loopcarried_dependencies()
def test_loop_carried_dependency_aarch64(self):
dg = KernelDG(
self.kernel_aarch64_memdep,
self.parser_AArch64,
@@ -521,7 +517,6 @@ class TestSemanticTools(unittest.TestCase):
[(4, 1.0), (5, 1.0), (10, 1.0), (11, 1.0), (12, 1.0)],
)
def test_loop_carried_dependency_x86(self):
lcd_id = "8"
lcd_id2 = "5"