diff --git a/.gitignore b/.gitignore index a996ec7..426ace6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,5 @@ # OSACA specific files and folders *.*.pickle -osaca_testfront_venv/ -examples/riscy_asm_files/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/conf.py b/docs/conf.py index 27b1634..4a8d3a3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,6 +61,8 @@ html_theme = "sphinx_rtd_theme" # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] htmlhelp_basename = "osaca_doc" -html_sidebars = {"**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"]} +html_sidebars = { + "**": ["globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html"] +} autodoc_member_order = "bysource" diff --git a/examples/add/add.s.rv6.gcc.s b/examples/add/add.s.rv64.gcc.s similarity index 100% rename from examples/add/add.s.rv6.gcc.s rename to examples/add/add.s.rv64.gcc.s diff --git a/examples/copy/copy.s.rv6.gcc.s b/examples/copy/copy.s.rv64.gcc.s similarity index 100% rename from examples/copy/copy.s.rv6.gcc.s rename to examples/copy/copy.s.rv64.gcc.s diff --git a/examples/daxpy/daxpy.s.rv6.gcc.s b/examples/daxpy/daxpy.s.rv64.gcc.s similarity index 100% rename from examples/daxpy/daxpy.s.rv6.gcc.s rename to examples/daxpy/daxpy.s.rv64.gcc.s diff --git a/examples/gs/gs.s.rv6.gcc.s b/examples/gs/gs.s.rv64.gcc.s similarity index 100% rename from examples/gs/gs.s.rv6.gcc.s rename to examples/gs/gs.s.rv64.gcc.s diff --git a/examples/j2d/j2d.s.rv6.gcc.s b/examples/j2d/j2d.s.rv64.gcc.s similarity index 100% rename from examples/j2d/j2d.s.rv6.gcc.s rename to examples/j2d/j2d.s.rv64.gcc.s diff --git a/examples/striad/striad.s.rv6.gcc.s b/examples/striad/striad.s.rv64.gcc.s similarity index 100% rename from examples/striad/striad.s.rv6.gcc.s rename to examples/striad/striad.s.rv64.gcc.s diff --git a/examples/sum_reduction/sum_reduction.s.rv6.gcc.s b/examples/sum_reduction/sum_reduction.s.rv64.gcc.s similarity index 100% rename from examples/sum_reduction/sum_reduction.s.rv6.gcc.s rename to examples/sum_reduction/sum_reduction.s.rv64.gcc.s diff --git a/examples/triad/triad.s.rv6.gcc.s b/examples/triad/triad.s.rv64.gcc.s similarity index 100% rename from examples/triad/triad.s.rv6.gcc.s rename to examples/triad/triad.s.rv64.gcc.s diff --git a/examples/update/update.s.rv6.gcc.s b/examples/update/update.s.rv64.gcc.s similarity index 100% rename from examples/update/update.s.rv6.gcc.s rename to examples/update/update.s.rv64.gcc.s diff --git a/osaca/__init__.py b/osaca/__init__.py index 6eb46ce..bcf050c 100644 --- a/osaca/__init__.py +++ b/osaca/__init__.py @@ -8,4 +8,5 @@ __version__ = "0.7.0" # 2. commit to RRZE-HPC/osaca's master branch # 3. wait for Github Actions to complete successful (unless already tested) # 4. tag commit with 'v{}'.format(__version__) (`git tag vX.Y.Z`) -# 5. push tag to github (`git push origin vX.Y.Z` or push all tags with `git push --tags`) +# 5. push tag to github (`git push origin vX.Y.Z` or push all tags with +# `git push --tags`) diff --git a/osaca/data/_build_cache.py b/osaca/data/_build_cache.py index 74803cb..0390f96 100755 --- a/osaca/data/_build_cache.py +++ b/osaca/data/_build_cache.py @@ -10,9 +10,9 @@ try: from osaca.semantics.hw_model import MachineModel except ModuleNotFoundError: print( - "Unable to import MachineModel, probably some dependency is not yet installed. SKIPPING. " - "First run of OSACA may take a while to build caches, subsequent runs will be as fast as " - "ever." + "Unable to import MachineModel, probably some dependency is not yet " + "installed. SKIPPING. First run of OSACA may take a while to build " + "caches, subsequent runs will be as fast as ever." ) sys.exit() diff --git a/osaca/data/create_db_entry.py b/osaca/data/create_db_entry.py index c7b61ae..83b7720 100644 --- a/osaca/data/create_db_entry.py +++ b/osaca/data/create_db_entry.py @@ -19,7 +19,9 @@ class EntryBuilder: vec = False if any([vecr in operands_types for vecr in ["mm", "xmm", "ymm", "zmm"]]): vec = True - assert not (load and store), "Can not process a combined load-store instruction." + assert not ( + load and store + ), "Can not process a combined load-store instruction." return load, store, vec def build_description( @@ -37,7 +39,9 @@ class EntryBuilder: if ot == "imd": description += " - class: immediate\n imd: int\n" elif ot.startswith("mem"): - description += " - class: memory\n" ' base: "*"\n' ' offset: "*"\n' + description += ( + " - class: memory\n" ' base: "*"\n' ' offset: "*"\n' + ) if ot == "mem_simple": description += " index: ~\n" elif ot == "mem_complex": @@ -47,8 +51,10 @@ class EntryBuilder: description += ' scale: "*"\n' else: if "{k}" in ot: - description += " - class: register\n name: {}\n mask: True\n".format( - ot.replace("{k}", "") + description += ( + " - class: register\n name: {}\n mask: True\n".format( + ot.replace("{k}", "") + ) ) else: description += " - class: register\n name: {}\n".format(ot) @@ -87,22 +93,29 @@ class EntryBuilder: def process_item(self, instruction_form, resources): """ Example: - ('mov xmm mem', ('1*p45+2*p0', 7) -> ('mov', ['xmm', 'mem'], [[1, '45'], [2, '0']], 7) + ('mov xmm mem', ('1*p45+2*p0', 7) -> ('mov', ['xmm', 'mem'], + [[1, '45'], [2, '0']], 7) """ if instruction_form.startswith("[") and "]" in instruction_form: instr_elements = instruction_form.split("]") - instr_elements = [instr_elements[0] + "]"] + instr_elements[1].strip().split(" ") + instr_elements = [instr_elements[0] + "]"] + instr_elements[ + 1 + ].strip().split(" ") else: instr_elements = instruction_form.split(" ") latency = int(resources[1]) port_pressure = self.parse_port_pressure(resources[0]) instruction_name = instr_elements[0] operand_types = instr_elements[1:] - return self.build_description(instruction_name, operand_types, port_pressure, latency) + return self.build_description( + instruction_name, operand_types, port_pressure, latency + ) class ArchEntryBuilder(EntryBuilder): - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): # Intel ICX # LD_pressure = [[1, "23"], [1, ["2D", "3D"]]] # LD_pressure_vec = LD_pressure @@ -160,7 +173,9 @@ def get_description(instruction_form, port_pressure, latency, rhs_comment=None): commented_entry = "" for line in entry.split("\n"): - commented_entry += ("{:<" + str(max_length) + "} # {}\n").format(line, rhs_comment) + commented_entry += ("{:<" + str(max_length) + "} # {}\n").format( + line, rhs_comment + ) entry = commented_entry return entry @@ -170,7 +185,11 @@ if __name__ == "__main__": import sys if len(sys.argv) != 4 and len(sys.argv) != 5: - print("Usage: {} [COMMENT]".format(sys.argv[0])) + print( + "Usage: {} [COMMENT]".format( + sys.argv[0] + ) + ) sys.exit(0) try: diff --git a/osaca/data/generate_mov_entries.py b/osaca/data/generate_mov_entries.py index 0bde448..0934cc5 100644 --- a/osaca/data/generate_mov_entries.py +++ b/osaca/data/generate_mov_entries.py @@ -19,7 +19,9 @@ class MOVEntryBuilder: vec = False if any([vecr in operands_types for vecr in ["mm", "xmm", "ymm", "zmm"]]): vec = True - assert not (load and store), "Can not process a combined load-store instruction." + assert not ( + load and store + ), "Can not process a combined load-store instruction." return load, store, vec def build_description( @@ -35,7 +37,9 @@ class MOVEntryBuilder: if ot == "imd": description += " - class: immediate\n imd: int\n" elif ot.startswith("mem"): - description += " - class: memory\n" ' base: "*"\n' ' offset: "*"\n' + description += ( + " - class: memory\n" ' base: "*"\n' ' offset: "*"\n' + ) if ot == "mem_simple": description += " index: ~\n" elif ot == "mem_complex": @@ -77,19 +81,24 @@ class MOVEntryBuilder: def process_item(self, instruction_form, resources): """ Example: - ('mov xmm mem', ('1*p45+2*p0', 7) -> ('mov', ['xmm', 'mem'], [[1, '45'], [2, '0']], 7) + ('mov xmm mem', ('1*p45+2*p0', 7) -> ('mov', ['xmm', 'mem'], + [[1, '45'], [2, '0']], 7) """ instr_elements = instruction_form.split(" ") latency = resources[1] port_pressure = self.parse_port_pressure(resources[0]) instruction_name = instr_elements[0] operand_types = instr_elements[1:] - return self.build_description(instruction_name, operand_types, port_pressure, latency) + return self.build_description( + instruction_name, operand_types, port_pressure, latency + ) class MOVEntryBuilderIntelNoPort7AGU(MOVEntryBuilder): # for SNB and IVB - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): load, store, vec = self.classify(operand_types) comment = None @@ -117,7 +126,9 @@ class MOVEntryBuilderIntelNoPort7AGU(MOVEntryBuilder): class MOVEntryBuilderIntelPort11(MOVEntryBuilder): # for SPR - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): load, store, vec = self.classify(operand_types) if load: @@ -154,7 +165,9 @@ class MOVEntryBuilderIntelPort11(MOVEntryBuilder): class MOVEntryBuilderIntelPort9(MOVEntryBuilder): # for ICX - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): load, store, vec = self.classify(operand_types) if load: @@ -185,7 +198,9 @@ class MOVEntryBuilderIntelPort9(MOVEntryBuilder): class MOVEntryBuilderAMDZen3(MOVEntryBuilder): # for Zen 3 - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): load, store, vec = self.classify(operand_types) if load and vec: @@ -788,8 +803,6 @@ icx_mov_instructions = [ ("vmovshdup mem xmm", ("", 0)), ("vmovshdup ymm ymm", ("1*p15", 1)), ("vmovshdup mem ymm", ("", 0)), - ("vmovshdup zmm zmm", ("1*p5", 1)), - ("vmovshdup mem zmm", ("", 0)), # https://www.felixcloutier.com/x86/movsldup ("movsldup xmm xmm", ("1*p15", 1)), ("movsldup mem xmm", ("", 0)), @@ -797,8 +810,6 @@ icx_mov_instructions = [ ("vmovsldup mem xmm", ("", 0)), ("vmovsldup ymm ymm", ("1*p15", 1)), ("vmovsldup mem ymm", ("", 0)), - ("vmovsldup zmm zmm", ("1*p5", 1)), - ("vmovsldup mem zmm", ("", 0)), # https://www.felixcloutier.com/x86/movss ("movss xmm xmm", ("1*p015", 1)), ("movss mem xmm", ("", 0)), @@ -934,27 +945,426 @@ icx_mov_instructions = [ ("vpmovsxbq xmm xmm", ("1*p15", 1)), ("vpmovsxbq mem xmm", ("1*p15", 1)), ("vpmovsxbw xmm ymm", ("1*p5", 1)), - ("vpmovsxbw mem ymm", ("1*p5", 1)), + ("vpmovsxbw mem ymm", ("1*p12", 1)), ("vpmovsxbd xmm ymm", ("1*p5", 1)), - ("vpmovsxbd mem ymm", ("1*p5", 1)), + ("vpmovsxbd mem ymm", ("1*p12", 1)), ("vpmovsxbq xmm ymm", ("1*p5", 1)), - ("vpmovsxbq mem ymm", ("1*p5", 1)), - ("vpmovsxbw ymm zmm", ("1*p5", 3)), - ("vpmovsxbw mem zmm", ("1*p5", 1)), + ("vpmovsxbq mem ymm", ("1*p12", 1)), # https://www.felixcloutier.com/x86/pmovzx ("pmovzxbw xmm xmm", ("1*p15", 1)), ("pmovzxbw mem xmm", ("1*p15", 1)), ("vpmovzxbw xmm xmm", ("1*p15", 1)), ("vpmovzxbw mem xmm", ("1*p15", 1)), ("vpmovzxbw xmm ymm", ("1*p5", 1)), - ("vpmovzxbw mem ymm", ("1*p5", 1)), - ("vpmovzxbw ymm zmm", ("1*p5", 1)), - ("vpmovzxbw mem zmm", ("1*p5", 1)), + ("vpmovzxbw mem ymm", ("1*p12", 1)), ################################################################# # https://www.felixcloutier.com/x86/movbe ("movbe gpr mem", ("1*p15", 6)), ("movbe mem gpr", ("1*p15", 6)), ################################################ + # https://www.felixcloutier.com/x86/movq2dq + ("movq2dq mm xmm", ("2*p0123", 1)), +] + + +p11 = MOVEntryBuilderIntelPort11() + +spr_mov_instructions = [ + # https://www.felixcloutier.com/x86/mov + ("mov gpr gpr", ("1*p0,1,5,6,10", 1)), + ("mov gpr mem", ("", 0)), + ("mov mem gpr", ("", 0)), + ("mov imd gpr", ("1*p0,1,5,6,10", 1)), + ("mov imd mem", ("", 0)), + ("movabs imd gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + # https://www.felixcloutier.com/x86/movapd + ("movapd xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movapd xmm mem", ("", 0)), + ("movapd mem xmm", ("", 0)), + ("vmovapd xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovapd xmm mem", ("", 0)), + ("vmovapd mem xmm", ("", 0)), + ("vmovapd ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovapd ymm mem", ("", 0)), + ("vmovapd mem ymm", ("", 0)), + ("vmovapd zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovapd zmm mem", ("", 0)), + ("vmovapd mem zmm", ("", 0)), + # https://www.felixcloutier.com/x86/movaps + ("movaps xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movaps xmm mem", ("", 0)), + ("movaps mem xmm", ("", 0)), + ("vmovaps xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovaps xmm mem", ("", 0)), + ("vmovaps mem xmm", ("", 0)), + ("vmovaps ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovaps ymm mem", ("", 0)), + ("vmovaps mem ymm", ("", 0)), + ("vmovaps zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovaps zmm mem", ("", 0)), + ("vmovaps mem zmm", ("", 0)), + # # https://www.felixcloutier.com/x86/movd:movq + # ("movd gpr mm", ("1*p5", 1)), + # ("movd mem mm", ("", 0)), + # ("movq gpr mm", ("1*p5", 1)), + # ("movq mem mm", ("", 0)), + # ("movd mm gpr", ("1*p0", 1)), + # ("movd mm mem", ("", 0)), + # ("movq mm gpr", ("1*p0", 1)), + # ("movq mm mem", ("", 0)), + # ("movd gpr xmm", ("1*p5", 1)), + # ("movd mem xmm", ("", 0)), + # ("movq gpr xmm", ("1*p5", 1)), + # ("movq mem xmm", ("", 0)), + # ("movd xmm gpr", ("1*p0", 1)), + # ("movd xmm mem", ("", 0)), + # ("movq xmm gpr", ("1*p0", 1)), + # ("movq xmm mem", ("", 0)), + # ("vmovd gpr xmm", ("1*p5", 1)), + # ("vmovd mem xmm", ("", 0)), + # ("vmovq gpr xmm", ("1*p5", 1)), + # ("vmovq mem xmm", ("", 0)), + # ("vmovd xmm gpr", ("1*p0", 1)), + # ("vmovd xmm mem", ("", 0)), + # ("vmovq xmm gpr", ("1*p0", 1)), + # ("vmovq xmm mem", ("", 0)), + # # https://www.felixcloutier.com/x86/movddup + # ("movddup xmm xmm", ("1*p5", 1)), + # ("movddup mem xmm", ("", 0)), + # ("vmovddup xmm xmm", ("1*p5", 1)), + # ("vmovddup mem xmm", ("", 0)), + # ("vmovddup ymm ymm", ("1*p5", 1)), + # ("vmovddup mem ymm", ("", 0)), + # ("vmovddup zmm zmm", ("1*p5", 1)), + # ("vmovddup mem zmm", ("", 0)), + # https://www.felixcloutier.com/x86/movdq2q + # ("movdq2q xmm mm", ("1*p015+1*p5", 1)), + # https://www.felixcloutier.com/x86/movdqa:vmovdqa32:vmovdqa64 + ("movdqa xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movdqa mem xmm", ("", 0)), + ("movdqa xmm mem", ("", 0)), + ("vmovdqa xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa mem xmm", ("", 0)), + ("vmovdqa xmm mem", ("", 0)), + ("vmovdqa ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa mem ymm", ("", 0)), + ("vmovdqa ymm mem", ("", 0)), + ("vmovdqa32 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa32 mem xmm", ("", 0)), + ("vmovdqa32 xmm mem", ("", 0)), + ("vmovdqa32 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa32 mem ymm", ("", 0)), + ("vmovdqa32 ymm mem", ("", 0)), + ("vmovdqa32 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa32 mem zmm", ("", 0)), + ("vmovdqa32 zmm mem", ("", 0)), + ("vmovdqa64 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa64 mem xmm", ("", 0)), + ("vmovdqa64 xmm mem", ("", 0)), + ("vmovdqa64 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa64 mem ymm", ("", 0)), + ("vmovdqa64 ymm mem", ("", 0)), + ("vmovdqa64 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqa64 mem zmm", ("", 0)), + ("vmovdqa64 zmm mem", ("", 0)), + # https://www.felixcloutier.com/x86/movdqu:vmovdqu8:vmovdqu16:vmovdqu32:vmovdqu64 + ("movdqu xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movdqu mem xmm", ("", 0)), + ("movdqu xmm mem", ("", 0)), + ("vmovdqu xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu mem xmm", ("", 0)), + ("vmovdqu xmm mem", ("", 0)), + ("vmovdqu ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu mem ymm", ("", 0)), + ("vmovdqu ymm mem", ("", 0)), + ("vmovdqu8 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu8 mem xmm", ("", 0)), + ("vmovdqu8 xmm mem", ("", 0)), + ("vmovdqu8 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu8 mem ymm", ("", 0)), + ("vmovdqu8 ymm mem", ("", 0)), + ("vmovdqu8 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu8 mem zmm", ("", 0)), + ("vmovdqu8 zmm mem", ("", 0)), + ("vmovdqu16 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu16 mem xmm", ("", 0)), + ("vmovdqu16 xmm mem", ("", 0)), + ("vmovdqu16 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu16 mem ymm", ("", 0)), + ("vmovdqu16 ymm mem", ("", 0)), + ("vmovdqu16 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu16 mem zmm", ("", 0)), + ("vmovdqu16 zmm mem", ("", 0)), + ("vmovdqu32 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu32 mem xmm", ("", 0)), + ("vmovdqu32 xmm mem", ("", 0)), + ("vmovdqu32 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu32 mem ymm", ("", 0)), + ("vmovdqu32 ymm mem", ("", 0)), + ("vmovdqu32 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu32 mem zmm", ("", 0)), + ("vmovdqu32 zmm mem", ("", 0)), + ("vmovdqu64 xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu64 mem xmm", ("", 0)), + ("vmovdqu64 xmm mem", ("", 0)), + ("vmovdqu64 ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu64 mem ymm", ("", 0)), + ("vmovdqu64 ymm mem", ("", 0)), + ("vmovdqu64 zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovdqu64 mem zmm", ("", 0)), + ("vmovdqu64 zmm mem", ("", 0)), + # # https://www.felixcloutier.com/x86/movhlps + # ("movhlps xmm xmm", ("1*p5", 1)), + # ("vmovhlps xmm xmm xmm", ("1*p5", 1)), + # # https://www.felixcloutier.com/x86/movhpd + # ("movhpd mem xmm", ("1*p5", 1)), + # ("vmovhpd mem xmm xmm", ("1*p5", 1)), + # ("movhpd xmm mem", ("", 0)), + # ("vmovhpd mem xmm", ("", 0)), + # # https://www.felixcloutier.com/x86/movhps + # ("movhps mem xmm", ("1*p5", 1)), + # ("vmovhps mem xmm xmm", ("1*p5", 1)), + # ("movhps xmm mem", ("", 0)), + # ("vmovhps mem xmm", ("", 0)), + # # https://www.felixcloutier.com/x86/movlhps + # ("movlhps xmm xmm", ("1*p5", 1)), + # ("vmovlhps xmm xmm xmm", ("1*p5", 1)), + # # https://www.felixcloutier.com/x86/movlpd + # ("movlpd mem xmm", ("1*p5", 1)), + # ("vmovlpd mem xmm xmm", ("1*p5", 1)), + # ("movlpd xmm mem", ("", 0)), + # ("vmovlpd mem xmm", ("1*p5", 1)), + # # https://www.felixcloutier.com/x86/movlps + # ("movlps mem xmm", ("1*p5", 1)), + # ("vmovlps mem xmm xmm", ("1*p5", 1)), + # ("movlps xmm mem", ("", 0)), + # ("vmovlps mem xmm", ("1*p5", 1)), + # # https://www.felixcloutier.com/x86/movmskpd + # ("movmskpd xmm gpr", ("1*p0", 1)), + # ("vmovmskpd xmm gpr", ("1*p0", 1)), + # ("vmovmskpd ymm gpr", ("1*p0", 1)), + # # https://www.felixcloutier.com/x86/movmskps + # ("movmskps xmm gpr", ("1*p0", 1)), + # ("vmovmskps xmm gpr", ("1*p0", 1)), + # ("vmovmskps ymm gpr", ("1*p0", 1)), + # https://www.felixcloutier.com/x86/movntdq + ("movntdq xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdq xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdq ymm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdq zmm mem", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movntdqa + ("movntdqa mem xmm", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdqa mem xmm", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdqa mem ymm", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntdqa mem zmm", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movnti + ("movnti gpr mem", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movntpd + ("movntpd xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntpd xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntpd ymm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntpd zmm mem", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movntps + ("movntps xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntps xmm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntps ymm mem", ("", 0)), # TODO NT-store: what latency to use? + ("vmovntps zmm mem", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movntq + ("movntq mm mem", ("", 0)), # TODO NT-store: what latency to use? + # https://www.felixcloutier.com/x86/movq + ("movq mm mm", ("", 0)), + ("movq mem mm", ("", 0)), + ("movq mm mem", ("", 0)), + ("movq xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movq mem xmm", ("", 0)), + ("movq xmm mem", ("", 0)), + ("vmovq xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovq mem xmm", ("", 0)), + ("vmovq xmm mem", ("", 0)), + # https://www.felixcloutier.com/x86/movs:movsb:movsw:movsd:movsq + # TODO combined load-store is currently not supported + # ('movs mem mem', ()), + # https://www.felixcloutier.com/x86/movsd + ("movsd xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movsd mem xmm", ("", 0)), + ("movsd xmm mem", ("", 0)), + ("vmovsd xmm xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovsd mem xmm", ("", 0)), + ("vmovsd xmm mem", ("", 0)), + # # https://www.felixcloutier.com/x86/movshdup + # ("movshdup xmm xmm", ("1*p15", 1)), + # ("movshdup mem xmm", ("", 0)), + # ("vmovshdup xmm xmm", ("1*p15", 1)), + # ("vmovshdup mem xmm", ("", 0)), + # ("vmovshdup ymm ymm", ("1*p15", 1)), + # ("vmovshdup mem ymm", ("", 0)), + # ("vmovshdup zmm zmm", ("1*p5", 1)), + # ("vmovshdup mem zmm", ("", 0)), + # # https://www.felixcloutier.com/x86/movsldup + # ("movsldup xmm xmm", ("1*p15", 1)), + # ("movsldup mem xmm", ("", 0)), + # ("vmovsldup xmm xmm", ("1*p15", 1)), + # ("vmovsldup mem xmm", ("", 0)), + # ("vmovsldup ymm ymm", ("1*p15", 1)), + # ("vmovsldup mem ymm", ("", 0)), + # ("vmovsldup zmm zmm", ("1*p5", 1)), + # ("vmovsldup mem zmm", ("", 0)), + # https://www.felixcloutier.com/x86/movss + ("movss xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movss mem xmm", ("", 0)), + ("vmovss xmm xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovss mem xmm", ("", 0)), + ("vmovss xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovss xmm mem", ("", 0)), + ("movss mem xmm", ("", 0)), + # https://www.felixcloutier.com/x86/movsx:movsxd + ("movsx gpr gpr", ("1*p0,1,5,6,10", 1)), + ("movsx mem gpr", ("", 0)), + ("movsxd gpr gpr", ("", 0)), + ("movsxd mem gpr", ("", 0)), + ("movsb gpr gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + ("movsb mem gpr", ("", 0)), # AT&T version + ("movsw gpr gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + ("movsw mem gpr", ("", 0)), # AT&T version + ("movsl gpr gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + ("movsl mem gpr", ("", 0)), # AT&T version + ("movsq gpr gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + ("movsq mem gpr", ("", 0)), # AT&T version + # https://www.felixcloutier.com/x86/movupd + ("movupd xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movupd mem xmm", ("", 0)), + ("movupd xmm mem", ("", 0)), + ("vmovupd xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovupd mem xmm", ("", 0)), + ("vmovupd xmm mem", ("", 0)), + ("vmovupd ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovupd mem ymm", ("", 0)), + ("vmovupd ymm mem", ("", 0)), + ("vmovupd zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovupd mem zmm", ("", 0)), + ("vmovupd zmm mem", ("", 0)), + # https://www.felixcloutier.com/x86/movups + ("movups xmm xmm", ("1*p0,1,5,6,10", 1)), + ("movups mem xmm", ("", 0)), + ("movups xmm mem", ("", 0)), + ("vmovups xmm xmm", ("1*p0,1,5,6,10", 1)), + ("vmovups mem xmm", ("", 0)), + ("vmovups xmm mem", ("", 0)), + ("vmovups ymm ymm", ("1*p0,1,5,6,10", 1)), + ("vmovups mem ymm", ("", 0)), + ("vmovups ymm mem", ("", 0)), + ("vmovups zmm zmm", ("1*p0,1,5,6,10", 1)), + ("vmovups mem zmm", ("", 0)), + ("vmovups zmm mem", ("", 0)), + # # https://www.felixcloutier.com/x86/movzx + # ("movzx gpr gpr", ("1*p0,1,5,6,10", 1)), + # ("movzx mem gpr", ("", 0)), + # ("movzb gpr gpr", ("1*p0,1,5,6,10", 1)), # AT&T version + # ("movzb mem gpr", ("", 0)), # AT&T version + # ("movzw gpr gpr", ("1*p0,1,5,6,106", 1)), # AT&T version + # ("movzw mem gpr", ("", 0)), # AT&T version + # ("movzl gpr gpr", ("1*p0156", 1)), # AT&T version + # ("movzl mem gpr", ("", 0)), # AT&T version + # ("movzq gpr gpr", ("1*p0156", 1)), # AT&T version + # ("movzq mem gpr", ("", 0)), # AT&T version + # # https://www.felixcloutier.com/x86/cmovcc + # ("cmova gpr gpr", ("2*p06", 1)), + # ("cmova mem gpr", ("", 0)), + # ("cmovae gpr gpr", ("1*p06", 1)), + # ("cmovae mem gpr", ("", 0)), + # ("cmovb gpr gpr", ("2*p06", 1)), + # ("cmovb mem gpr", ("", 0)), + # ("cmovbe gpr gpr", ("2*p06", 1)), + # ("cmovbe mem gpr", ("", 0)), + # ("cmovc gpr gpr", ("1*p06", 1)), + # ("cmovc mem gpr", ("", 0)), + # ("cmove gpr gpr", ("1*p06", 1)), + # ("cmove mem gpr", ("", 0)), + # ("cmovg gpr gpr", ("1*p06", 1)), + # ("cmovg mem gpr", ("", 0)), + # ("cmovge gpr gpr", ("1*p06", 1)), + # ("cmovge mem gpr", ("", 0)), + # ("cmovl gpr gpr", ("1*p06", 1)), + # ("cmovl mem gpr", ("", 0)), + # ("cmovle gpr gpr", ("1*p06", 1)), + # ("cmovle mem gpr", ("", 0)), + # ("cmovna gpr gpr", ("2*p06", 1)), + # ("cmovna mem gpr", ("", 0)), + # ("cmovnae gpr gpr", ("1*p06", 1)), + # ("cmovnae mem gpr", ("", 0)), + # ("cmovnb gpr gpr", ("1*p06", 1)), + # ("cmovnb mem gpr", ("", 0)), + # ("cmovnbe gpr gpr", ("2*p06", 1)), + # ("cmovnbe mem gpr", ("", 0)), + # ("cmovnc gpr gpr", ("1*p06", 1)), + # ("cmovnc mem gpr", ("", 0)), + # ("cmovne gpr gpr", ("1*p06", 1)), + # ("cmovne mem gpr", ("", 0)), + # ("cmovng gpr gpr", ("1*p06", 1)), + # ("cmovng mem gpr", ("", 0)), + # ("cmovnge gpr gpr", ("1*p06", 1)), + # ("cmovnge mem gpr", ("", 0)), + # ("cmovnl gpr gpr", ("1*p06", 1)), + # ("cmovnl mem gpr", ("", 0)), + # ("cmovno gpr gpr", ("1*p06", 1)), + # ("cmovno mem gpr", ("", 0)), + # ("cmovnp gpr gpr", ("1*p06", 1)), + # ("cmovnp mem gpr", ("", 0)), + # ("cmovns gpr gpr", ("1*p06", 1)), + # ("cmovns mem gpr", ("", 0)), + # ("cmovnz gpr gpr", ("1*p06", 1)), + # ("cmovnz mem gpr", ("", 0)), + # ("cmovo gpr gpr", ("1*p06", 1)), + # ("cmovo mem gpr", ("", 0)), + # ("cmovp gpr gpr", ("1*p06", 1)), + # ("cmovp mem gpr", ("", 0)), + # ("cmovpe gpr gpr", ("1*p06", 1)), + # ("cmovpe mem gpr", ("", 0)), + # ("cmovpo gpr gpr", ("1*p06", 1)), + # ("cmovpo mem gpr", ("", 0)), + # ("cmovs gpr gpr", ("1*p06", 1)), + # ("cmovs mem gpr", ("", 0)), + # ("cmovz gpr gpr", ("1*p06", 1)), + # ("cmovz mem gpr", ("", 0)), + # # https://www.felixcloutier.com/x86/pmovmskb + # ("pmovmskb mm gpr", ("1*p0", 1)), + # ("pmovmskb xmm gpr", ("1*p0", 1)), + # ("vpmovmskb xmm gpr", ("1*p0", 1)), + # # https://www.felixcloutier.com/x86/pmovsx + # ("pmovsxbw xmm xmm", ("1*p15", 1)), + # ("pmovsxbw mem xmm", ("1*p15", 1)), + # ("pmovsxbd xmm xmm", ("1*p15", 1)), + # ("pmovsxbd mem xmm", ("1*p15", 1)), + # ("pmovsxbq xmm xmm", ("1*p15", 1)), + # ("pmovsxbq mem xmm", ("1*p15", 1)), + # ("vpmovsxbw xmm xmm", ("1*p15", 1)), + # ("vpmovsxbw mem xmm", ("1*p15", 1)), + # ("vpmovsxbd xmm xmm", ("1*p15", 1)), + # ("vpmovsxbd mem xmm", ("1*p15", 1)), + # ("vpmovsxbq xmm xmm", ("1*p15", 1)), + # ("vpmovsxbq mem xmm", ("1*p15", 1)), + # ("vpmovsxbw xmm ymm", ("1*p5", 1)), + # ("vpmovsxbw mem ymm", ("1*p5", 1)), + # ("vpmovsxbd xmm ymm", ("1*p5", 1)), + # ("vpmovsxbd mem ymm", ("1*p5", 1)), + # ("vpmovsxbq xmm ymm", ("1*p5", 1)), + # ("vpmovsxbq mem ymm", ("1*p5", 1)), + # ("vpmovsxbw ymm zmm", ("1*p5", 3)), + # ("vpmovsxbw mem zmm", ("1*p5", 1)), + # # https://www.felixcloutier.com/x86/pmovzx + # ("pmovzxbw xmm xmm", ("1*p15", 1)), + # ("pmovzxbw mem xmm", ("1*p15", 1)), + # ("vpmovzxbw xmm xmm", ("1*p15", 1)), + # ("vpmovzxbw mem xmm", ("1*p15", 1)), + # ("vpmovzxbw xmm ymm", ("1*p5", 1)), + # ("vpmovzxbw mem ymm", ("1*p5", 1)), + # ("vpmovzxbw ymm zmm", ("1*p5", 1)), + # ("vpmovzxbw mem zmm", ("1*p5", 1)), + ################################################################## + # # https://www.felixcloutier.com/x86/movbe + # ("movbe gpr mem", ("1*p15", 6)), + # ("movbe mem gpr", ("1*p15", 6)), + ################################################ # https://www.felixcloutier.com/x86/movapd # TODO with masking! # https://www.felixcloutier.com/x86/movaps @@ -966,7 +1376,7 @@ icx_mov_instructions = [ # https://www.felixcloutier.com/x86/movdqu:vmovdqu8:vmovdqu16:vmovdqu32:vmovdqu64 # TODO with masking! # https://www.felixcloutier.com/x86/movq2dq - ("movq2dq mm xmm", ("1*p0+1*p015", 1)), + # ("movq2dq mm xmm", ("1*p0+1*p015", 1)), # https://www.felixcloutier.com/x86/movsd # TODO with masking! # https://www.felixcloutier.com/x86/movshdup @@ -1413,7 +1823,9 @@ spr_mov_instructions = [ class MOVEntryBuilderIntelWithPort7AGU(MOVEntryBuilder): # for HSW, BDW, SKX and CSX - def build_description(self, instruction_name, operand_types, port_pressure=[], latency=0): + def build_description( + self, instruction_name, operand_types, port_pressure=[], latency=0 + ): load, store, vec = self.classify(operand_types) if load: @@ -1427,7 +1839,9 @@ class MOVEntryBuilderIntelWithPort7AGU(MOVEntryBuilder): port_pressure_simple = port_pressure + [[1, "237"], [1, "4"]] operands_simple = ["mem_simple" if o == "mem" else o for o in operand_types] port_pressure_complex = port_pressure + [[1, "23"], [1, "4"]] - operands_complex = ["mem_complex" if o == "mem" else o for o in operand_types] + operands_complex = [ + "mem_complex" if o == "mem" else o for o in operand_types + ] latency += 0 return ( MOVEntryBuilder.build_description( diff --git a/osaca/data/model_importer.py b/osaca/data/model_importer.py index f1ab348..03c520c 100644 --- a/osaca/data/model_importer.py +++ b/osaca/data/model_importer.py @@ -47,7 +47,9 @@ def port_pressure_from_tag_attributes(attrib): def extract_paramters(instruction_tag, parser, isa): # Extract parameter components parameters = [] # used to store string representations - parameter_tags = sorted(instruction_tag.findall("operand"), key=lambda p: int(p.attrib["idx"])) + parameter_tags = sorted( + instruction_tag.findall("operand"), key=lambda p: int(p.attrib["idx"]) + ) for parameter_tag in parameter_tags: parameter = {} # Ignore parameters with suppressed=1 @@ -68,7 +70,9 @@ def extract_paramters(instruction_tag, parser, isa): parameters.append(parameter) elif p_type == "reg": parameter["class"] = "register" - possible_regs = [parser.parse_register("%" + r) for r in parameter_tag.text.split(",")] + possible_regs = [ + parser.parse_register("%" + r) for r in parameter_tag.text.split(",") + ] if possible_regs[0] is None: raise ValueError( "Unknown register type for {} with {}.".format( @@ -170,10 +174,14 @@ def extract_model(tree, arch, skip_mem=True): if throughput == float("inf"): throughput = None uops = ( - int(measurement_tag.attrib["uops"]) if "uops" in measurement_tag.attrib else None + int(measurement_tag.attrib["uops"]) + if "uops" in measurement_tag.attrib + else None ) if "ports" in measurement_tag.attrib: - port_pressure.append(port_pressure_from_tag_attributes(measurement_tag.attrib)) + port_pressure.append( + port_pressure_from_tag_attributes(measurement_tag.attrib) + ) latencies = [ int(l_tag.attrib["cycles"]) for l_tag in measurement_tag.iter("latency") @@ -250,7 +258,9 @@ def extract_model(tree, arch, skip_mem=True): mm.add_port(p) throughput = max(mm.average_port_pressure(port_pressure)) - mm.set_instruction(mnemonic, parameters, latency, port_pressure, throughput, uops) + mm.set_instruction( + mnemonic, parameters, latency, port_pressure, throughput, uops + ) # TODO eliminate entries which could be covered by automatic load / store expansion return mm @@ -260,7 +270,9 @@ def rhs_comment(uncommented_string, comment): commented_string = "" for line in uncommented_string.split("\n"): - commented_string += ("{:<" + str(max_length) + "} # {}\n").format(line, comment) + commented_string += ("{:<" + str(max_length) + "} # {}\n").format( + line, comment + ) return commented_string diff --git a/osaca/data/pmevo_importer.py b/osaca/data/pmevo_importer.py index ba8d041..5e21712 100644 --- a/osaca/data/pmevo_importer.py +++ b/osaca/data/pmevo_importer.py @@ -65,7 +65,17 @@ def build_bench_instruction(name, operands): constraint = "0" elif name in ["and", "ands", "eor", "eors", "orr", "orrs"]: constraint = "255" - elif name in ["bfi", "extr", "sbfiz", "sbfx", "shl", "sshr", "ubfiz", "ubfx", "ushr"]: + elif name in [ + "bfi", + "extr", + "sbfiz", + "sbfx", + "shl", + "sshr", + "ubfiz", + "ubfx", + "ushr", + ]: constraint = "7" else: constraint = "42" @@ -75,7 +85,9 @@ def build_bench_instruction(name, operands): shift += " {}".format(operand["shift"]) else: return None - asmbench_inst += "{}{{{}:{}:{}}}{}".format(separator, direction, shape, constraint, shift) + asmbench_inst += "{}{{{}:{}:{}}}{}".format( + separator, direction, shape, constraint, shift + ) direction = "src" separator = ", " return asmbench_inst @@ -119,7 +131,9 @@ def operand_parse(op, state): elif bits == "64": parameter["prefix"] = "x" else: - raise ValueError("Invalid register bits for {} {}".format(register_type, bits)) + raise ValueError( + "Invalid register bits for {} {}".format(register_type, bits) + ) elif register_type == "F": if bits == "32": parameter["prefix"] = "s" @@ -145,7 +159,9 @@ def operand_parse(op, state): else: raise ValueError("Invalid vector shape {}".format(vec_shape)) else: - raise ValueError("Invalid register bits for {} {}".format(register_type, bits)) + raise ValueError( + "Invalid register bits for {} {}".format(register_type, bits) + ) else: raise ValueError("Unknown register type {}".format(register_type)) elif op.startswith("_[((MEM:"): @@ -272,7 +288,9 @@ def extract_model(mapping, arch, template_model, asmbench): if bench_throughput is not None: throughput = round_cycles(bench_throughput) else: - print("Failed to measure throughput for instruction {}.".format(insn)) + print( + "Failed to measure throughput for instruction {}.".format(insn) + ) if bench_latency is not None: latency = round_cycles(bench_latency) else: @@ -283,7 +301,9 @@ def extract_model(mapping, arch, template_model, asmbench): # Insert instruction if not already found (can happen with template) if mm.get_instruction(name, operands) is None: - mm.set_instruction(name, operands, latency, port_pressure, throughput, uops) + mm.set_instruction( + name, operands, latency, port_pressure, throughput, uops + ) except ValueError as e: print("Failed to parse instruction {}: {}.".format(insn, e)) @@ -295,7 +315,9 @@ def main(): parser.add_argument("json", help="path of mapping.json") parser.add_argument("yaml", help="path of template.yml", nargs="?") parser.add_argument( - "--asmbench", help="Benchmark latency and throughput using asmbench.", action="store_true" + "--asmbench", + help="Benchmark latency and throughput using asmbench.", + action="store_true", ) args = parser.parse_args() diff --git a/osaca/data/rv64.yml b/osaca/data/rv64.yml index fdfca03..03e92bb 100644 --- a/osaca/data/rv64.yml +++ b/osaca/data/rv64.yml @@ -661,7 +661,7 @@ instruction_forms: latency: 3 throughput: 1 port_pressure: [[1, ["FP"]]] - + - name: VFMADD.VV operands: - class: register diff --git a/osaca/db_interface.py b/osaca/db_interface.py index ce02418..8f78f7c 100644 --- a/osaca/db_interface.py +++ b/osaca/db_interface.py @@ -16,7 +16,9 @@ from osaca.parser.immediate import ImmediateOperand from osaca.parser.instruction_form import InstructionForm -def sanity_check(arch: str, verbose=False, internet_check=False, output_file=sys.stdout): +def sanity_check( + arch: str, verbose=False, internet_check=False, output_file=sys.stdout +): """ Checks the database for missing TP/LT values, instructions might missing int the ISA DB and duplicate instructions. @@ -151,8 +153,12 @@ def _get_asmbench_output(input_data, isa): entry = InstructionForm( mnemonic=mnemonic_parsed, operands=operands, - throughput=_validate_measurement(float(input_data[i + 2].split()[1]), "tp"), - latency=_validate_measurement(float(input_data[i + 1].split()[1]), "lt"), + throughput=_validate_measurement( + float(input_data[i + 2].split()[1]), "tp" + ), + latency=_validate_measurement( + float(input_data[i + 1].split()[1]), "lt" + ), port_pressure=None, ) if not entry.throughput or not entry.latency: @@ -350,7 +356,9 @@ def _scrape_from_felixcloutier(mnemonic): elif r.status_code == 404: # Check for alternative href index = BeautifulSoup(requests.get(url=index).text, "html.parser") - alternatives = [ref for ref in index.findAll("a") if ref.text == mnemonic.upper()] + alternatives = [ + ref for ref in index.findAll("a") if ref.text == mnemonic.upper() + ] if len(alternatives) > 0: # alternative(s) found, take first one url = base_url + alternatives[0].attrs["href"][2:] @@ -412,20 +420,21 @@ def _check_sanity_arch_db(arch_mm, isa_mm, internet_check=True): suspicious_prefixes_x86 = ["vfm", "fm"] suspicious_prefixes_arm = ["fml", "ldp", "stp", "str"] suspicious_prefixes_riscv = [ - "vse", # Vector store (register is source, memory is destination) + "vse", # Vector store (register is source, memory is destination) "vfmacc", # Vector FMA with accumulation (first operand is both source and destination) - "vfmadd", # Vector FMA with addition (first operand is implicitly both source and destination) - "vset", # Vector configuration (complex operand pattern) - "csrs", # CSR Set (first operand is both source and destination) - "csrc", # CSR Clear (first operand is both source and destination) - "csrsi", # CSR Set Immediate (first operand is both source and destination) - "csrci", # CSR Clear Immediate (first operand is both source and destination) - "amo", # Atomic memory operations (read-modify-write to memory) - "lr", # Load-Reserved (part of atomic operations) - "sc", # Store-Conditional (part of atomic operations) - "czero", # Conditional zero instructions (Zicond extension) + "vfmadd", # Vector FMA with addition (first operand is implicitly both + # source and destination) + "vset", # Vector configuration (complex operand pattern) + "csrs", # CSR Set (first operand is both source and destination) + "csrc", # CSR Clear (first operand is both source and destination) + "csrsi", # CSR Set Immediate (first operand is both source and destination) + "csrci", # CSR Clear Immediate (first operand is both source and destination) + "amo", # Atomic memory operations (read-modify-write to memory) + "lr", # Load-Reserved (part of atomic operations) + "sc", # Store-Conditional (part of atomic operations) + "czero", # Conditional zero instructions (Zicond extension) ] - + # Default to empty list if ISA not recognized suspicious_prefixes = [] @@ -458,7 +467,10 @@ def _check_sanity_arch_db(arch_mm, isa_mm, internet_check=True): for prefix in suspicious_prefixes: if instr_form["name"].lower().startswith(prefix): # check if instruction in ISA DB - if isa_mm.get_instruction(instr_form["name"], instr_form["operands"]) is None: + if ( + isa_mm.get_instruction(instr_form["name"], instr_form["operands"]) + is None + ): # if not, mark them as suspicious and print it on the screen suspicious_instructions.append(instr_form) # instr forms with less than 3 operands might need an ISA DB entry due to src_reg operands @@ -468,7 +480,8 @@ def _check_sanity_arch_db(arch_mm, isa_mm, internet_check=True): and "mov" not in instr_form["name"].lower() and not instr_form["name"].lower().startswith("j") and instr_form not in suspicious_instructions - and isa_mm.get_instruction(instr_form["name"], instr_form["operands"]) is None + and isa_mm.get_instruction(instr_form["name"], instr_form["operands"]) + is None ): # validate with data from internet if connected flag is set if internet_check: @@ -556,7 +569,9 @@ def _get_sanity_report( s += "{} duplicate instruction forms in uarch DB.\n".format(len(dup_arch)) s += "{} duplicate instruction forms in ISA DB.\n".format(len(dup_isa)) s += ( - "{} instruction forms in ISA DB are not referenced by instruction ".format(len(only_isa)) + "{} instruction forms in ISA DB are not referenced by instruction ".format( + len(only_isa) + ) + "forms in uarch DB.\n" ) s += "{} bad operands found in uarch DB\n".format(len(bad_operands)) @@ -602,13 +617,19 @@ def _get_sanity_report_verbose( s = "Instruction forms without throughput value:\n" if m_tp else "" for instr_form in sorted(m_tp, key=lambda i: i["name"]): - s += "{}{}{}\n".format(BRIGHT_BLUE, _get_full_instruction_name(instr_form), WHITE) + s += "{}{}{}\n".format( + BRIGHT_BLUE, _get_full_instruction_name(instr_form), WHITE + ) s += "Instruction forms without latency value:\n" if m_l else "" for instr_form in sorted(m_l, key=lambda i: i["name"]): - s += "{}{}{}\n".format(BRIGHT_RED, _get_full_instruction_name(instr_form), WHITE) + s += "{}{}{}\n".format( + BRIGHT_RED, _get_full_instruction_name(instr_form), WHITE + ) s += "Instruction forms without port pressure assignment:\n" if m_pp else "" for instr_form in sorted(m_pp, key=lambda i: i["name"]): - s += "{}{}{}\n".format(BRIGHT_MAGENTA, _get_full_instruction_name(instr_form), WHITE) + s += "{}{}{}\n".format( + BRIGHT_MAGENTA, _get_full_instruction_name(instr_form), WHITE + ) s += "Instruction forms which might miss an ISA DB entry:\n" if suspic_instr else "" for instr_form in sorted(suspic_instr, key=lambda i: i["name"]): s += "{}{}{}{}\n".format( @@ -622,13 +643,25 @@ def _get_sanity_report_verbose( s += "{}{}{}\n".format(YELLOW, _get_full_instruction_name(instr_form), WHITE) s += "Duplicate instruction forms in ISA DB:\n" if dup_isa else "" for instr_form in sorted(dup_isa, key=lambda i: i["name"]): - s += "{}{}{}\n".format(BRIGHT_YELLOW, _get_full_instruction_name(instr_form), WHITE) - s += "Instruction forms existing in ISA DB but not in uarch DB:\n" if only_isa else "" + s += "{}{}{}\n".format( + BRIGHT_YELLOW, _get_full_instruction_name(instr_form), WHITE + ) + s += ( + "Instruction forms existing in ISA DB but not in uarch DB:\n" + if only_isa + else "" + ) for instr_form in sorted(only_isa, key=lambda i: i["name"]): s += "{}{}{}\n".format(CYAN, _get_full_instruction_name(instr_form), WHITE) - s += "{} bad operands found in uarch DB:\n".format(len(bad_operands)) if bad_operands else "" + s += ( + "{} bad operands found in uarch DB:\n".format(len(bad_operands)) + if bad_operands + else "" + ) for instr_form in sorted(bad_operands, key=lambda i: i["name"]): - s += "{}{}{}\n".format(BRIGHT_RED, _get_full_instruction_name(instr_form), WHITE) + s += "{}{}{}\n".format( + BRIGHT_RED, _get_full_instruction_name(instr_form), WHITE + ) return s @@ -688,4 +721,6 @@ def __dump_data_to_yaml(filepath, data): default_style="|", ) # finally, add instruction forms - ruamel.yaml.dump({"instruction_forms": data["instruction_forms"]}, f, allow_unicode=True) + ruamel.yaml.dump( + {"instruction_forms": data["instruction_forms"]}, f, allow_unicode=True + ) diff --git a/osaca/frontend.py b/osaca/frontend.py index 763eb5c..3358517 100644 --- a/osaca/frontend.py +++ b/osaca/frontend.py @@ -12,7 +12,9 @@ from osaca.semantics import INSTR_FLAGS, ArchSemantics, KernelDG, MachineModel def _get_version(*file_paths): """Searches for a version attribute in the given file(s)""" - with io.open(os.path.join(os.path.dirname(__file__), *file_paths), encoding="utf-8") as fp: + with io.open( + os.path.join(os.path.dirname(__file__), *file_paths), encoding="utf-8" + ) as fp: version_file = fp.read() version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) if version_match: @@ -53,7 +55,9 @@ class Frontend(object): :type instruction_form: `dict` :returns: `True` if comment line, `False` otherwise """ - return instruction_form.comment is not None and instruction_form.mnemonic is None + return ( + instruction_form.comment is not None and instruction_form.mnemonic is None + ) def throughput_analysis(self, kernel, show_lineno=False, show_cmnts=True): """ @@ -69,7 +73,9 @@ class Frontend(object): lineno_filler = " " if show_lineno else "" port_len = self._get_max_port_len(kernel) separator = "-" * sum([x + 3 for x in port_len]) + "-" - separator += "--" + len(str(kernel[-1]["line_number"])) * "-" if show_lineno else "" + separator += ( + "--" + len(str(kernel[-1]["line_number"])) * "-" if show_lineno else "" + ) col_sep = "|" sep_list = self._get_separator_list(col_sep) headline = "Port pressure in cycles" @@ -92,13 +98,19 @@ class Frontend(object): ), instruction_form.line.strip().replace("\t", " "), ) - line = line if show_lineno else col_sep + col_sep.join(line.split(col_sep)[1:]) + line = ( + line if show_lineno else col_sep + col_sep.join(line.split(col_sep)[1:]) + ) if show_cmnts is False and self._is_comment(instruction_form): continue s += line + "\n" s += "\n" tp_sum = ArchSemantics.get_throughput_sum(kernel) - s += lineno_filler + self._get_port_pressure(tp_sum, port_len, separator=" ") + "\n" + s += ( + lineno_filler + + self._get_port_pressure(tp_sum, port_len, separator=" ") + + "\n" + ) return s def latency_analysis(self, cp_kernel, separator="|"): @@ -125,7 +137,8 @@ class Frontend(object): ) + "\n" s += ( "\n{:4} {} {:4.1f}".format( - " " * max([len(str(instr_form.line_number)) for instr_form in cp_kernel]), + " " + * max([len(str(instr_form.line_number)) for instr_form in cp_kernel]), " " * len(separator), sum([instr_form.latency_cp for instr_form in cp_kernel]), ) @@ -341,7 +354,8 @@ class Frontend(object): longest_lcd = max(dep_dict, key=lambda ln: dep_dict[ln]["latency"]) lcd_sum = dep_dict[longest_lcd]["latency"] lcd_lines = { - instr.line_number: lat for instr, lat in dep_dict[longest_lcd]["dependencies"] + instr.line_number: lat + for instr, lat in dep_dict[longest_lcd]["dependencies"] } port_line = ( @@ -425,7 +439,8 @@ class Frontend(object): "-------------------------- WARNING: No micro-architecture was specified " "-------------------------\n" " A default uarch for this particular ISA was used. Specify " - "the uarch with --arch.\n See --help for more information.\n" + dashed_line + "the uarch with --arch.\n See --help for more information.\n" + + dashed_line ) length_text = ( "----------------- WARNING: You are analyzing a large amount of instruction forms " @@ -465,7 +480,11 @@ class Frontend(object): for i in range(len(self._machine_model.get_ports()) - 1): match_1 = re.search(r"\d+", self._machine_model.get_ports()[i]) match_2 = re.search(r"\d+", self._machine_model.get_ports()[i + 1]) - if match_1 is not None and match_2 is not None and match_1.group() == match_2.group(): + if ( + match_1 is not None + and match_2 is not None + and match_1.group() == match_2.group() + ): separator_list.append(separator_2) else: separator_list.append(separator) @@ -488,11 +507,20 @@ class Frontend(object): separator = [separator for x in ports] string_result = "{} ".format(separator[-1]) for i in range(len(ports)): - if float(ports[i]) == 0.0 and self._machine_model.get_ports()[i] not in used_ports: + if ( + float(ports[i]) == 0.0 + and self._machine_model.get_ports()[i] not in used_ports + ): string_result += port_len[i] * " " + " {} ".format(separator[i]) continue left_len = len(str(float(ports[i])).split(".")[0]) - substr = "{:" + str(left_len) + "." + str(max(port_len[i] - left_len - 1, 0)) + "f}" + substr = ( + "{:" + + str(left_len) + + "." + + str(max(port_len[i] - left_len - 1, 0)) + + "f}" + ) substr = substr.format(ports[i]) string_result += ( substr + " {} ".format(separator[i]) @@ -513,7 +541,9 @@ class Frontend(object): lat_cp = float(self._get_node_by_lineno(line_number, cp_dg).latency_cp) if dep_lat is not None: lat_lcd = float(dep_lat) - return "{} {:>4} {} {:>4} {}".format(separator, lat_cp, separator, lat_lcd, separator) + return "{} {:>4} {} {:>4} {}".format( + separator, lat_cp, separator, lat_lcd, separator + ) def _get_max_port_len(self, kernel): """Returns the maximal length needed to print all throughputs of the kernel.""" @@ -530,7 +560,9 @@ class Frontend(object): separator_list = self._get_separator_list(separator, "-") for i, length in enumerate(port_len): substr = "{:^" + str(length + 2) + "s}" - string_result += substr.format(self._machine_model.get_ports()[i]) + separator_list[i] + string_result += ( + substr.format(self._machine_model.get_ports()[i]) + separator_list[i] + ) return string_result def _header_report(self): @@ -538,7 +570,9 @@ class Frontend(object): version = _get_version("__init__.py") adjust = 20 header = "" - header += "Open Source Architecture Code Analyzer (OSACA) - {}\n".format(version) + header += "Open Source Architecture Code Analyzer (OSACA) - {}\n".format( + version + ) header += "Analyzed file:".ljust(adjust) + "{}\n".format(self._filename) header += "Architecture:".ljust(adjust) + "{}\n".format(self._arch.upper()) header += "Timestamp:".ljust(adjust) + "{}\n".format( @@ -566,7 +600,9 @@ class Frontend(object): } symbol_map = "" for flag in sorted(symbol_dict.keys()): - symbol_map += " {} - {}\n".format(self._get_flag_symbols([flag]), symbol_dict[flag]) + symbol_map += " {} - {}\n".format( + self._get_flag_symbols([flag]), symbol_dict[flag] + ) return symbol_map def _port_binding_summary(self): diff --git a/osaca/osaca.py b/osaca/osaca.py index 755face..657b977 100644 --- a/osaca/osaca.py +++ b/osaca/osaca.py @@ -11,7 +11,13 @@ from ruamel.yaml import YAML from osaca.db_interface import import_benchmark_output, sanity_check from osaca.frontend import Frontend -from osaca.parser import BaseParser, ParserAArch64, ParserX86ATT, ParserX86Intel, ParserRISCV +from osaca.parser import ( + BaseParser, + ParserAArch64, + ParserX86ATT, + ParserX86Intel, + ParserRISCV, +) from osaca.semantics import ( INSTR_FLAGS, ArchSemantics, @@ -111,8 +117,8 @@ def create_parser(parser=None): "--arch", type=str, help="Define architecture (SNB, IVB, HSW, BDW, SKX, CSX, ICL, ICX, SPR, ZEN1, ZEN2, ZEN3, " - "ZEN4, TX2, N1, A64FX, TSV110, A72, M1, V2, RV64). If no architecture is given, OSACA assumes a " - "default uarch for the detected ISA.", + "ZEN4, TX2, N1, A64FX, TSV110, A72, M1, V2, RV64). If no architecture is given, " + "OSACA assumes a default uarch for the detected ISA.", ) parser.add_argument( "--syntax", @@ -245,7 +251,9 @@ def check_arguments(args, parser): "Microarchitecture not supported. Please see --help for all valid architecture codes." ) if args.syntax and args.arch and MachineModel.get_isa_for_arch(args.arch) != "x86": - parser.error("Syntax can only be explicitly specified for an x86 microarchitecture") + parser.error( + "Syntax can only be explicitly specified for an x86 microarchitecture" + ) if args.syntax: args.syntax = args.syntax.upper() if args.syntax not in SUPPORTED_SYNTAXES: @@ -303,7 +311,9 @@ def insert_byte_marker(args): # Check if ISA is RISC-V and raise NotImplementedError isa = MachineModel.get_isa_for_arch(args.arch) if isa == "riscv": - raise NotImplementedError("Marker insertion is not supported for RISC-V architecture.") + raise NotImplementedError( + "Marker insertion is not supported for RISC-V architecture." + ) assembly = args.file.read() unmarked_assembly = io.StringIO(assembly) @@ -408,7 +418,12 @@ def inspect(args, output_file=sys.stdout): # Create DiGrahps kernel_graph = KernelDG( - kernel, parser, machine_model, semantics, args.lcd_timeout, args.consider_flag_deps + kernel, + parser, + machine_model, + semantics, + args.lcd_timeout, + args.consider_flag_deps, ) if args.dotpath is not None: kernel_graph.export_graph(args.dotpath if args.dotpath != "." else None) @@ -459,7 +474,9 @@ def run(args, output_file=sys.stdout): ) elif "import_data" in args: # Import microbench output file into DB - import_data(args.import_data, args.arch, args.file.name, output_file=output_file) + import_data( + args.import_data, args.arch, args.file.name, output_file=output_file + ) elif args.insert_marker: # Try to add IACA marker insert_byte_marker(args) @@ -487,11 +504,15 @@ def get_asm_parser(arch, syntax="ATT") -> BaseParser: else: raise ValueError("Unknown ISA: {}".format(isa)) + def get_unmatched_instruction_ratio(kernel): """Return ratio of unmatched from total instructions in kernel.""" unmatched_counter = 0 for instruction in kernel: - if INSTR_FLAGS.TP_UNKWN in instruction.flags and INSTR_FLAGS.LT_UNKWN in instruction.flags: + if ( + INSTR_FLAGS.TP_UNKWN in instruction.flags + and INSTR_FLAGS.LT_UNKWN in instruction.flags + ): unmatched_counter += 1 return unmatched_counter / len(kernel) diff --git a/osaca/parser/base_parser.py b/osaca/parser/base_parser.py index 7606933..fa10636 100644 --- a/osaca/parser/base_parser.py +++ b/osaca/parser/base_parser.py @@ -1,5 +1,5 @@ # TODO: Heuristics for detecting the RISCV ISA -#!/usr/bin/env python3 +#!/usr/bin/env python3 # noqa: E265 """Parser superclass of specific parsers.""" import operator import re @@ -72,14 +72,19 @@ class BaseParser(object): # 3) check for RISC-V registers (x0-x31, a0-a7, t0-t6, s0-s11) and instructions heuristics_riscv = [ r"\bx[0-9]|x[1-2][0-9]|x3[0-1]\b", # x0-x31 registers - r"\ba[0-7]\b", # a0-a7 registers - r"\bt[0-6]\b", # t0-t6 registers - r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers - r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers - r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions - r"\baddi\b|\bsd\b|\bld\b|\bjal\b" # Common RISC-V instructions + r"\ba[0-7]\b", # a0-a7 registers + r"\bt[0-6]\b", # t0-t6 registers + r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers + r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers + r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions + r"\baddi\b|\bsd\b|\bld\b|\bjal\b", # Common RISC-V instructions ] - matches = {("x86", "ATT"): 0, ("x86", "INTEL"): 0, ("aarch64", None): 0, ("riscv", None): 0} + matches = { + ("x86", "ATT"): 0, + ("x86", "INTEL"): 0, + ("aarch64", None): 0, + ("riscv", None): 0, + } for h in heuristics_x86ATT: matches[("x86", "ATT")] += len(re.findall(h, file_content)) diff --git a/osaca/parser/condition.py b/osaca/parser/condition.py index 1acdf50..ed2d1a1 100644 --- a/osaca/parser/condition.py +++ b/osaca/parser/condition.py @@ -22,7 +22,10 @@ class ConditionOperand(Operand): self._ccode = ccode def __str__(self): - return f"Condition(ccode={self._ccode}, source={self._source}, destination={self._destination})" + return ( + f"Condition(ccode={self._ccode}, source={self._source}, " + f"destination={self._destination})" + ) def __repr__(self): return self.__str__() diff --git a/osaca/parser/directive.py b/osaca/parser/directive.py index fff2616..0bc7cc0 100644 --- a/osaca/parser/directive.py +++ b/osaca/parser/directive.py @@ -28,7 +28,9 @@ class DirectiveOperand(Operand): if isinstance(other, DirectiveOperand): return self._name == other._name and self._parameters == other._parameters elif isinstance(other, dict): - return self._name == other["name"] and self._parameters == other["parameters"] + return ( + self._name == other["name"] and self._parameters == other["parameters"] + ) return False def __str__(self): diff --git a/osaca/parser/identifier.py b/osaca/parser/identifier.py index 87c3d76..f8a51ca 100644 --- a/osaca/parser/identifier.py +++ b/osaca/parser/identifier.py @@ -4,7 +4,9 @@ from osaca.parser.operand import Operand class IdentifierOperand(Operand): - def __init__(self, name=None, offset=None, relocation=None, source=False, destination=False): + def __init__( + self, name=None, offset=None, relocation=None, source=False, destination=False + ): super().__init__(source, destination) self._name = name self._offset = offset @@ -36,7 +38,8 @@ class IdentifierOperand(Operand): def __str__(self): return ( - f"Identifier(name={self._name}, offset={self._offset}, relocation={self._relocation})" + f"Identifier(name={self._name}, offset={self._offset}, " + f"relocation={self._relocation})" ) def __repr__(self): diff --git a/osaca/parser/immediate.py b/osaca/parser/immediate.py index afda34c..0dd9c1e 100644 --- a/osaca/parser/immediate.py +++ b/osaca/parser/immediate.py @@ -10,6 +10,8 @@ class ImmediateOperand(Operand): imd_type=None, value=None, shift=None, + reloc_type=None, + symbol=None, source=False, destination=False, ): @@ -18,6 +20,8 @@ class ImmediateOperand(Operand): self._imd_type = imd_type self._value = value self._shift = shift + self._reloc_type = reloc_type + self._symbol = symbol @property def identifier(self): @@ -33,7 +37,15 @@ class ImmediateOperand(Operand): @property def shift(self): - return self._imd_type + return self._shift + + @property + def reloc_type(self): + return self._reloc_type + + @property + def symbol(self): + return self._symbol @imd_type.setter def imd_type(self, itype): @@ -51,10 +63,19 @@ class ImmediateOperand(Operand): def shift(self, shift): self._shift = shift + @reloc_type.setter + def reloc_type(self, reloc_type): + self._reloc_type = reloc_type + + @symbol.setter + def symbol(self, symbol): + self._symbol = symbol + def __str__(self): return ( f"Immediate(identifier={self._identifier}, imd_type={self._imd_type}, " - f"value={self._value}, shift={self._shift}, source={self._source}, destination={self._destination})" + f"value={self._value}, shift={self._shift}, reloc_type={self._reloc_type}, " + f"symbol={self._symbol}, source={self._source}, destination={self._destination})" ) def __repr__(self): @@ -62,10 +83,18 @@ class ImmediateOperand(Operand): def __eq__(self, other): if isinstance(other, ImmediateOperand): + # Handle cases where old instances might not have the new attributes + self_reloc_type = getattr(self, "_reloc_type", None) + self_symbol = getattr(self, "_symbol", None) + other_reloc_type = getattr(other, "_reloc_type", None) + other_symbol = getattr(other, "_symbol", None) + return ( self._identifier == other._identifier and self._imd_type == other._imd_type and self._value == other._value and self._shift == other._shift + and self_reloc_type == other_reloc_type + and self_symbol == other_symbol ) return False diff --git a/osaca/parser/parser_AArch64.py b/osaca/parser/parser_AArch64.py index 12f44b2..507b9fa 100644 --- a/osaca/parser/parser_AArch64.py +++ b/osaca/parser/parser_AArch64.py @@ -34,10 +34,15 @@ class ParserAArch64(BaseParser): return [ InstructionForm( mnemonic="mov", - operands=[RegisterOperand(name="1", prefix="x"), ImmediateOperand(value=111)], + operands=[ + RegisterOperand(name="1", prefix="x"), + ImmediateOperand(value=111), + ], ), InstructionForm( - directive_id=DirectiveOperand(name="byte", parameters=["213", "3", "32", "31"]) + directive_id=DirectiveOperand( + name="byte", parameters=["213", "3", "32", "31"] + ) ), ] @@ -45,10 +50,15 @@ class ParserAArch64(BaseParser): return [ InstructionForm( mnemonic="mov", - operands=[RegisterOperand(name="1", prefix="x"), ImmediateOperand(value=222)], + operands=[ + RegisterOperand(name="1", prefix="x"), + ImmediateOperand(value=222), + ], ), InstructionForm( - directive_id=DirectiveOperand(name="byte", parameters=["213", "3", "32", "31"]) + directive_id=DirectiveOperand( + name="byte", parameters=["213", "3", "32", "31"] + ) ), ] @@ -88,7 +98,9 @@ class ParserAArch64(BaseParser): hex_number = pp.Combine( pp.Optional(pp.Literal("-")) + pp.Literal("0x") + pp.Word(pp.hexnums) ).setResultsName("value") - relocation = pp.Combine(pp.Literal(":") + pp.Word(pp.alphanums + "_") + pp.Literal(":")) + relocation = pp.Combine( + pp.Literal(":") + pp.Word(pp.alphanums + "_") + pp.Literal(":") + ) first = pp.Word(pp.alphas + "_.", exact=1) rest = pp.Word(pp.alphanums + "_.") identifier = pp.Group( @@ -101,7 +113,9 @@ class ParserAArch64(BaseParser): ).setResultsName(self.identifier) # Label self.label = pp.Group( - identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment) + identifier.setResultsName("name") + + pp.Literal(":") + + pp.Optional(self.comment) ).setResultsName(self.label_id) # Directive directive_option = pp.Combine( @@ -109,13 +123,21 @@ class ParserAArch64(BaseParser): + pp.Optional(pp.Word(pp.printables + " ", excludeChars=",")) ) directive_parameter = ( - pp.quotedString | directive_option | identifier | hex_number | decimal_number + pp.quotedString + | directive_option + | identifier + | hex_number + | decimal_number + ) + commaSeparatedList = pp.delimitedList( + pp.Optional(directive_parameter), delim="," ) - commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",") self.directive = pp.Group( pp.Literal(".") + pp.Word(pp.alphanums + "_").setResultsName("name") - + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters") + + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName( + "parameters" + ) + pp.Optional(self.comment) ).setResultsName(self.directive_id) # LLVM-MCA markers @@ -137,7 +159,10 @@ class ParserAArch64(BaseParser): # int: ^-?[0-9]+ | hex: ^0x[0-9a-fA-F]+ | fp: ^[0-9]{1}.[0-9]+[eE]{1}[\+-]{1}[0-9]+[fF]? symbol_immediate = "#" mantissa = pp.Combine( - pp.Optional(pp.Literal("-")) + pp.Word(pp.nums) + pp.Literal(".") + pp.Word(pp.nums) + pp.Optional(pp.Literal("-")) + + pp.Word(pp.nums) + + pp.Literal(".") + + pp.Word(pp.nums) ).setResultsName("mantissa") exponent = ( pp.CaselessLiteral("e") @@ -182,7 +207,9 @@ class ParserAArch64(BaseParser): + pp.Word(pp.nums).setResultsName("name") + pp.Optional(pp.Literal("!")).setResultsName("pre_indexed") ) - index = pp.Literal("[") + pp.Word(pp.nums).setResultsName("index") + pp.Literal("]") + index = ( + pp.Literal("[") + pp.Word(pp.nums).setResultsName("index") + pp.Literal("]") + ) vector = ( pp.oneOf("v z", caseless=True).setResultsName("prefix") + pp.Word(pp.nums).setResultsName("name") @@ -212,10 +239,12 @@ class ParserAArch64(BaseParser): register_list = ( pp.Literal("{") + ( - pp.delimitedList(pp.Combine(self.list_element), delim=",").setResultsName("list") - ^ pp.delimitedList(pp.Combine(self.list_element), delim="-").setResultsName( - "range" - ) + pp.delimitedList( + pp.Combine(self.list_element), delim="," + ).setResultsName("list") + ^ pp.delimitedList( + pp.Combine(self.list_element), delim="-" + ).setResultsName("range") ) + pp.Literal("}") + pp.Optional(index) @@ -238,21 +267,30 @@ class ParserAArch64(BaseParser): pp.Literal("[") + pp.Optional(register.setResultsName("base")) + pp.Optional(pp.Suppress(pp.Literal(","))) - + pp.Optional(register_index ^ (immediate ^ arith_immediate).setResultsName("offset")) + + pp.Optional( + register_index ^ (immediate ^ arith_immediate).setResultsName("offset") + ) + pp.Literal("]") + pp.Optional( pp.Literal("!").setResultsName("pre_indexed") - | (pp.Suppress(pp.Literal(",")) + immediate.setResultsName("post_indexed")) + | ( + pp.Suppress(pp.Literal(",")) + + immediate.setResultsName("post_indexed") + ) ) ).setResultsName(self.memory_id) prefetch_op = pp.Group( - pp.Group(pp.CaselessLiteral("PLD") ^ pp.CaselessLiteral("PST")).setResultsName("type") + pp.Group( + pp.CaselessLiteral("PLD") ^ pp.CaselessLiteral("PST") + ).setResultsName("type") + pp.Group( - pp.CaselessLiteral("L1") ^ pp.CaselessLiteral("L2") ^ pp.CaselessLiteral("L3") + pp.CaselessLiteral("L1") + ^ pp.CaselessLiteral("L2") + ^ pp.CaselessLiteral("L3") ).setResultsName("target") - + pp.Group(pp.CaselessLiteral("KEEP") ^ pp.CaselessLiteral("STRM")).setResultsName( - "policy" - ) + + pp.Group( + pp.CaselessLiteral("KEEP") ^ pp.CaselessLiteral("STRM") + ).setResultsName("policy") ).setResultsName("prfop") # Condition codes, based on http://tiny.cc/armcc condition = ( @@ -323,7 +361,9 @@ class ParserAArch64(BaseParser): # 1. Parse comment try: - result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.comment.parseString(line, parseAll=True).asDict() + ) instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass @@ -339,7 +379,9 @@ class ParserAArch64(BaseParser): if result is None: try: # returns tuple with label operand and comment, if any - result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.label.parseString(line, parseAll=True).asDict() + ) instruction_form.label = result[0].name if result[1] is not None: instruction_form.comment = " ".join(result[1]) @@ -381,33 +423,57 @@ class ParserAArch64(BaseParser): :param str instruction: Assembly line string. :returns: `dict` -- parsed instruction form """ - result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() + result = self.instruction_parser.parseString( + instruction, parseAll=True + ).asDict() operands = [] # Add operands to list # Check first operand if "operand1" in result: operand = self.process_operand(result["operand1"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) # Check second operand if "operand2" in result: operand = self.process_operand(result["operand2"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) # Check third operand if "operand3" in result: operand = self.process_operand(result["operand3"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) # Check fourth operand if "operand4" in result: operand = self.process_operand(result["operand4"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) # Check fifth operand if "operand5" in result: operand = self.process_operand(result["operand5"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) return_dict = InstructionForm( mnemonic=result["mnemonic"], operands=operands, - comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None, + comment_id=( + " ".join(result[self.comment_id]) if self.comment_id in result else None + ), ) return return_dict @@ -421,8 +487,13 @@ class ParserAArch64(BaseParser): "list" in operand[self.register_id] or "range" in operand[self.register_id] ): # resolve ranges and lists - return self.resolve_range_list(self.process_register_list(operand[self.register_id])) - if self.register_id in operand and operand[self.register_id]["name"].lower() == "sp": + return self.resolve_range_list( + self.process_register_list(operand[self.register_id]) + ) + if ( + self.register_id in operand + and operand[self.register_id]["name"].lower() == "sp" + ): return self.process_sp_register(operand[self.register_id]) # add value attribute to floating point immediates without exponent if self.immediate_id in operand: @@ -464,7 +535,9 @@ class ParserAArch64(BaseParser): shape=operand["shape"].lower() if "shape" in operand else None, lanes=operand["lanes"] if "lanes" in operand else None, index=operand["index"] if "index" in operand else None, - predication=operand["predication"].lower() if "predication" in operand else None, + predication=( + operand["predication"].lower() if "predication" in operand else None + ), pre_indexed=True if "pre_indexed" in operand else False, ) @@ -511,7 +584,9 @@ class ParserAArch64(BaseParser): new_dict.pre_indexed = True if "post_indexed" in memory_address: if "value" in memory_address["post_indexed"]: - new_dict.post_indexed = {"value": int(memory_address["post_indexed"]["value"], 0)} + new_dict.post_indexed = { + "value": int(memory_address["post_indexed"]["value"], 0) + } else: new_dict.post_indexed = memory_address["post_indexed"] return new_dict @@ -585,13 +660,17 @@ class ParserAArch64(BaseParser): # normal integer value immediate["type"] = "int" # convert hex/bin immediates to dec - new_immediate = ImmediateOperand(imd_type=immediate["type"], value=immediate["value"]) + new_immediate = ImmediateOperand( + imd_type=immediate["type"], value=immediate["value"] + ) new_immediate.value = self.normalize_imd(new_immediate) return new_immediate if "base_immediate" in immediate: # arithmetic immediate, add calculated value as value immediate["shift"] = immediate["shift"][0] - temp_immediate = ImmediateOperand(value=immediate["base_immediate"]["value"]) + temp_immediate = ImmediateOperand( + value=immediate["base_immediate"]["value"] + ) immediate["type"] = "int" new_immediate = ImmediateOperand( imd_type=immediate["type"], value=None, shift=immediate["shift"] @@ -606,10 +685,14 @@ class ParserAArch64(BaseParser): dict_name = "double" if "exponent" in immediate[dict_name]: immediate["type"] = dict_name - return ImmediateOperand(imd_type=immediate["type"], value=immediate[immediate["type"]]) + return ImmediateOperand( + imd_type=immediate["type"], value=immediate[immediate["type"]] + ) else: # change 'mantissa' key to 'value' - return ImmediateOperand(value=immediate[dict_name]["mantissa"], imd_type=dict_name) + return ImmediateOperand( + value=immediate[dict_name]["mantissa"], imd_type=dict_name + ) def process_label(self, label): """Post-process label asm line""" @@ -632,7 +715,9 @@ class ParserAArch64(BaseParser): name = register.prefix + str(register.name) if register.shape is not None: name += ( - "." + str(register.lanes if register.lanes is not None else "") + register.shape + "." + + str(register.lanes if register.lanes is not None else "") + + register.shape ) if register.index is not None: name += "[" + str(register.index) + "]" @@ -707,9 +792,15 @@ class ParserAArch64(BaseParser): prefixes_gpr = "wx" prefixes_vec = "bhsdqvz" if reg_a.name == reg_b.name: - if reg_a.prefix.lower() in prefixes_gpr and reg_b.prefix.lower() in prefixes_gpr: + if ( + reg_a.prefix.lower() in prefixes_gpr + and reg_b.prefix.lower() in prefixes_gpr + ): return True - if reg_a.prefix.lower() in prefixes_vec and reg_b.prefix.lower() in prefixes_vec: + if ( + reg_a.prefix.lower() in prefixes_vec + and reg_b.prefix.lower() in prefixes_vec + ): return True return False diff --git a/osaca/parser/parser_RISCV.py b/osaca/parser/parser_RISCV.py index b1db9fe..bdf72f9 100644 --- a/osaca/parser/parser_RISCV.py +++ b/osaca/parser/parser_RISCV.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 import re -import os -from copy import deepcopy import pyparsing as pp from osaca.parser import BaseParser @@ -13,7 +11,6 @@ from osaca.parser.label import LabelOperand from osaca.parser.register import RegisterOperand from osaca.parser.identifier import IdentifierOperand from osaca.parser.immediate import ImmediateOperand -from osaca.parser.condition import ConditionOperand class ParserRISCV(BaseParser): @@ -61,7 +58,7 @@ class ParserRISCV(BaseParser): self.comment = pp.Literal(symbol_comment) + pp.Group( pp.ZeroOrMore(pp.Word(pp.printables)) ).setResultsName(self.comment_id) - + # Define RISC-V assembly identifier decimal_number = pp.Combine( pp.Optional(pp.Literal("-")) + pp.Word(pp.nums) @@ -69,18 +66,32 @@ class ParserRISCV(BaseParser): hex_number = pp.Combine( pp.Optional(pp.Literal("-")) + pp.Literal("0x") + pp.Word(pp.hexnums) ).setResultsName("value") - - # Additional identifiers used in vector instructions - vector_identifier = pp.Word(pp.alphas, pp.alphanums) - special_identifier = pp.Word(pp.alphas + "%") - - # First character of an identifier + + # RISC-V specific relocation attributes + reloc_type = ( + pp.Literal("%hi") + | pp.Literal("%lo") + | pp.Literal("%pcrel_hi") + | pp.Literal("%pcrel_lo") + | pp.Literal("%tprel_hi") + | pp.Literal("%tprel_lo") + | pp.Literal("%tprel_add") + ).setResultsName("reloc_type") + + reloc_expr = pp.Group( + reloc_type + + pp.Suppress("(") + + pp.Word(pp.alphas + pp.nums + "_").setResultsName("symbol") + + pp.Suppress(")") + ).setResultsName("relocation") + + # First character of an identifier first = pp.Word(pp.alphas + "_.", exact=1) # Rest of the identifier rest = pp.Word(pp.alphanums + "_.") # PLT suffix (@plt) for calls to shared libraries plt_suffix = pp.Optional(pp.Literal("@") + pp.Word(pp.alphas)) - + identifier = pp.Group( (pp.Combine(first + pp.Optional(rest) + plt_suffix)).setResultsName("name") + pp.Optional( @@ -88,31 +99,44 @@ class ParserRISCV(BaseParser): + (hex_number | decimal_number).setResultsName("offset") ) ).setResultsName(self.identifier) - + + # Immediate with optional relocation + immediate = pp.Group( + reloc_expr | (hex_number ^ decimal_number) | identifier + ).setResultsName(self.immediate_id) + # Label self.label = pp.Group( - identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment) + identifier.setResultsName("name") + + pp.Literal(":") + + pp.Optional(self.comment) ).setResultsName(self.label_id) - + # Directive directive_option = pp.Combine( pp.Word(pp.alphas + "#@.%", exact=1) + pp.Optional(pp.Word(pp.printables + " ", excludeChars=",")) ) - - # For vector instructions - vector_parameter = pp.Word(pp.alphas) + directive_parameter = ( - pp.quotedString | directive_option | identifier | hex_number | decimal_number + pp.quotedString + | directive_option + | identifier + | hex_number + | decimal_number + ) + commaSeparatedList = pp.delimitedList( + pp.Optional(directive_parameter), delim="," ) - commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",") self.directive = pp.Group( pp.Literal(".") + pp.Word(pp.alphanums + "_").setResultsName("name") - + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters") + + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName( + "parameters" + ) + pp.Optional(self.comment) ).setResultsName(self.directive_id) - + # LLVM-MCA markers self.llvm_markers = pp.Group( pp.Literal("#") @@ -127,61 +151,58 @@ class ParserRISCV(BaseParser): # Instructions # Mnemonic mnemonic = pp.Word(pp.alphanums + ".").setResultsName("mnemonic") - - # Immediate: - # int: ^-?[0-9]+ | hex: ^0x[0-9a-fA-F]+ - immediate = pp.Group( - (hex_number ^ decimal_number) - | identifier - ).setResultsName(self.immediate_id) - + # Register: # RISC-V has two main types of registers: # 1. Integer registers (x0-x31 or ABI names) # 2. Floating-point registers (f0-f31 or ABI names) - + # Integer register ABI names integer_reg_abi = ( - pp.CaselessLiteral("zero") | - pp.CaselessLiteral("ra") | - pp.CaselessLiteral("sp") | - pp.CaselessLiteral("gp") | - pp.CaselessLiteral("tp") | - pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11 + pp.CaselessLiteral("zero") + | pp.CaselessLiteral("ra") + | pp.CaselessLiteral("sp") + | pp.CaselessLiteral("gp") + | pp.CaselessLiteral("tp") + | pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11 ).setResultsName("name") - + # Integer registers x0-x31 - integer_reg_x = ( - pp.CaselessLiteral("x").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + integer_reg_x = pp.CaselessLiteral("x").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Floating point registers - fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName("name") # ft0-ft11, fa0-fa7, fs0-fs11 - - fp_reg_f = ( - pp.CaselessLiteral("f").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName( + "name" + ) # ft0-ft11, fa0-fa7, fs0-fs11 + + fp_reg_f = pp.CaselessLiteral("f").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Control and status registers (CSRs) csr_reg = pp.Combine( pp.CaselessLiteral("csr") + pp.Word(pp.alphanums + "_") ).setResultsName("name") - + # Vector registers (for the "V" extension) - vector_reg = ( - pp.CaselessLiteral("v").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + vector_reg = pp.CaselessLiteral("v").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Combined register definition register = pp.Group( - integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg + integer_reg_x + | integer_reg_abi + | fp_reg_f + | fp_reg_abi + | vector_reg + | csr_reg ).setResultsName(self.register_id) - + self.register = register - + # Memory addressing mode in RISC-V: offset(base_register) memory = pp.Group( pp.Optional(immediate.setResultsName("offset")) @@ -189,24 +210,19 @@ class ParserRISCV(BaseParser): + register.setResultsName("base") + pp.Suppress(pp.Literal(")")) ).setResultsName(self.memory_id) - + # Combine to instruction form - operand_first = pp.Group( - register ^ immediate ^ memory ^ identifier - ) - operand_rest = pp.Group( - register ^ immediate ^ memory ^ identifier - ) - - # Vector instruction special parameters (e.g., e32, m4, ta, ma) - vector_param = pp.Word(pp.alphas + pp.nums) - + operand_first = pp.Group(register ^ immediate ^ memory ^ identifier) + operand_rest = pp.Group(register ^ immediate ^ memory ^ identifier) + # Handle additional vector parameters additional_params = pp.ZeroOrMore( - pp.Suppress(pp.Literal(",")) + - vector_param.setResultsName("vector_param", listAllMatches=True) + pp.Suppress(pp.Literal(",")) + + pp.Word(pp.alphas + pp.nums).setResultsName( + "vector_param", listAllMatches=True + ) ) - + # Main instruction parser self.instruction_parser = ( mnemonic @@ -217,7 +233,7 @@ class ParserRISCV(BaseParser): + pp.Optional(operand_rest.setResultsName("operand3")) + pp.Optional(pp.Suppress(pp.Literal(","))) + pp.Optional(operand_rest.setResultsName("operand4")) - + pp.Optional(additional_params) # For vector instructions with more params + + pp.Optional(additional_params) + pp.Optional(self.comment) ) @@ -228,7 +244,8 @@ class ParserRISCV(BaseParser): :param str line: line of assembly code :param line_number: identifier of instruction form, defaults to None :type line_number: int, optional - :return: `dict` -- parsed asm line (comment, label, directive or instruction form) + :return: `dict` -- parsed asm line (comment, label, directive or + instruction form) """ instruction_form = InstructionForm( mnemonic=None, @@ -243,11 +260,13 @@ class ParserRISCV(BaseParser): # 1. Parse comment try: - result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.comment.parseString(line, parseAll=True).asDict() + ) instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass - + # 1.2 check for llvm-mca marker try: result = self.process_operand( @@ -256,12 +275,14 @@ class ParserRISCV(BaseParser): instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass - + # 2. Parse label if result is None: try: # returns tuple with label operand and comment, if any - result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.label.parseString(line, parseAll=True).asDict() + ) instruction_form.label = result[0].name if result[1] is not None: instruction_form.comment = " ".join(result[1]) @@ -294,7 +315,7 @@ class ParserRISCV(BaseParser): instruction_form.mnemonic = result.mnemonic instruction_form.operands = result.operands instruction_form.comment = result.comment - + return instruction_form def parse_instruction(self, instruction): @@ -304,75 +325,99 @@ class ParserRISCV(BaseParser): :param str instruction: Assembly line string. :returns: `dict` -- parsed instruction form """ + # Store current instruction for context in operand processing + if instruction.startswith("vsetvli"): + self.current_instruction = "vsetvli" + else: + # Extract mnemonic for context + parts = instruction.split("#")[0].strip().split() + self.current_instruction = parts[0] if parts else None + # Special handling for vector instructions like vsetvli with many parameters if instruction.startswith("vsetvli"): - parts = instruction.split("#")[0].strip().split() + # Split into mnemonic and operands part + parts = ( + instruction.split("#")[0].strip().split(None, 1) + ) # Split on first whitespace only mnemonic = parts[0] - + # Split operands by commas if len(parts) > 1: operand_part = parts[1] operands_list = [op.strip() for op in operand_part.split(",")] - + # Process each operand operands = [] for op in operands_list: - if op.startswith("x") or op in ["zero", "ra", "sp", "gp", "tp"] or re.match(r"[tas][0-9]+", op): + if ( + op.startswith("x") + or op in ["zero", "ra", "sp", "gp", "tp"] + or re.match(r"[tas][0-9]+", op) + ): operands.append(RegisterOperand(name=op)) - elif op in ["e8", "e16", "e32", "e64", "m1", "m2", "m4", "m8", "ta", "tu", "ma", "mu"]: - operands.append(IdentifierOperand(name=op)) else: - operands.append(IdentifierOperand(name=op)) - + # Vector parameters get appropriate attributes + if op.startswith("e"): # Element width + operands.append(IdentifierOperand(name=op)) + elif op.startswith("m"): # LMUL setting + operands.append(IdentifierOperand(name=op)) + elif op in ["ta", "tu", "ma", "mu"]: # Tail/mask policies + operands.append(IdentifierOperand(name=op)) + else: + operands.append(IdentifierOperand(name=op)) + # Get comment if present comment = None if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - + return InstructionForm( - mnemonic=mnemonic, - operands=operands, - comment_id=comment + mnemonic=mnemonic, operands=operands, comment_id=comment ) - + # Regular instruction parsing try: - result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() + result = self.instruction_parser.parseString( + instruction, parseAll=True + ).asDict() operands = [] - # Add operands to list - # Check first operand - if "operand1" in result: - operand = self.process_operand(result["operand1"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check second operand - if "operand2" in result: - operand = self.process_operand(result["operand2"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check third operand - if "operand3" in result: - operand = self.process_operand(result["operand3"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check fourth operand - if "operand4" in result: - operand = self.process_operand(result["operand4"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - - # Handle vector_param for vector instructions + + # Process operands + for i in range(1, 5): + operand_key = f"operand{i}" + if operand_key in result: + operand = self.process_operand(result[operand_key]) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) + + # Handle vector parameters as identifiers with appropriate attributes if "vector_param" in result: if isinstance(result["vector_param"], list): for param in result["vector_param"]: - operands.append(IdentifierOperand(name=param)) + if param.startswith("e"): # Element width + operands.append(IdentifierOperand(name=param)) + elif param.startswith("m"): # LMUL setting + operands.append(IdentifierOperand(name=param)) + else: + operands.append(IdentifierOperand(name=param)) else: operands.append(IdentifierOperand(name=result["vector_param"])) - + return_dict = InstructionForm( mnemonic=result["mnemonic"], operands=operands, - comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None, + comment_id=( + " ".join(result[self.comment_id]) + if self.comment_id in result + else None + ), ) return return_dict - - except Exception as e: + + except Exception: # For special vector instructions or ones with % in them if "%" in instruction or instruction.startswith("v"): parts = instruction.split("#")[0].strip().split(None, 1) @@ -383,21 +428,26 @@ class ParserRISCV(BaseParser): operands_list = [op.strip() for op in operand_part.split(",")] for op in operands_list: # Process '%hi(data)' to 'data' for certain operands - if op.startswith("%") and '(' in op and ')' in op: - # Extract data from %hi(data) format - data = op[op.index('(')+1:op.index(')')] - operands.append(IdentifierOperand(name=data)) + if op.startswith("%") and "(" in op and ")" in op: + reloc_type = op[: op.index("(")] + symbol = op[op.index("(") + 1 : op.index(")")] + operands.append( + ImmediateOperand( + imd_type="reloc", + value=None, + reloc_type=reloc_type, + symbol=symbol, + ) + ) else: operands.append(IdentifierOperand(name=op)) - + comment = None if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - + return InstructionForm( - mnemonic=mnemonic, - operands=operands, - comment_id=comment + mnemonic=mnemonic, operands=operands, comment_id=comment ) else: raise @@ -430,62 +480,127 @@ class ParserRISCV(BaseParser): ) def process_register_operand(self, operand): - """Process register operands, including ABI name to x-register mapping""" - # If already has prefix (x#, f#, v#), just return as is + """Process register operands, including ABI name to x-register mapping + and vector attributes""" + # If already has prefix (x#, f#, v#), process with appropriate attributes if "prefix" in operand: - return RegisterOperand( - prefix=operand["prefix"].lower(), - name=operand["name"] - ) - + prefix = operand["prefix"].lower() + + # Special handling for vector registers + if prefix == "v": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="vector", + # Vector registers can have different element widths (e8,e16,e32,e64) + width=operand.get("width", None), + # Number of elements (m1,m2,m4,m8) + lanes=operand.get("lanes", None), + # For vector mask registers + mask=operand.get("mask", False), + # For tail agnostic/undisturbed policies + zeroing=operand.get("zeroing", False), + ) + # For floating point registers + elif prefix == "f": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="float", + width=64, # RISC-V typically uses 64-bit float registers + ) + # For integer registers + elif prefix == "x": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="int", + width=64, # RV64 uses 64-bit registers + ) + # Handle ABI names by converting to x-register numbers name = operand["name"].lower() - + # ABI name mapping for integer registers abi_to_x = { - "zero": "0", "ra": "1", "sp": "2", "gp": "3", "tp": "4", - "t0": "5", "t1": "6", "t2": "7", - "s0": "8", "fp": "8", "s1": "9", - "a0": "10", "a1": "11", "a2": "12", "a3": "13", - "a4": "14", "a5": "15", "a6": "16", "a7": "17", - "s2": "18", "s3": "19", "s4": "20", "s5": "21", - "s6": "22", "s7": "23", "s8": "24", "s9": "25", - "s10": "26", "s11": "27", - "t3": "28", "t4": "29", "t5": "30", "t6": "31" + "zero": "x0", + "ra": "x1", + "sp": "x2", + "gp": "x3", + "tp": "x4", + "t0": "x5", + "t1": "x6", + "t2": "x7", + "s0": "x8", + "s1": "x9", + "a0": "x10", + "a1": "x11", + "a2": "x12", + "a3": "x13", + "a4": "x14", + "a5": "x15", + "a6": "x16", + "a7": "x17", + "s2": "x18", + "s3": "x19", + "s4": "x20", + "s5": "x21", + "s6": "x22", + "s7": "x23", + "s8": "x24", + "s9": "x25", + "s10": "x26", + "s11": "x27", + "t3": "x28", + "t4": "x29", + "t5": "x30", + "t6": "x31", } - + # Integer register ABI names if name in abi_to_x: return RegisterOperand( prefix="x", - name=abi_to_x[name] + name=abi_to_x[name], + regtype="int", + width=64, # RV64 uses 64-bit registers ) # Floating point register ABI names elif name.startswith("f") and name[1] in ["t", "a", "s"]: if name[1] == "a": # fa0-fa7 idx = int(name[2:]) - return RegisterOperand(prefix="f", name=str(idx + 10)) + return RegisterOperand( + prefix="f", name=str(idx + 10), regtype="float", width=64 + ) elif name[1] == "s": # fs0-fs11 idx = int(name[2:]) if idx <= 1: - return RegisterOperand(prefix="f", name=str(idx + 8)) + return RegisterOperand( + prefix="f", name=str(idx + 8), regtype="float", width=64 + ) else: - return RegisterOperand(prefix="f", name=str(idx + 16)) + return RegisterOperand( + prefix="f", name=str(idx + 16), regtype="float", width=64 + ) elif name[1] == "t": # ft0-ft11 idx = int(name[2:]) if idx <= 7: - return RegisterOperand(prefix="f", name=str(idx)) + return RegisterOperand( + prefix="f", name=str(idx), regtype="float", width=64 + ) else: - return RegisterOperand(prefix="f", name=str(idx + 20)) + return RegisterOperand( + prefix="f", name=str(idx + 20), regtype="float", width=64 + ) # CSR registers elif name.startswith("csr"): - return RegisterOperand(prefix="", name=name) - + return RegisterOperand(prefix="", name=name, regtype="csr") + # If no mapping found, return as is return RegisterOperand(prefix="", name=name) def process_memory_address(self, memory_address): - """Post-process memory address operand""" + """Post-process memory address operand with RISC-V specific attributes""" # Process offset offset = memory_address.get("offset", None) if isinstance(offset, list) and len(offset) == 1: @@ -494,18 +609,38 @@ class ParserRISCV(BaseParser): offset = ImmediateOperand(value=int(offset["value"], 0)) if isinstance(offset, dict) and "identifier" in offset: offset = self.process_identifier(offset["identifier"]) - + # Process base register base = memory_address.get("base", None) if base is not None: base = self.process_register_operand(base) - - # Create memory operand + + # Determine data type from instruction context if available + # RISC-V load/store instructions encode the data width in the mnemonic + # e.g., lw (word), lh (half), lb (byte), etc. + data_type = None + if hasattr(self, "current_instruction"): + mnemonic = self.current_instruction.lower() + if any(x in mnemonic for x in ["b", "bu"]): # byte operations + data_type = "byte" + elif any(x in mnemonic for x in ["h", "hu"]): # halfword operations + data_type = "halfword" + elif any(x in mnemonic for x in ["w", "wu"]): # word operations + data_type = "word" + elif "d" in mnemonic: # doubleword operations + data_type = "doubleword" + + # Create memory operand with enhanced attributes return MemoryOperand( offset=offset, base=base, - index=None, - scale=1 + index=None, # RISC-V doesn't use index registers + scale=1, # RISC-V doesn't use scaling + data_type=data_type, + # Handle vector memory operations + mask=memory_address.get("mask", None), # For vector masked loads/stores + src=memory_address.get("src", None), # Source register type for stores + dst=memory_address.get("dst", None), # Destination register type for loads ) def process_label(self, label): @@ -519,21 +654,102 @@ class ParserRISCV(BaseParser): """Post-process identifier operand""" return IdentifierOperand( name=identifier["name"] if "name" in identifier else None, - offset=identifier["offset"] if "offset" in identifier else None + offset=identifier["offset"] if "offset" in identifier else None, ) - + def process_immediate(self, immediate): - """Post-process immediate operand""" + """Post-process immediate operand with RISC-V specific handling""" + # Handle relocations + if "relocation" in immediate: + reloc = immediate["relocation"] + return ImmediateOperand( + imd_type="reloc", + value=None, + reloc_type=reloc["reloc_type"], + symbol=reloc["symbol"], + ) + + # Handle identifiers if "identifier" in immediate: - # actually an identifier, change declaration return self.process_identifier(immediate["identifier"]) + + # Handle numeric values with validation if "value" in immediate: - # normal integer value - immediate["type"] = "int" - # convert hex/bin immediates to dec - new_immediate = ImmediateOperand(imd_type=immediate["type"], value=immediate["value"]) - new_immediate.value = self.normalize_imd(new_immediate) - return new_immediate + value = int( + immediate["value"], 0 + ) # Convert to integer, handling hex/decimal + + # Determine immediate type and validate range based on instruction type + if hasattr(self, "current_instruction"): + mnemonic = self.current_instruction.lower() + + # I-type instructions (12-bit signed immediate) + if any( + x in mnemonic + for x in [ + "addi", + "slti", + "xori", + "ori", + "andi", + "slli", + "srli", + "srai", + ] + ): + if not -2048 <= value <= 2047: + raise ValueError( + f"Immediate value {value} out of range for I-type " + f"instruction (-2048 to 2047)" + ) + return ImmediateOperand(imd_type="I", value=value) + + # S-type instructions (12-bit signed immediate for store) + elif any(x in mnemonic for x in ["sb", "sh", "sw", "sd"]): + if not -2048 <= value <= 2047: + raise ValueError( + f"Immediate value {value} out of range for S-type " + f"instruction (-2048 to 2047)" + ) + return ImmediateOperand(imd_type="S", value=value) + + # B-type instructions (13-bit signed immediate for branches, must be even) + elif any( + x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"] + ): + if not -4096 <= value <= 4095 or value % 2 != 0: + raise ValueError( + f"Immediate value {value} out of range or not even " + f"for B-type instruction (-4096 to 4095, must be even)" + ) + return ImmediateOperand(imd_type="B", value=value) + + # U-type instructions (20-bit upper immediate) + elif any(x in mnemonic for x in ["lui", "auipc"]): + if not 0 <= value <= 1048575: + raise ValueError( + f"Immediate value {value} out of range for U-type " + f"instruction (0 to 1048575)" + ) + return ImmediateOperand(imd_type="U", value=value) + + # J-type instructions (21-bit signed immediate for jumps, must be even) + elif any(x in mnemonic for x in ["jal"]): + if not -1048576 <= value <= 1048575 or value % 2 != 0: + raise ValueError( + f"Immediate value {value} out of range or not even " + f"for J-type instruction (-1048576 to 1048575, must be even)" + ) + return ImmediateOperand(imd_type="J", value=value) + + # Vector instructions might have specific immediate ranges + elif mnemonic.startswith("v"): + # Handle vector specific immediates (implementation specific) + return ImmediateOperand(imd_type="V", value=value) + + # Default case - no specific validation + return ImmediateOperand(imd_type="int", value=value) + return immediate def get_full_reg_name(self, register): @@ -558,44 +774,83 @@ class ParserRISCV(BaseParser): def parse_register(self, register_string): """ Parse register string and return register dictionary. - + :param str register_string: register representation as string :returns: dict with register info """ # Remove any leading/trailing whitespace register_string = register_string.strip() - + # Check for integer registers (x0-x31) - x_match = re.match(r'^x([0-9]|[1-2][0-9]|3[0-1])$', register_string) + x_match = re.match(r"^x([0-9]|[1-2][0-9]|3[0-1])$", register_string) if x_match: reg_num = int(x_match.group(1)) - return {"class": "register", "register": {"prefix": "x", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "x", "name": str(reg_num)}, + } + # Check for floating-point registers (f0-f31) - f_match = re.match(r'^f([0-9]|[1-2][0-9]|3[0-1])$', register_string) + f_match = re.match(r"^f([0-9]|[1-2][0-9]|3[0-1])$", register_string) if f_match: reg_num = int(f_match.group(1)) - return {"class": "register", "register": {"prefix": "f", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "f", "name": str(reg_num)}, + } + # Check for vector registers (v0-v31) - v_match = re.match(r'^v([0-9]|[1-2][0-9]|3[0-1])$', register_string) + v_match = re.match(r"^v([0-9]|[1-2][0-9]|3[0-1])$", register_string) if v_match: reg_num = int(v_match.group(1)) - return {"class": "register", "register": {"prefix": "v", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "v", "name": str(reg_num)}, + } + # Check for ABI names abi_names = { - "zero": 0, "ra": 1, "sp": 2, "gp": 3, "tp": 4, - "t0": 5, "t1": 6, "t2": 7, - "s0": 8, "fp": 8, "s1": 9, - "a0": 10, "a1": 11, "a2": 12, "a3": 13, "a4": 14, "a5": 15, "a6": 16, "a7": 17, - "s2": 18, "s3": 19, "s4": 20, "s5": 21, "s6": 22, "s7": 23, "s8": 24, "s9": 25, "s10": 26, "s11": 27, - "t3": 28, "t4": 29, "t5": 30, "t6": 31 + "zero": 0, + "ra": 1, + "sp": 2, + "gp": 3, + "tp": 4, + "t0": 5, + "t1": 6, + "t2": 7, + "s0": 8, + "fp": 8, + "s1": 9, + "a0": 10, + "a1": 11, + "a2": 12, + "a3": 13, + "a4": 14, + "a5": 15, + "a6": 16, + "a7": 17, + "s2": 18, + "s3": 19, + "s4": 20, + "s5": 21, + "s6": 22, + "s7": 23, + "s8": 24, + "s9": 25, + "s10": 26, + "s11": 27, + "t3": 28, + "t4": 29, + "t5": 30, + "t6": 31, } - + if register_string in abi_names: - return {"class": "register", "register": {"prefix": "", "name": register_string}} - + return { + "class": "register", + "register": {"prefix": "", "name": register_string}, + } + # If no match is found return None @@ -626,38 +881,61 @@ class ParserRISCV(BaseParser): """Check if ``reg_a`` is dependent on ``reg_b``""" if not isinstance(reg_a, Operand): reg_a = RegisterOperand(name=reg_a["name"]) - + # Get canonical register names reg_a_canonical = self._get_canonical_reg_name(reg_a) reg_b_canonical = self._get_canonical_reg_name(reg_b) - + # Same register type and number means dependency return reg_a_canonical == reg_b_canonical - + def _get_canonical_reg_name(self, register): """Get the canonical form of a register (x-form for integer, f-form for FP)""" # If already in canonical form (x# or f#) if register.prefix in ["x", "f", "v"] and register.name.isdigit(): return f"{register.prefix}{register.name}" - + # ABI name mapping for integer registers abi_to_x = { - "zero": "x0", "ra": "x1", "sp": "x2", "gp": "x3", "tp": "x4", - "t0": "x5", "t1": "x6", "t2": "x7", - "s0": "x8", "s1": "x9", - "a0": "x10", "a1": "x11", "a2": "x12", "a3": "x13", - "a4": "x14", "a5": "x15", "a6": "x16", "a7": "x17", - "s2": "x18", "s3": "x19", "s4": "x20", "s5": "x21", - "s6": "x22", "s7": "x23", "s8": "x24", "s9": "x25", - "s10": "x26", "s11": "x27", - "t3": "x28", "t4": "x29", "t5": "x30", "t6": "x31" + "zero": "x0", + "ra": "x1", + "sp": "x2", + "gp": "x3", + "tp": "x4", + "t0": "x5", + "t1": "x6", + "t2": "x7", + "s0": "x8", + "s1": "x9", + "a0": "x10", + "a1": "x11", + "a2": "x12", + "a3": "x13", + "a4": "x14", + "a5": "x15", + "a6": "x16", + "a7": "x17", + "s2": "x18", + "s3": "x19", + "s4": "x20", + "s5": "x21", + "s6": "x22", + "s7": "x23", + "s8": "x24", + "s9": "x25", + "s10": "x26", + "s11": "x27", + "t3": "x28", + "t4": "x29", + "t5": "x30", + "t6": "x31", } - + # For integer register ABI names name = register.name.lower() if name in abi_to_x: return abi_to_x[name] - + # For FP register ABI names like fa0, fs1, etc. if name.startswith("f") and len(name) > 1: if name[1] == "a": # fa0-fa7 @@ -675,7 +953,7 @@ class ParserRISCV(BaseParser): return f"f{idx}" else: return f"f{idx + 20}" - + # Return as is if no mapping found return f"{register.prefix}{register.name}" @@ -684,7 +962,7 @@ class ParserRISCV(BaseParser): # Return register prefix if exists if register.prefix: return register.prefix - + # Determine type from ABI name name = register.name.lower() if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]: @@ -693,30 +971,30 @@ class ParserRISCV(BaseParser): return "f" # Floating point register elif name.startswith("csr"): return "csr" # Control and Status Register - + return "unknown" def normalize_instruction_form(self, instruction_form, isa_model, arch_model): """ Normalize instruction form for RISC-V instructions. - + :param instruction_form: instruction form to normalize :param isa_model: ISA model to use for normalization :param arch_model: architecture model to use for normalization """ if instruction_form.normalized: return - + if instruction_form.mnemonic is None: instruction_form.normalized = True return - + # Normalize the mnemonic if needed if instruction_form.mnemonic: # Handle any RISC-V specific mnemonic normalization # For example, convert aliases or pseudo-instructions to their base form pass - + # Normalize the operands if needed for i, operand in enumerate(instruction_form.operands): if isinstance(operand, ImmediateOperand): @@ -725,8 +1003,8 @@ class ParserRISCV(BaseParser): elif isinstance(operand, RegisterOperand): # Convert register names to canonical form if needed pass - - instruction_form.normalized = True + + instruction_form.normalized = True def get_regular_source_operands(self, instruction_form): """Get source operand of given instruction form assuming regular src/dst behavior.""" @@ -736,14 +1014,14 @@ class ParserRISCV(BaseParser): return [instruction_form.operands[0]] else: return [op for op in instruction_form.operands[1:]] - + def get_regular_destination_operands(self, instruction_form): """Get destination operand of given instruction form assuming regular src/dst behavior.""" # For RISC-V, the first operand is typically the destination if len(instruction_form.operands) == 1: return [] else: - return instruction_form.operands[:1] + return instruction_form.operands[:1] def process_immediate_operand(self, operand): """Process immediate operands, converting them to ImmediateOperand objects""" @@ -751,7 +1029,7 @@ class ParserRISCV(BaseParser): # For raw integer values or string immediates return ImmediateOperand( imd_type="int", - value=str(operand) if isinstance(operand, int) else operand + value=str(operand) if isinstance(operand, int) else operand, ) elif isinstance(operand, dict) and "imd" in operand: # For immediate operands from instruction definitions @@ -759,11 +1037,8 @@ class ParserRISCV(BaseParser): imd_type=operand["imd"], value=operand.get("value"), identifier=operand.get("identifier"), - shift=operand.get("shift") + shift=operand.get("shift"), ) else: # For any other immediate format - return ImmediateOperand( - imd_type="int", - value=str(operand) - ) \ No newline at end of file + return ImmediateOperand(imd_type="int", value=str(operand)) diff --git a/osaca/parser/parser_x86att.py b/osaca/parser/parser_x86att.py index 54f5125..f63d57c 100644 --- a/osaca/parser/parser_x86att.py +++ b/osaca/parser/parser_x86att.py @@ -38,7 +38,9 @@ class ParserX86ATT(ParserX86): ), ], InstructionForm( - directive_id=DirectiveOperand(name="byte", parameters=["100", "103", "144"]) + directive_id=DirectiveOperand( + name="byte", parameters=["100", "103", "144"] + ) ), ] @@ -55,7 +57,9 @@ class ParserX86ATT(ParserX86): ), ], InstructionForm( - directive_id=DirectiveOperand(name="byte", parameters=["100", "103", "144"]) + directive_id=DirectiveOperand( + name="byte", parameters=["100", "103", "144"] + ) ), ] @@ -113,7 +117,9 @@ class ParserX86ATT(ParserX86): label_identifier = pp.Group( pp.Optional(id_offset).setResultsName("offset") + pp.Combine( - pp.delimitedList(pp.Combine(first + pp.Optional(label_rest)), delim="::"), + pp.delimitedList( + pp.Combine(first + pp.Optional(label_rest)), delim="::" + ), joinString="::", ).setResultsName("name") + pp.Optional(relocation).setResultsName("relocation") @@ -179,9 +185,9 @@ class ParserX86ATT(ParserX86): + segment_extension.setResultsName(self.segment_ext) ) # Memory: offset | seg:seg_ext | offset(base, index, scale){mask} - memory_abs = pp.Suppress(pp.Literal("*")) + (offset | self.register).setResultsName( - "offset" - ) + memory_abs = pp.Suppress(pp.Literal("*")) + ( + offset | self.register + ).setResultsName("offset") memory = pp.Group( ( pp.Optional(pp.Suppress(pp.Literal("*"))) @@ -268,7 +274,9 @@ class ParserX86ATT(ParserX86): # 1. Parse comment try: - result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.comment.parseString(line, parseAll=True).asDict() + ) instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass @@ -277,7 +285,9 @@ class ParserX86ATT(ParserX86): if result is None: try: # returns tuple with label operand and comment, if any - result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.label.parseString(line, parseAll=True).asDict() + ) instruction_form.label = result[0].name if result[1] is not None: instruction_form.comment = " ".join(result[1]) @@ -307,7 +317,9 @@ class ParserX86ATT(ParserX86): result = self.parse_instruction(line) except pp.ParseException as e: raise ValueError( - "Could not parse instruction on line {}: {!r}".format(line_number, line) + "Could not parse instruction on line {}: {!r}".format( + line_number, line + ) ) from e instruction_form.mnemonic = result.mnemonic instruction_form.operands = result.operands @@ -321,7 +333,9 @@ class ParserX86ATT(ParserX86): :param str instruction: Assembly line string. :returns: `dict` -- parsed instruction form """ - result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() + result = self.instruction_parser.parseString( + instruction, parseAll=True + ).asDict() operands = [] # Add operands to list # Check first operand @@ -339,7 +353,9 @@ class ParserX86ATT(ParserX86): return_dict = InstructionForm( mnemonic=result["mnemonic"].split(",")[0], operands=operands, - comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None, + comment_id=( + " ".join(result[self.comment_id]) if self.comment_id in result else None + ), ) return return_dict @@ -368,7 +384,9 @@ class ParserX86ATT(ParserX86): shape=operand["shape"].lower() if "shape" in operand else None, lanes=operand["lanes"] if "lanes" in operand else None, index=operand["index"] if "index" in operand else None, - predication=operand["predication"].lower() if "predication" in operand else None, + predication=( + operand["predication"].lower() if "predication" in operand else None + ), ) def process_directive(self, directive): @@ -400,7 +418,8 @@ class ParserX86ATT(ParserX86): ) if index is not None: indexOp = RegisterOperand( - name=index["name"], prefix=index["prefix"] if "prefix" in index else None + name=index["name"], + prefix=index["prefix"] if "prefix" in index else None, ) if isinstance(offset, dict) and "identifier" in offset: offset = IdentifierOperand(name=offset["identifier"]["name"]) @@ -414,7 +433,9 @@ class ParserX86ATT(ParserX86): """Post-process label asm line""" # remove duplicated 'name' level due to identifier label["name"] = label["name"][0]["name"] - return LabelOperand(name=label["name"]), label["comment"] if "comment" in label else None + return LabelOperand(name=label["name"]), ( + label["comment"] if "comment" in label else None + ) def process_immediate(self, immediate): """Post-process immediate operand""" diff --git a/osaca/parser/parser_x86intel.py b/osaca/parser/parser_x86intel.py index 8802b59..96d3bdd 100644 --- a/osaca/parser/parser_x86intel.py +++ b/osaca/parser/parser_x86intel.py @@ -17,7 +17,7 @@ from osaca.parser.register import RegisterOperand # syntax characters. # This approach is described at the end of https://www.unicode.org/reports/tr55/#Whitespace-Syntax. # It is appropriate for tools, such as this one, which process source code but do not fully validate -# it (in this case, that’s the job of the assembler). +# it (in this case, that's the job of the assembler). NON_ASCII_PRINTABLE_CHARACTERS = "".join( chr(cp) for cp in range(0x80, 0x10FFFF + 1) @@ -26,9 +26,13 @@ NON_ASCII_PRINTABLE_CHARACTERS = "".join( # References: -# ASM386 Assembly Language Reference, document number 469165-003, https://mirror.math.princeton.edu/pub/oldlinux/Linux.old/Ref-docs/asm-ref.pdf. -# Microsoft Macro Assembler BNF Grammar, https://learn.microsoft.com/en-us/cpp/assembler/masm/masm-bnf-grammar?view=msvc-170. -# Intel Architecture Code Analyzer User's Guide, https://www.intel.com/content/dam/develop/external/us/en/documents/intel-architecture-code-analyzer-3-0-users-guide-157552.pdf. +# ASM386 Assembly Language Reference, document number 469165-003, +# https://mirror.math.princeton.edu/pub/oldlinux/Linux.old/Ref-docs/asm-ref.pdf. +# Microsoft Macro Assembler BNF Grammar, +# https://learn.microsoft.com/en-us/cpp/assembler/masm/masm-bnf-grammar?view=msvc-170. +# Intel Architecture Code Analyzer User's Guide, +# https://www.intel.com/content/dam/develop/external/us/en/documents/ +# intel-architecture-code-analyzer-3-0-users-guide-157552.pdf. class ParserX86Intel(ParserX86): _instance = None @@ -52,7 +56,8 @@ class ParserX86Intel(ParserX86): mnemonic="mov", operands=[ MemoryOperand( - base=RegisterOperand(name="GS"), offset=ImmediateOperand(value=111) + base=RegisterOperand(name="GS"), + offset=ImmediateOperand(value=111), ), ImmediateOperand(value=111), ], @@ -65,7 +70,8 @@ class ParserX86Intel(ParserX86): mnemonic="mov", operands=[ MemoryOperand( - base=RegisterOperand(name="GS"), offset=ImmediateOperand(value=222) + base=RegisterOperand(name="GS"), + offset=ImmediateOperand(value=222), ), ImmediateOperand(value=222), ], @@ -92,11 +98,15 @@ class ParserX86Intel(ParserX86): if not arch_model.get_instruction(mnemonic, len(instruction_form.operands)): if mnemonic[0] == "v": unvexed_mnemonic = mnemonic[1:] - if arch_model.get_instruction(unvexed_mnemonic, len(instruction_form.operands)): + if arch_model.get_instruction( + unvexed_mnemonic, len(instruction_form.operands) + ): mnemonic = unvexed_mnemonic else: vexed_mnemonic = "v" + mnemonic - if arch_model.get_instruction(vexed_mnemonic, len(instruction_form.operands)): + if arch_model.get_instruction( + vexed_mnemonic, len(instruction_form.operands) + ): mnemonic = vexed_mnemonic instruction_form.mnemonic = mnemonic @@ -286,9 +296,9 @@ class ParserX86Intel(ParserX86): pp.CaselessKeyword("ST") + pp.Optional(pp.Literal("(") + pp.Word("01234567") + pp.Literal(")")) ).setResultsName("name") - xmm_register = pp.Combine(pp.CaselessLiteral("XMM") + pp.Word(pp.nums)) | pp.Combine( - pp.CaselessLiteral("XMM1") + pp.Word("012345") - ) + xmm_register = pp.Combine( + pp.CaselessLiteral("XMM") + pp.Word(pp.nums) + ) | pp.Combine(pp.CaselessLiteral("XMM1") + pp.Word("012345")) simd_register = ( pp.Combine(pp.CaselessLiteral("MM") + pp.Word("01234567")) | xmm_register @@ -345,9 +355,21 @@ class ParserX86Intel(ParserX86): + ( base ^ (base + operator_displacement + displacement) - ^ (base + operator_displacement + displacement + operator_index + indexed) + ^ ( + base + + operator_displacement + + displacement + + operator_index + + indexed + ) ^ (base + operator_index + indexed) - ^ (base + operator_index + indexed + operator_displacement + displacement) + ^ ( + base + + operator_index + + indexed + + operator_displacement + + displacement + ) ^ (displacement + operator + base) ^ (displacement + operator + base + operator_index + indexed) ^ ( @@ -397,16 +419,19 @@ class ParserX86Intel(ParserX86): ptr_expression = pp.Group( data_type + pp.CaselessKeyword("PTR") + address_expression ).setResultsName("ptr_expression") - short_expression = pp.Group(pp.CaselessKeyword("SHORT") + identifier).setResultsName( - "short_expression" - ) + short_expression = pp.Group( + pp.CaselessKeyword("SHORT") + identifier + ).setResultsName("short_expression") # Instructions. mnemonic = pp.Word(pp.alphas, pp.alphanums).setResultsName("mnemonic") operand = pp.Group( self.register | pp.Group( - offset_expression | ptr_expression | short_expression | address_expression + offset_expression + | ptr_expression + | short_expression + | address_expression ).setResultsName(self.memory_id) | immediate ) @@ -433,10 +458,16 @@ class ParserX86Intel(ParserX86): # Directives. # The identifiers at the beginnig of a directive cannot start with a "." otherwise we end up # with ambiguities. - directive_first = pp.Word(pp.alphas + NON_ASCII_PRINTABLE_CHARACTERS + "$?@_<>", exact=1) - directive_rest = pp.Word(pp.alphanums + NON_ASCII_PRINTABLE_CHARACTERS + ".$?@_<>") + directive_first = pp.Word( + pp.alphas + NON_ASCII_PRINTABLE_CHARACTERS + "$?@_<>", exact=1 + ) + directive_rest = pp.Word( + pp.alphanums + NON_ASCII_PRINTABLE_CHARACTERS + ".$?@_<>" + ) directive_identifier = pp.Group( - pp.Combine(directive_first + pp.Optional(directive_rest)).setResultsName("name") + pp.Combine(directive_first + pp.Optional(directive_rest)).setResultsName( + "name" + ) ).setResultsName("identifier") # Parameter can be any quoted string or sequence of characters besides ';' (for comments) @@ -444,7 +475,9 @@ class ParserX86Intel(ParserX86): directive_parameter = ( pp.quotedString ^ ( - pp.Word(pp.printables + NON_ASCII_PRINTABLE_CHARACTERS, excludeChars=",;") + pp.Word( + pp.printables + NON_ASCII_PRINTABLE_CHARACTERS, excludeChars=",;" + ) + pp.Optional(pp.Suppress(pp.Literal(","))) ) ^ pp.Suppress(pp.Literal(",")) @@ -548,7 +581,9 @@ class ParserX86Intel(ParserX86): if not result: try: # Returns tuple with label operand and comment, if any. - result = self.process_operand(self.label.parseString(line, parseAll=True)) + result = self.process_operand( + self.label.parseString(line, parseAll=True) + ) instruction_form.label = result[0].name if result[1]: instruction_form.comment = " ".join(result[1]) @@ -559,7 +594,9 @@ class ParserX86Intel(ParserX86): if not result: try: # Returns tuple with directive operand and comment, if any. - result = self.process_operand(self.directive.parseString(line, parseAll=True)) + result = self.process_operand( + self.directive.parseString(line, parseAll=True) + ) instruction_form.directive = result[0] if result[1]: instruction_form.comment = " ".join(result[1]) @@ -572,7 +609,9 @@ class ParserX86Intel(ParserX86): result = self.parse_instruction(line) except pp.ParseException as e: raise ValueError( - "Could not parse instruction on line {}: {!r}".format(line_number, line) + "Could not parse instruction on line {}: {!r}".format( + line_number, line + ) ) from e instruction_form.mnemonic = result.mnemonic instruction_form.operands = result.operands @@ -627,7 +666,9 @@ class ParserX86Intel(ParserX86): def parse_register(self, register_string): """Parse register string""" try: - return self.process_operand(self.register.parseString(register_string, parseAll=True)) + return self.process_operand( + self.register.parseString(register_string, parseAll=True) + ) except pp.ParseException: return None @@ -651,7 +692,9 @@ class ParserX86Intel(ParserX86): # TODO: This is putting the identifier in the parameters. No idea if it's right. parameters = [directive.identifier.name] if "identifier" in directive else [] parameters.extend(directive.parameters) - directive_new = DirectiveOperand(name=directive.name, parameters=parameters or None) + directive_new = DirectiveOperand( + name=directive.name, parameters=parameters or None + ) # Interpret the "=" directives because the generated assembly is full of symbols that are # defined there. if directive.name == "=": @@ -672,7 +715,9 @@ class ParserX86Intel(ParserX86): scale = int(indexed.get("scale", "1"), 0) if register_expression.get("operator_index") == "-": scale *= -1 - displacement_op = self.process_immediate(displacement.immediate) if displacement else None + displacement_op = ( + self.process_immediate(displacement.immediate) if displacement else None + ) if displacement_op and register_expression.get("operator_disp") == "-": displacement_op.value *= -1 base_op = RegisterOperand(name=base.name) if base else None @@ -712,7 +757,9 @@ class ParserX86Intel(ParserX86): register_expression.data_type = data_type return register_expression elif segment: - return MemoryOperand(base=segment, offset=immediate_operand, data_type=data_type) + return MemoryOperand( + base=segment, offset=immediate_operand, data_type=data_type + ) elif identifier: if immediate_operand: identifier.offset = immediate_operand @@ -774,7 +821,9 @@ class ParserX86Intel(ParserX86): if "identifier" in immediate: # Actually an identifier, change declaration. return self.process_identifier(immediate.identifier) - new_immediate = ImmediateOperand(value=immediate.get("sign", "") + immediate.value) + new_immediate = ImmediateOperand( + value=immediate.get("sign", "") + immediate.value + ) new_immediate.value = self.normalize_imd(new_immediate) return new_immediate diff --git a/osaca/parser/register.py b/osaca/parser/register.py index e3effaf..23b404c 100644 --- a/osaca/parser/register.py +++ b/osaca/parser/register.py @@ -156,8 +156,9 @@ class RegisterOperand(Operand): f"Register(name={self._name}, width={self._width}, " f"prefix={self._prefix}, regtype={self._regtype}, " f"lanes={self._lanes}, shape={self._shape}, index={self._index}, " - f"mask={self._mask}, zeroing={self._zeroing},source={self._source},destination={self._destination}," - f"pre_indexed={self._pre_indexed}, post_indexed={self._post_indexed}) " + f"mask={self._mask}, zeroing={self._zeroing}, source={self._source}, " + f"destination={self._destination}, pre_indexed={self._pre_indexed}, " + f"post_indexed={self._post_indexed})" ) def __repr__(self): diff --git a/osaca/semantics/arch_semantics.py b/osaca/semantics/arch_semantics.py index f159ada..a6af343 100644 --- a/osaca/semantics/arch_semantics.py +++ b/osaca/semantics/arch_semantics.py @@ -59,22 +59,24 @@ class ArchSemantics(ISASemantics): multiple_assignments = False for idx, instruction_form in enumerate(kernel[start:], start): multiple_assignments = False - # if iform has multiple possible port assignments, check all in a DFS manner and take the best + # if iform has multiple possible port assignments, check all in a DFS + # manner and take the best if isinstance(instruction_form.port_uops, dict): best_kernel = None best_kernel_tp = sys.maxsize for port_util_alt in list(instruction_form.port_uops.values())[1:]: k_tmp = deepcopy(kernel) k_tmp[idx].port_uops = deepcopy(port_util_alt) - k_tmp[idx].port_pressure = self._machine_model.average_port_pressure( - k_tmp[idx].port_uops + k_tmp[idx].port_pressure = ( + self._machine_model.average_port_pressure(k_tmp[idx].port_uops) ) k_tmp.reverse() self.assign_optimal_throughput(k_tmp, idx) if max(self.get_throughput_sum(k_tmp)) < best_kernel_tp: best_kernel = k_tmp best_kernel_tp = max(self.get_throughput_sum(best_kernel)) - # check the first option in the main branch and compare against the best option later + # check the first option in the main branch and compare against the + # best option later multiple_assignments = True kernel[idx].port_uops = list(instruction_form.port_uops.values())[0] for uop in instruction_form.port_uops: @@ -82,8 +84,12 @@ class ArchSemantics(ISASemantics): ports = list(uop[1]) indices = [port_list.index(p) for p in ports] # check if port sum of used ports for uop are unbalanced - port_sums = self._to_list(itemgetter(*indices)(self.get_throughput_sum(kernel))) - instr_ports = self._to_list(itemgetter(*indices)(instruction_form.port_pressure)) + port_sums = self._to_list( + itemgetter(*indices)(self.get_throughput_sum(kernel)) + ) + instr_ports = self._to_list( + itemgetter(*indices)(instruction_form.port_pressure) + ) if len(set(port_sums)) > 1: # balance ports # init list for keeping track of the current change @@ -99,17 +105,19 @@ class ArchSemantics(ISASemantics): differences[max_port_idx] -= INC differences[min_port_idx] += INC # instr_ports = [round(p, 2) for p in instr_ports] - self._itemsetter(*indices)(instruction_form.port_pressure, *instr_ports) + self._itemsetter(*indices)( + instruction_form.port_pressure, *instr_ports + ) # check if min port is zero if round(min(instr_ports), 2) <= 0: - # if port_pressure is not exactly 0.00, add the residual to - # the former port + # if port_pressure is not exactly 0.00, add the + # residual to the former port if min(instr_ports) != 0.0: min_port_idx = port_sums.index(min(port_sums)) instr_ports[min_port_idx] += min(instr_ports) differences[min_port_idx] += min(instr_ports) - # we don't need to decrease difference for other port, just - # delete it + # we don't need to decrease difference for + # other port, just delete it del differences[instr_ports.index(min(instr_ports))] self._itemsetter(*indices)( instruction_form.port_pressure, *instr_ports @@ -122,16 +130,20 @@ class ArchSemantics(ISASemantics): ][0] instruction_form.port_pressure[zero_index] = 0.0 # Remove from further balancing - indices = [p for p in indices if instruction_form.port_pressure[p] > 0] + indices = [ + p + for p in indices + if instruction_form.port_pressure[p] > 0 + ] instr_ports = self._to_list( itemgetter(*indices)(instruction_form.port_pressure) ) - # never remove more than the fixed utilization per uop and port, i.e., - # cycles/len(ports) + # never remove more than the fixed utilization + # per uop and port, i.e., cycles/len(ports) if round(min(differences), 2) <= 0: - # don't worry if port_pressure isn't exactly 0 and just - # remove from further balancing by deleting index since - # pressure is not 0 + # don't worry if port_pressure isn't exactly 0 and + # just remove from further balancing by deleting + # index since pressure is not 0 del indices[differences.index(min(differences))] instr_ports = self._to_list( itemgetter(*indices)(instruction_form.port_pressure) @@ -180,7 +192,11 @@ class ArchSemantics(ISASemantics): if INSTR_FLAGS.HIDDEN_LD not in load_instr.flags ] ) - load = [instr for instr in kernel if instr.line_number == min_distance_load[1]][0] + load = [ + instr + for instr in kernel + if instr.line_number == min_distance_load[1] + ][0] # Hide load load.flags += [INSTR_FLAGS.HIDDEN_LD] load.port_pressure = self._nullify_data_ports(load.port_pressure) @@ -243,7 +259,9 @@ class ArchSemantics(ISASemantics): load_perf_data = self._machine_model.get_load_throughput( [ x - for x in instruction_form.semantic_operands["source"] + for x in instruction_form.semantic_operands[ + "source" + ] + instruction_form.semantic_operands["src_dst"] if isinstance(x, MemoryOperand) ][0] @@ -261,14 +279,18 @@ class ArchSemantics(ISASemantics): data_port_uops = load_perf_data[0][1] else: data_port_uops = data_port_uops[0] - data_port_pressure = self._machine_model.average_port_pressure( - data_port_uops + data_port_pressure = ( + self._machine_model.average_port_pressure( + data_port_uops + ) ) if "load_throughput_multiplier" in self._machine_model: - multiplier = self._machine_model["load_throughput_multiplier"][ - reg_type + multiplier = self._machine_model[ + "load_throughput_multiplier" + ][reg_type] + data_port_pressure = [ + pp * multiplier for pp in data_port_pressure ] - data_port_pressure = [pp * multiplier for pp in data_port_pressure] if INSTR_FLAGS.HAS_ST in instruction_form.flags: # STORE performance data destinations = ( @@ -276,7 +298,11 @@ class ArchSemantics(ISASemantics): + instruction_form.semantic_operands["src_dst"] ) store_perf_data = self._machine_model.get_store_throughput( - [x for x in destinations if isinstance(x, MemoryOperand)][0], + [ + x + for x in destinations + if isinstance(x, MemoryOperand) + ][0], dummy_reg, ) st_data_port_uops = store_perf_data[0][1] @@ -294,7 +320,9 @@ class ArchSemantics(ISASemantics): and all( [ op.post_indexed or op.pre_indexed - for op in instruction_form.semantic_operands["src_dst"] + for op in instruction_form.semantic_operands[ + "src_dst" + ] if isinstance(op, MemoryOperand) ] ) @@ -303,21 +331,26 @@ class ArchSemantics(ISASemantics): instruction_form.flags.remove(INSTR_FLAGS.HAS_ST) # sum up all data ports in case for LOAD and STORE - st_data_port_pressure = self._machine_model.average_port_pressure( - st_data_port_uops + st_data_port_pressure = ( + self._machine_model.average_port_pressure( + st_data_port_uops + ) ) if "store_throughput_multiplier" in self._machine_model: - multiplier = self._machine_model["store_throughput_multiplier"][ - reg_type - ] + multiplier = self._machine_model[ + "store_throughput_multiplier" + ][reg_type] st_data_port_pressure = [ pp * multiplier for pp in st_data_port_pressure ] data_port_pressure = [ - sum(x) for x in zip(data_port_pressure, st_data_port_pressure) + sum(x) + for x in zip(data_port_pressure, st_data_port_pressure) ] data_port_uops += st_data_port_uops - throughput = max(max(data_port_pressure), instruction_data_reg.throughput) + throughput = max( + max(data_port_pressure), instruction_data_reg.throughput + ) latency = instruction_data_reg.latency # Add LD and ST latency latency += ( @@ -380,11 +413,15 @@ class ArchSemantics(ISASemantics): instruction_form.latency_cp = 0 instruction_form.latency_lcd = 0 - def _handle_instruction_found(self, instruction_data, port_number, instruction_form, flags): + def _handle_instruction_found( + self, instruction_data, port_number, instruction_form, flags + ): """Apply performance data to instruction if it was found in the archDB""" instruction_form.check_normalized() throughput = instruction_data.throughput - port_pressure = self._machine_model.average_port_pressure(instruction_data.port_pressure) + port_pressure = self._machine_model.average_port_pressure( + instruction_data.port_pressure + ) instruction_form.port_uops = instruction_data.port_pressure try: assert isinstance(port_pressure, list) @@ -469,7 +506,9 @@ class ArchSemantics(ISASemantics): """Get the overall throughput sum separated by port of all instructions of a kernel.""" # ignoring all lines with throughput == 0.0, because there won't be anything to sum up # typically comment, label and non-instruction lines - port_pressures = [instr.port_pressure for instr in kernel if instr.throughput != 0.0] + port_pressures = [ + instr.port_pressure for instr in kernel if instr.throughput != 0.0 + ] # Essentially summing up each columns of port_pressures, where each column is one port # and each row is one line of the kernel # round is necessary to ensure termination of ArchsSemantics.assign_optimal_throughput diff --git a/osaca/semantics/hw_model.py b/osaca/semantics/hw_model.py index 4979119..26c2a0c 100644 --- a/osaca/semantics/hw_model.py +++ b/osaca/semantics/hw_model.py @@ -1,5 +1,5 @@ # TODO -#!/usr/bin/env python3 +#!/usr/bin/env python3 # noqa: E265 import hashlib import os @@ -26,7 +26,9 @@ from ruamel.yaml.compat import StringIO class MachineModel(object): WILDCARD = "*" - INTERNAL_VERSION = 1 # increase whenever self._data format changes to invalidate cache! + INTERNAL_VERSION = ( + 1 # increase whenever self._data format changes to invalidate cache! + ) _runtime_cache = {} def __init__(self, arch=None, path_to_yaml=None, isa=None, lazy=False): @@ -50,7 +52,9 @@ class MachineModel(object): "offset": o, "scale": s, } - for b, i, o, s in product(["gpr"], ["gpr", None], ["imd", None], [1, 8]) + for b, i, o, s in product( + ["gpr"], ["gpr", None], ["imd", None], [1, 8] + ) ], "load_throughput_default": [], "store_throughput": [], @@ -92,7 +96,9 @@ class MachineModel(object): self._data["instruction_forms"] = [] # separate multi-alias instruction forms for entry in [ - x for x in self._data["instruction_forms"] if isinstance(x["name"], list) + x + for x in self._data["instruction_forms"] + if isinstance(x["name"], list) ]: for name in entry["name"]: new_entry = {"name": name} @@ -125,16 +131,26 @@ class MachineModel(object): mnemonic=iform["name"].upper() if "name" in iform else None, operands=iform["operands"] if "operands" in iform else [], hidden_operands=( - iform["hidden_operands"] if "hidden_operands" in iform else [] + iform["hidden_operands"] + if "hidden_operands" in iform + else [] + ), + directive_id=( + iform["directive"] if "directive" in iform else None ), - directive_id=iform["directive"] if "directive" in iform else None, comment_id=iform["comment"] if "comment" in iform else None, line=iform["line"] if "line" in iform else None, - line_number=iform["line_number"] if "line_number" in iform else None, + line_number=( + iform["line_number"] if "line_number" in iform else None + ), latency=iform["latency"] if "latency" in iform else None, - throughput=iform["throughput"] if "throughput" in iform else None, + throughput=( + iform["throughput"] if "throughput" in iform else None + ), uops=iform["uops"] if "uops" in iform else None, - port_pressure=iform["port_pressure"] if "port_pressure" in iform else None, + port_pressure=( + iform["port_pressure"] if "port_pressure" in iform else None + ), operation=iform["operation"] if "operation" in iform else None, breaks_dependency_on_equal_operands=( iform["breaks_dependency_on_equal_operands"] @@ -148,7 +164,9 @@ class MachineModel(object): ), ) # List containing classes with same name/instruction - self._data["instruction_forms_dict"][iform["name"]].append(new_iform) + self._data["instruction_forms_dict"][iform["name"]].append( + new_iform + ) self._data["internal_version"] = self.INTERNAL_VERSION # Convert load and store throughput memory operands to classes @@ -398,11 +416,19 @@ class MachineModel(object): def get_load_latency(self, reg_type): """Return load latency for given register type.""" - return self._data["load_latency"][reg_type] if self._data["load_latency"][reg_type] else 0 + return ( + self._data["load_latency"][reg_type] + if self._data["load_latency"][reg_type] + else 0 + ) def get_load_throughput(self, memory): """Return load thorughput for given register type.""" - ld_tp = [m for m in self._data["load_throughput"] if self._match_mem_entries(memory, m[0])] + ld_tp = [ + m + for m in self._data["load_throughput"] + if self._match_mem_entries(memory, m[0]) + ] if len(ld_tp) > 0: return ld_tp.copy() return [(memory, self._data["load_throughput_default"].copy())] @@ -415,7 +441,9 @@ class MachineModel(object): def get_store_throughput(self, memory, src_reg=None): """Return store throughput for a given destination and register type.""" st_tp = [ - m for m in self._data["store_throughput"] if self._match_mem_entries(memory, m[0]) + m + for m in self._data["store_throughput"] + if self._match_mem_entries(memory, m[0]) ] if src_reg is not None: st_tp = [ @@ -486,7 +514,8 @@ class MachineModel(object): raise ValueError("Unknown architecture {!r}.".format(arch)) def class_to_dict(self, op): - """Need to convert operand classes to dicts for the dump. Memory operand types may have their index/base/offset as a register operand/""" + """Need to convert operand classes to dicts for the dump. Memory operand types + may have their index/base/offset as a register operand/""" if isinstance(op, Operand): dict_op = dict( (key.lstrip("_"), value) @@ -527,7 +556,9 @@ class MachineModel(object): if not callable(value) and not key.startswith("__") ) if instruction_form["port_pressure"] is not None: - cs = ruamel.yaml.comments.CommentedSeq(instruction_form["port_pressure"]) + cs = ruamel.yaml.comments.CommentedSeq( + instruction_form["port_pressure"] + ) cs.fa.set_flow_style() instruction_form["port_pressure"] = cs dict_operands = [] @@ -596,7 +627,9 @@ class MachineModel(object): hexhash = hashlib.sha256(p.read_bytes()).hexdigest() # 1. companion cachefile: same location, with '._.pickle' - companion_cachefile = p.with_name("." + p.stem + "_" + hexhash).with_suffix(".pickle") + companion_cachefile = p.with_name("." + p.stem + "_" + hexhash).with_suffix( + ".pickle" + ) if companion_cachefile.exists(): # companion file (must be up-to-date, due to equal hash) with companion_cachefile.open("rb") as f: @@ -605,7 +638,9 @@ class MachineModel(object): return data # 2. home cachefile: ~/.osaca/cache/_.pickle - home_cachefile = (Path(utils.CACHE_DIR) / (p.stem + "_" + hexhash)).with_suffix(".pickle") + home_cachefile = (Path(utils.CACHE_DIR) / (p.stem + "_" + hexhash)).with_suffix( + ".pickle" + ) if home_cachefile.exists(): # home file (must be up-to-date, due to equal hash) with home_cachefile.open("rb") as f: @@ -624,7 +659,9 @@ class MachineModel(object): p = Path(filepath) hexhash = hashlib.sha256(p.read_bytes()).hexdigest() # 1. companion cachefile: same location, with '._.pickle' - companion_cachefile = p.with_name("." + p.stem + "_" + hexhash).with_suffix(".pickle") + companion_cachefile = p.with_name("." + p.stem + "_" + hexhash).with_suffix( + ".pickle" + ) if os.access(str(companion_cachefile.parent), os.W_OK): with companion_cachefile.open("wb") as f: pickle.dump(self._data, f) @@ -668,7 +705,9 @@ class MachineModel(object): operand_string += operand["prefix"] operand_string += operand["shape"] if "shape" in operand else "" elif "name" in operand: - operand_string += "r" if operand["name"] == "gpr" else operand["name"][0] + operand_string += ( + "r" if operand["name"] == "gpr" else operand["name"][0] + ) elif opclass == "memory": # Memory operand_string += "m" @@ -803,7 +842,10 @@ class MachineModel(object): return False return self._is_AArch64_mem_type(i_operand, operand) # immediate - if isinstance(i_operand, ImmediateOperand) and i_operand.imd_type == self.WILDCARD: + if ( + isinstance(i_operand, ImmediateOperand) + and i_operand.imd_type == self.WILDCARD + ): return isinstance(operand, ImmediateOperand) and (operand.value is not None) if isinstance(i_operand, ImmediateOperand) and i_operand.imd_type == "int": @@ -838,7 +880,9 @@ class MachineModel(object): # condition if isinstance(operand, ConditionOperand): if isinstance(i_operand, ConditionOperand): - return (i_operand.ccode == self.WILDCARD) or (i_operand.ccode == operand.ccode) + return (i_operand.ccode == self.WILDCARD) or ( + i_operand.ccode == operand.ccode + ) # no match return False @@ -860,7 +904,9 @@ class MachineModel(object): # immediate if isinstance(operand, ImmediateOperand): # if "immediate" in operand.name or operand.value != None: - return isinstance(i_operand, ImmediateOperand) and i_operand.imd_type == "int" + return ( + isinstance(i_operand, ImmediateOperand) and i_operand.imd_type == "int" + ) # identifier (e.g., labels) if isinstance(operand, IdentifierOperand): return isinstance(i_operand, IdentifierOperand) @@ -873,13 +919,13 @@ class MachineModel(object): if not isinstance(i_operand, RegisterOperand): return False return self._is_RISCV_reg_type(i_operand, operand) - + # memory if isinstance(operand, MemoryOperand): if not isinstance(i_operand, MemoryOperand): return False return self._is_RISCV_mem_type(i_operand, operand) - + # immediate if isinstance(operand, (ImmediateOperand, int)): if not isinstance(i_operand, ImmediateOperand): @@ -895,7 +941,7 @@ class MachineModel(object): if i_operand.imd_type == self.WILDCARD: return True return False - + # identifier if isinstance(operand, IdentifierOperand) or ( isinstance(operand, ImmediateOperand) and operand.identifier is not None @@ -929,7 +975,8 @@ class MachineModel(object): if reg.prefix == self.WILDCARD or i_reg.prefix == self.WILDCARD: if reg.shape is not None: if i_reg.shape is not None and ( - reg.shape == i_reg.shape or self.WILDCARD in (reg.shape + i_reg.shape) + reg.shape == i_reg.shape + or self.WILDCARD in (reg.shape + i_reg.shape) ): return True return False @@ -994,7 +1041,10 @@ class MachineModel(object): # one instruction is missing zeroing while the other has it zero_ok = False # check for wildcard - if i_reg.zeroing == self.WILDCARD or reg.zeroing == self.WILDCARD: + if ( + i_reg.zeroing == self.WILDCARD + or reg.zeroing == self.WILDCARD + ): zero_ok = True if not mask_ok or not zero_ok: return False @@ -1011,7 +1061,7 @@ class MachineModel(object): # check for wildcards if reg.prefix == self.WILDCARD or i_reg.prefix == self.WILDCARD: return True - + # First handle potentially None values to avoid AttributeError if reg.name is None or i_reg.name is None: # If both have same prefix, they might still match @@ -1019,12 +1069,15 @@ class MachineModel(object): return True # If we can't determine canonical names, be conservative and return False return False - + # Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.) - if (reg.prefix is None and i_reg.prefix == "x") or (reg.prefix == "x" and i_reg.prefix is None): + if (reg.prefix is None and i_reg.prefix == "x") or ( + reg.prefix == "x" and i_reg.prefix is None + ): try: # Need to check if they refer to the same register from osaca.parser import ParserRISCV + parser = ParserRISCV() reg_canonical = parser._get_canonical_reg_name(reg) i_reg_canonical = parser._get_canonical_reg_name(i_reg) @@ -1032,16 +1085,18 @@ class MachineModel(object): return True except (AttributeError, KeyError): return False - + # Check for direct prefix matches if reg.prefix == i_reg.prefix: # For vector registers, check lanes if present if reg.prefix == "v" and reg.lanes is not None and i_reg.lanes is not None: - return reg.lanes == i_reg.lanes or self.WILDCARD in (reg.lanes + i_reg.lanes) + return reg.lanes == i_reg.lanes or self.WILDCARD in ( + reg.lanes + i_reg.lanes + ) return True - + return False - + def _is_AArch64_mem_type(self, i_mem, mem): """Check if memory addressing type match.""" if ( @@ -1049,7 +1104,10 @@ class MachineModel(object): ( (mem.base is None and i_mem.base is None) or i_mem.base == self.WILDCARD - or (isinstance(mem.base, RegisterOperand) and (mem.base.prefix == i_mem.base)) + or ( + isinstance(mem.base, RegisterOperand) + and (mem.base.prefix == i_mem.base) + ) ) # check offset and ( @@ -1083,7 +1141,10 @@ class MachineModel(object): or (mem.scale != 1 and i_mem.scale != 1) ) # check pre-indexing - and (i_mem.pre_indexed == self.WILDCARD or mem.pre_indexed == i_mem.pre_indexed) + and ( + i_mem.pre_indexed == self.WILDCARD + or mem.pre_indexed == i_mem.pre_indexed + ) # check post-indexing and ( i_mem.post_indexed == self.WILDCARD @@ -1116,7 +1177,8 @@ class MachineModel(object): mem.offset is not None and isinstance(mem.offset, ImmediateOperand) and ( - i_mem.offset == "imd" or (i_mem.offset is None and mem.offset.value == "0") + i_mem.offset == "imd" + or (i_mem.offset is None and mem.offset.value == "0") ) ) or (isinstance(mem.offset, IdentifierOperand) and i_mem.offset == "id") @@ -1148,9 +1210,13 @@ class MachineModel(object): ( (mem.base is None and i_mem.base is None) or i_mem.base == self.WILDCARD - or (isinstance(mem.base, RegisterOperand) and - (mem.base.prefix == i_mem.base or - (mem.base.name is not None and i_mem.base is not None))) + or ( + isinstance(mem.base, RegisterOperand) + and ( + mem.base.prefix == i_mem.base + or (mem.base.name is not None and i_mem.base is not None) + ) + ) ) # check offset and ( @@ -1181,4 +1247,4 @@ class MachineModel(object): def __represent_none(self, yaml_obj, data): """YAML representation for `None`""" - return yaml_obj.represent_scalar("tag:yaml.org,2002:null", "~") \ No newline at end of file + return yaml_obj.represent_scalar("tag:yaml.org,2002:null", "~") diff --git a/osaca/semantics/isa_semantics.py b/osaca/semantics/isa_semantics.py index 37a74cb..1c8a5d0 100644 --- a/osaca/semantics/isa_semantics.py +++ b/osaca/semantics/isa_semantics.py @@ -53,7 +53,11 @@ class ISASemantics(object): instruction_form.check_normalized() # if the instruction form doesn't have operands or is None, there's nothing to do if instruction_form.operands is None or instruction_form.mnemonic is None: - instruction_form.semantic_operands = {"source": [], "destination": [], "src_dst": []} + instruction_form.semantic_operands = { + "source": [], + "destination": [], + "src_dst": [], + } return # check if instruction form is in ISA yaml, otherwise apply standard operand assignment # (one dest, others source) @@ -82,7 +86,9 @@ class ISASemantics(object): if assign_default: # no irregular operand structure, apply default - op_dict["source"] = self._parser.get_regular_source_operands(instruction_form) + op_dict["source"] = self._parser.get_regular_source_operands( + instruction_form + ) op_dict["destination"] = self._parser.get_regular_destination_operands( instruction_form ) @@ -105,7 +111,9 @@ class ISASemantics(object): op_dict["src_dst"].append(reg) # post-process pre- and post-indexing for aarch64 memory operands if self._parser.isa() == "aarch64": - for operand in [op for op in op_dict["source"] if isinstance(op, MemoryOperand)]: + for operand in [ + op for op in op_dict["source"] if isinstance(op, MemoryOperand) + ]: post_indexed = operand.post_indexed pre_indexed = operand.pre_indexed if ( @@ -117,7 +125,9 @@ class ISASemantics(object): new_op.pre_indexed = pre_indexed new_op.post_indexed = post_indexed op_dict["src_dst"].append(new_op) - for operand in [op for op in op_dict["destination"] if isinstance(op, MemoryOperand)]: + for operand in [ + op for op in op_dict["destination"] if isinstance(op, MemoryOperand) + ]: post_indexed = operand.post_indexed pre_indexed = operand.pre_indexed if ( @@ -170,7 +180,9 @@ class ISASemantics(object): and o.base is not None and isinstance(o.post_indexed, dict) ): - base_name = (o.base.prefix if o.base.prefix is not None else "") + o.base.name + base_name = ( + o.base.prefix if o.base.prefix is not None else "" + ) + o.base.name return { base_name: { "name": (o.base.prefix if o.base.prefix is not None else "") @@ -181,7 +193,9 @@ class ISASemantics(object): return {} reg_operand_names = {} # e.g., {'rax': 'op1'} - operand_state = {} # e.g., {'op1': {'name': 'rax', 'value': 0}} 0 means unchanged + operand_state = ( + {} + ) # e.g., {'op1': {'name': 'rax', 'value': 0}} 0 means unchanged for o in instruction_form.operands: if isinstance(o, MemoryOperand) and o.pre_indexed: @@ -192,10 +206,14 @@ class ISASemantics(object): "This is currently not supprted.".format(instruction_form.line) ) - base_name = (o.base.prefix if o.base.prefix is not None else "") + o.base.name + base_name = ( + o.base.prefix if o.base.prefix is not None else "" + ) + o.base.name reg_operand_names = {base_name: "op1"} if o.offset: - operand_state = {"op1": {"name": base_name, "value": o.offset.value}} + operand_state = { + "op1": {"name": base_name, "value": o.offset.value} + } else: # no offset (e.g., with Arm9 memops) -> base is updated operand_state = {"op1": None} @@ -239,7 +257,10 @@ class ISASemantics(object): op_dict["src_dst"] = [] # handle dependency breaking instructions - if isa_data.breaks_dependency_on_equal_operands and operands[1:] == operands[:-1]: + if ( + isa_data.breaks_dependency_on_equal_operands + and operands[1:] == operands[:-1] + ): op_dict["destination"] += operands if isa_data.hidden_operands != []: op_dict["destination"] += [hop for hop in isa_data.hidden_operands] @@ -301,7 +322,8 @@ class ISASemantics(object): def substitute_mem_address(self, operands): """Create memory wildcard for all memory operands""" return [ - self._create_reg_wildcard() if isinstance(op, MemoryOperand) else op for op in operands + self._create_reg_wildcard() if isinstance(op, MemoryOperand) else op + for op in operands ] def _create_reg_wildcard(self): diff --git a/osaca/semantics/kernel_dg.py b/osaca/semantics/kernel_dg.py index 2dd46fb..2e74736 100644 --- a/osaca/semantics/kernel_dg.py +++ b/osaca/semantics/kernel_dg.py @@ -63,7 +63,9 @@ class KernelDG(nx.DiGraph): dg = nx.DiGraph() for i, instruction_form in enumerate(kernel): dg.add_node(instruction_form.line_number) - dg.nodes[instruction_form.line_number]["instruction_form"] = instruction_form + dg.nodes[instruction_form.line_number][ + "instruction_form" + ] = instruction_form # add load as separate node if existent if ( INSTR_FLAGS.HAS_LD in instruction_form.flags @@ -71,7 +73,9 @@ class KernelDG(nx.DiGraph): ): # add new node dg.add_node(instruction_form.line_number + 0.1) - dg.nodes[instruction_form.line_number + 0.1]["instruction_form"] = instruction_form + dg.nodes[instruction_form.line_number + 0.1][ + "instruction_form" + ] = instruction_form # and set LD latency as edge weight dg.add_edge( instruction_form.line_number + 0.1, @@ -84,7 +88,8 @@ class KernelDG(nx.DiGraph): # print(instruction_form.line_number,"\t",dep.line_number,"\n") edge_weight = ( instruction_form.latency - if "mem_dep" in dep_flags or instruction_form.latency_wo_load is None + if "mem_dep" in dep_flags + or instruction_form.latency_wo_load is None else instruction_form.latency_wo_load ) if "storeload_dep" in dep_flags and self.model is not None: @@ -327,7 +332,9 @@ class KernelDG(nx.DiGraph): # store to same location (presumed) if self.is_memstore(dst, instr_form, register_changes): break - self._update_reg_changes(instr_form, register_changes, only_postindexed=True) + self._update_reg_changes( + instr_form, register_changes, only_postindexed=True + ) def _update_reg_changes(self, iform, reg_state=None, only_postindexed=False): if self.arch_sem is None: @@ -335,7 +342,9 @@ class KernelDG(nx.DiGraph): return {} if reg_state is None: reg_state = {} - for reg, change in self.arch_sem.get_reg_changes(iform, only_postindexed).items(): + for reg, change in self.arch_sem.get_reg_changes( + iform, only_postindexed + ).items(): if change is None or reg_state.get(reg, {}) is None: reg_state[reg] = None else: @@ -378,9 +387,13 @@ class KernelDG(nx.DiGraph): is_read = self.parser.is_flag_dependend_of(register, src) or is_read if isinstance(src, MemoryOperand): if src.base is not None: - is_read = self.parser.is_reg_dependend_of(register, src.base) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, src.base) or is_read + ) if src.index is not None and isinstance(src.index, RegisterOperand): - is_read = self.parser.is_reg_dependend_of(register, src.index) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, src.index) or is_read + ) # Check also if read in destination memory address for dst in chain( instruction_form.semantic_operands["destination"], @@ -388,9 +401,13 @@ class KernelDG(nx.DiGraph): ): if isinstance(dst, MemoryOperand): if dst.base is not None: - is_read = self.parser.is_reg_dependend_of(register, dst.base) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, dst.base) or is_read + ) if dst.index is not None: - is_read = self.parser.is_reg_dependend_of(register, dst.index) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, dst.index) or is_read + ) return is_read def is_memload(self, mem, instruction_form, register_changes={}): @@ -408,13 +425,20 @@ class KernelDG(nx.DiGraph): # determine absolute address change addr_change = 0 - if isinstance(src.offset, ImmediateOperand) and src.offset.value is not None: + if ( + isinstance(src.offset, ImmediateOperand) + and src.offset.value is not None + ): addr_change += src.offset.value - if isinstance(mem.offset, ImmediateOperand) and mem.offset.value is not None: + if ( + isinstance(mem.offset, ImmediateOperand) + and mem.offset.value is not None + ): addr_change -= mem.offset.value if mem.base and src.base: base_change = register_changes.get( - (src.base.prefix if src.base.prefix is not None else "") + src.base.name, + (src.base.prefix if src.base.prefix is not None else "") + + src.base.name, { "name": (src.base.prefix if src.base.prefix is not None else "") + src.base.name, @@ -437,9 +461,12 @@ class KernelDG(nx.DiGraph): continue if mem.index and src.index: index_change = register_changes.get( - (src.index.prefix if src.index.prefix is not None else "") + src.index.name, + (src.index.prefix if src.index.prefix is not None else "") + + src.index.name, { - "name": (src.index.prefix if src.index.prefix is not None else "") + "name": ( + src.index.prefix if src.index.prefix is not None else "" + ) + src.index.name, "value": 0, }, @@ -476,12 +503,19 @@ class KernelDG(nx.DiGraph): instruction_form.semantic_operands["src_dst"], ): if isinstance(dst, RegisterOperand): - is_written = self.parser.is_reg_dependend_of(register, dst) or is_written + is_written = ( + self.parser.is_reg_dependend_of(register, dst) or is_written + ) if isinstance(dst, FlagOperand): - is_written = self.parser.is_flag_dependend_of(register, dst) or is_written + is_written = ( + self.parser.is_flag_dependend_of(register, dst) or is_written + ) if isinstance(dst, MemoryOperand): if dst.pre_indexed or dst.post_indexed: - is_written = self.parser.is_reg_dependend_of(register, dst.base) or is_written + is_written = ( + self.parser.is_reg_dependend_of(register, dst.base) + or is_written + ) # Check also for possible pre- or post-indexing in memory addresses for src in chain( instruction_form.semantic_operands["source"], @@ -489,7 +523,10 @@ class KernelDG(nx.DiGraph): ): if isinstance(src, MemoryOperand): if src.pre_indexed or src.post_indexed: - is_written = self.parser.is_reg_dependend_of(register, src.base) or is_written + is_written = ( + self.parser.is_reg_dependend_of(register, src.base) + or is_written + ) return is_written def is_memstore(self, mem, instruction_form, register_changes={}): @@ -519,7 +556,9 @@ class KernelDG(nx.DiGraph): lcd = self.get_loopcarried_dependencies() lcd_line_numbers = {} for dep in lcd: - lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]] + lcd_line_numbers[dep] = [ + x.line_number for x, lat in lcd[dep]["dependencies"] + ] # create LCD edges for dep in lcd_line_numbers: @@ -527,7 +566,9 @@ class KernelDG(nx.DiGraph): max_line_number = max(lcd_line_numbers[dep]) graph.add_edge(min_line_number, max_line_number, dir="back") graph.edges[min_line_number, max_line_number]["latency"] = [ - lat for x, lat in lcd[dep]["dependencies"] if x.line_number == max_line_number + lat + for x, lat in lcd[dep]["dependencies"] + if x.line_number == max_line_number ] # add label to edges @@ -568,7 +609,7 @@ class KernelDG(nx.DiGraph): (latency, list(deps)) for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"]) ), - reverse=True + reverse=True, ) node_colors = {} edge_colors = {} @@ -591,17 +632,16 @@ class KernelDG(nx.DiGraph): edge_colors[u, v] = color max_color = min(11, colors_used) colorscheme = f"spectral{max(3, max_color)}" - graph.graph["node"] = {"colorscheme" : colorscheme} - graph.graph["edge"] = {"colorscheme" : colorscheme} + graph.graph["node"] = {"colorscheme": colorscheme} + graph.graph["edge"] = {"colorscheme": colorscheme} for n, color in node_colors.items(): if "style" not in graph.nodes[n]: graph.nodes[n]["style"] = "filled" else: graph.nodes[n]["style"] += ",filled" graph.nodes[n]["fillcolor"] = color - if ( - (max_color >= 4 and color in (1, max_color)) or - (max_color >= 10 and color in (1, 2, max_color - 1 , max_color)) + if (max_color >= 4 and color in (1, max_color)) or ( + max_color >= 10 and color in (1, 2, max_color - 1, max_color) ): graph.nodes[n]["fontcolor"] = "white" for (u, v), color in edge_colors.items(): @@ -624,7 +664,11 @@ class KernelDG(nx.DiGraph): else: label = "label" if node.label is not None else None label = "directive" if node.directive is not None else label - label = "comment" if node.comment is not None and label is None else label + label = ( + "comment" + if node.comment is not None and label is None + else label + ) mapping[n] = "{}: {}".format(n, label) graph.nodes[n]["fontname"] = "italic" graph.nodes[n]["fontsize"] = 11.0 diff --git a/osaca/semantics/marker_utils.py b/osaca/semantics/marker_utils.py index 50f6c88..b627376 100644 --- a/osaca/semantics/marker_utils.py +++ b/osaca/semantics/marker_utils.py @@ -1,5 +1,5 @@ # TODO -#!/usr/bin/env python3 +#!/usr/bin/env python3 # noqa: E265 from collections import OrderedDict from enum import Enum @@ -55,7 +55,11 @@ def find_marked_section(lines, parser, comments=None): end_marker = parser.end_marker() for i, line in enumerate(lines): try: - if line.mnemonic is None and comments is not None and line.comment is not None: + if ( + line.mnemonic is None + and comments is not None + and line.comment is not None + ): if comments["start"] == line.comment: index_start = i + 1 elif comments["end"] == line.comment: @@ -218,7 +222,9 @@ def match_operands(line_operands, marker_line_operands): return False return all( match_operand(line_operand, marker_line_operand) - for line_operand, marker_line_operand in zip(line_operands, marker_line_operands) + for line_operand, marker_line_operand in zip( + line_operands, marker_line_operands + ) ) @@ -275,7 +281,9 @@ def match_parameter(parser, line_parameter, marker_line_parameter): # If the parameters don't match verbatim, check if they represent the same immediate value. line_immediate = ImmediateOperand(value=line_parameter) marker_line_immediate = ImmediateOperand(value=marker_line_parameter) - return parser.normalize_imd(line_immediate) == parser.normalize_imd(marker_line_immediate) + return parser.normalize_imd(line_immediate) == parser.normalize_imd( + marker_line_immediate + ) def find_jump_labels(lines): @@ -337,7 +345,9 @@ def find_basic_blocks(lines): blocks[label].append(line) # Find end of block by searching for references to valid jump labels if line.mnemonic is not None and line.operands != []: - for operand in [o for o in line.operands if isinstance(o, IdentifierOperand)]: + for operand in [ + o for o in line.operands if isinstance(o, IdentifierOperand) + ]: if operand.name in valid_jump_labels: terminate = True elif line.label is not None: @@ -371,7 +381,9 @@ def find_basic_loop_bodies(lines): # do not terminate if line.mnemonic == "b.none": continue - for operand in [o for o in line.operands if isinstance(o, IdentifierOperand)]: + for operand in [ + o for o in line.operands if isinstance(o, IdentifierOperand) + ]: if operand.name in valid_jump_labels: if operand.name == label: loop_bodies[label] = current_block diff --git a/setup.py b/setup.py index df74dc6..b71869f 100755 --- a/setup.py +++ b/setup.py @@ -58,7 +58,9 @@ class install(_install): class sdist(_sdist): def make_release_tree(self, basedir, files): _sdist.make_release_tree(self, basedir, files) - self.execute(_run_build_cache, (basedir,), msg="Build ISA and architecture cache") + self.execute( + _run_build_cache, (basedir,), msg="Build ISA and architecture cache" + ) # Get the long description from the README file diff --git a/tests/test_cli.py b/tests/test_cli.py index 3a2a018..a318859 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -28,7 +28,9 @@ class TestCLI(unittest.TestCase): def test_check_arguments(self): parser = osaca.create_parser(parser=ErrorRaisingArgumentParser()) - args = parser.parse_args(["--arch", "WRONG_ARCH", self._find_file("gs", "csx", "gcc")]) + args = parser.parse_args( + ["--arch", "WRONG_ARCH", self._find_file("gs", "csx", "gcc")] + ) with self.assertRaises(ValueError): osaca.check_arguments(args, parser) args = parser.parse_args( @@ -83,7 +85,9 @@ class TestCLI(unittest.TestCase): def test_get_parser(self): self.assertTrue(isinstance(osaca.get_asm_parser("csx"), ParserX86ATT)) - self.assertTrue(isinstance(osaca.get_asm_parser("csx", "intel"), ParserX86Intel)) + self.assertTrue( + isinstance(osaca.get_asm_parser("csx", "intel"), ParserX86Intel) + ) self.assertTrue(isinstance(osaca.get_asm_parser("tx2"), ParserAArch64)) self.assertTrue(isinstance(osaca.get_asm_parser("rv64"), ParserRISCV)) with self.assertRaises(ValueError): @@ -155,7 +159,12 @@ class TestCLI(unittest.TestCase): "update", ] archs = ["csx", "tx2", "zen1", "rv64"] - comps = {"csx": ["gcc", "icc"], "tx2": ["gcc", "clang"], "zen1": ["gcc"], "rv64": ["gcc"]} + comps = { + "csx": ["gcc", "icc"], + "tx2": ["gcc", "clang"], + "zen1": ["gcc"], + "rv64": ["gcc"], + } parser = osaca.create_parser() # Analyze all asm files resulting out of kernels, archs and comps for k in kernels: @@ -256,7 +265,9 @@ class TestCLI(unittest.TestCase): # Run tests with --lines option parser = osaca.create_parser() kernel_x86 = "triad_x86_iaca.s" - args_base = parser.parse_args(["--arch", "csx", self._find_test_file(kernel_x86)]) + args_base = parser.parse_args( + ["--arch", "csx", self._find_test_file(kernel_x86)] + ) output_base = StringIO() osaca.run(args_base, output_file=output_base) output_base = output_base.getvalue().split("\n")[8:] @@ -307,11 +318,13 @@ class TestCLI(unittest.TestCase): @staticmethod def _find_file(kernel, arch, comp): testdir = os.path.dirname(__file__) + # Handle special case for rv64 architecture + arch_prefix = arch.lower() if arch.lower() == "rv64" else arch[:3].lower() name = os.path.join( testdir, "../examples", kernel, - kernel + ".s." + arch[:3].lower() + "." + comp.lower() + ".s", + kernel + ".s." + arch_prefix + "." + comp.lower() + ".s", ) if kernel == "j2d" and arch.lower() == "csx": name = name[:-1] + "AVX.s" diff --git a/tests/test_db_interface.py b/tests/test_db_interface.py index 4ff6c0c..2c68759 100755 --- a/tests/test_db_interface.py +++ b/tests/test_db_interface.py @@ -34,7 +34,19 @@ class TestDBInterface(unittest.TestCase): self.entry_zen1 = copy.copy(sample_entry) self.entry_rv64 = copy.copy(sample_entry) - self.entry_csx.port_pressure = [1.25, 0, 1.25, 0.5, 0.5, 0.5, 0.5, 0, 1.25, 1.25, 0] + self.entry_csx.port_pressure = [ + 1.25, + 0, + 1.25, + 0.5, + 0.5, + 0.5, + 0.5, + 0, + 1.25, + 1.25, + 0, + ] self.entry_csx.port_pressure = [[5, "0156"], [1, "23"], [1, ["2D", "3D"]]] self.entry_tx2.port_pressure = [2.5, 2.5, 0, 0, 0.5, 0.5] self.entry_tx2.port_pressure = [[5, "01"], [1, "45"]] @@ -49,8 +61,15 @@ class TestDBInterface(unittest.TestCase): ] # For RV64, adapt to match its port structure self.entry_rv64.port_pressure = [1, 1, 1, 1] # [ALU, MEM, DIV, FP] - self.entry_rv64.port_pressure = [[1, ["ALU"]], [1, ["MEM"]], [1, ["DIV"]], [1, ["FP"]]] - self.entry_rv64.operands[1].prefix = "f" # Using f prefix for floating point registers + self.entry_rv64.port_pressure = [ + [1, ["ALU"]], + [1, ["MEM"]], + [1, ["DIV"]], + [1, ["FP"]], + ] + self.entry_rv64.operands[1].prefix = ( + "f" # Using f prefix for floating point registers + ) self.entry_rv64.operands[1].name = "1" ########### diff --git a/tests/test_frontend.py b/tests/test_frontend.py index adc25ec..5fa9821 100755 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -22,14 +22,14 @@ class TestFrontend(unittest.TestCase): self.parser_x86 = ParserX86ATT() self.parser_AArch64 = ParserAArch64() self.parser_RISCV = ParserRISCV() - + with open(self._find_file("kernel_x86.s")) as f: code_x86 = f.read() with open(self._find_file("kernel_aarch64.s")) as f: code_AArch64 = f.read() with open(self._find_file("kernel_riscv.s")) as f: code_RISCV = f.read() - + self.kernel_x86 = self.parser_x86.parse_file(code_x86) self.kernel_AArch64 = self.parser_AArch64.parse_file(code_AArch64) self.kernel_RISCV = self.parser_RISCV.parse_file(code_RISCV) @@ -40,7 +40,7 @@ class TestFrontend(unittest.TestCase): ) self.machine_model_tx2 = MachineModel(arch="tx2") self.machine_model_rv64 = MachineModel(arch="rv64") - + self.semantics_csx = ArchSemantics( self.parser_x86, self.machine_model_csx, @@ -79,7 +79,9 @@ class TestFrontend(unittest.TestCase): with self.assertRaises(ValueError): Frontend() with self.assertRaises(ValueError): - Frontend(arch="csx", path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml")) + Frontend( + arch="csx", path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml") + ) with self.assertRaises(FileNotFoundError): Frontend(path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "THE_MACHINE.yml")) with self.assertRaises(FileNotFoundError): @@ -88,7 +90,9 @@ class TestFrontend(unittest.TestCase): Frontend(arch="rv64") def test_frontend_x86(self): - dg = KernelDG(self.kernel_x86, self.parser_x86, self.machine_model_csx, self.semantics_csx) + dg = KernelDG( + self.kernel_x86, self.parser_x86, self.machine_model_csx, self.semantics_csx + ) fe = Frontend(path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml")) fe.throughput_analysis(self.kernel_x86, show_cmnts=False) fe.latency_analysis(dg.get_critical_path()) @@ -117,7 +121,9 @@ class TestFrontend(unittest.TestCase): # TODO compare output with checked string def test_dict_output_x86(self): - dg = KernelDG(self.kernel_x86, self.parser_x86, self.machine_model_csx, self.semantics_csx) + dg = KernelDG( + self.kernel_x86, self.parser_x86, self.machine_model_csx, self.semantics_csx + ) fe = Frontend(path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml")) analysis_dict = fe.full_analysis_dict(self.kernel_x86, dg) self.assertEqual(len(self.kernel_x86), len(analysis_dict["Kernel"])) @@ -131,7 +137,9 @@ class TestFrontend(unittest.TestCase): ) self.assertEqual(line.latency_cp, analysis_dict["Kernel"][i]["LatencyCP"]) self.assertEqual(line.mnemonic, analysis_dict["Kernel"][i]["Instruction"]) - self.assertEqual(len(line.operands), len(analysis_dict["Kernel"][i]["Operands"])) + self.assertEqual( + len(line.operands), len(analysis_dict["Kernel"][i]["Operands"]) + ) self.assertEqual( len(line.semantic_operands["source"]), len(analysis_dict["Kernel"][i]["SemanticOperands"]["source"]), @@ -168,7 +176,9 @@ class TestFrontend(unittest.TestCase): ) self.assertEqual(line.latency_cp, analysis_dict["Kernel"][i]["LatencyCP"]) self.assertEqual(line.mnemonic, analysis_dict["Kernel"][i]["Instruction"]) - self.assertEqual(len(line.operands), len(analysis_dict["Kernel"][i]["Operands"])) + self.assertEqual( + len(line.operands), len(analysis_dict["Kernel"][i]["Operands"]) + ) self.assertEqual( len(line.semantic_operands["source"]), len(analysis_dict["Kernel"][i]["SemanticOperands"]["source"]), @@ -206,7 +216,9 @@ class TestFrontend(unittest.TestCase): ) self.assertEqual(line.latency_cp, analysis_dict["Kernel"][i]["LatencyCP"]) self.assertEqual(line.mnemonic, analysis_dict["Kernel"][i]["Instruction"]) - self.assertEqual(len(line.operands), len(analysis_dict["Kernel"][i]["Operands"])) + self.assertEqual( + len(line.operands), len(analysis_dict["Kernel"][i]["Operands"]) + ) self.assertEqual( len(line.semantic_operands["source"]), len(analysis_dict["Kernel"][i]["SemanticOperands"]["source"]), diff --git a/tests/test_marker_utils.py b/tests/test_marker_utils.py index bcd05c1..c607f72 100755 --- a/tests/test_marker_utils.py +++ b/tests/test_marker_utils.py @@ -22,7 +22,7 @@ class TestMarkerUtils(unittest.TestCase): self.parser_x86_att = ParserX86ATT() self.parser_x86_intel = ParserX86Intel() self.parser_RISCV = ParserRISCV() - + with open(self._find_file("triad_arm_iaca.s")) as f: triad_code_arm = f.read() with open(self._find_file("triad_x86_iaca.s")) as f: @@ -31,7 +31,7 @@ class TestMarkerUtils(unittest.TestCase): triad_code_x86_intel = f.read() with open(self._find_file("kernel_riscv.s")) as f: kernel_code_riscv = f.read() - + self.parsed_AArch = self.parser_AArch.parse_file(triad_code_arm) self.parsed_x86_att = self.parser_x86_att.parse_file(triad_code_x86_att) self.parsed_x86_intel = self.parser_x86_intel.parse_file(triad_code_x86_intel) @@ -74,7 +74,9 @@ class TestMarkerUtils(unittest.TestCase): bytes_3_lines_1 = ".byte 213,3\n" + ".byte 32\n" + ".byte 31\n" bytes_3_lines_2 = ".byte 213\n" + ".byte 3,32\n" + ".byte 31\n" bytes_3_lines_3 = ".byte 213\n" + ".byte 3\n" + ".byte 32,31\n" - bytes_4_lines = ".byte 213\n" + ".byte 3\n" + ".byte 32\n" + ".byte 31\n" + bytes_4_lines = ( + ".byte 213\n" + ".byte 3\n" + ".byte 32\n" + ".byte 31\n" + ) bytes_hex = ".byte 0xd5, 0x3, 0x20, 0x1f\n" bytes_mixed = ".byte 0xd5\n.byte 3,0x20\n.byte 31\n" mov_start_1 = "mov x1, #111\n" @@ -130,13 +132,17 @@ class TestMarkerUtils(unittest.TestCase): bytes_end=bytes_var_2, ): sample_parsed = self.parser_AArch.parse_file(sample_code) - sample_kernel = reduce_to_section(sample_parsed, ParserAArch64()) + sample_kernel = reduce_to_section( + sample_parsed, ParserAArch64() + ) self.assertEqual(len(sample_kernel), kernel_length) kernel_start = len( list( filter( None, - (prologue + mov_start_var + bytes_var_1).split("\n"), + (prologue + mov_start_var + bytes_var_1).split( + "\n" + ), ) ) ) @@ -202,13 +208,17 @@ class TestMarkerUtils(unittest.TestCase): bytes_end=bytes_var_2, ): sample_parsed = self.parser_x86_att.parse_file(sample_code) - sample_kernel = reduce_to_section(sample_parsed, ParserX86ATT()) + sample_kernel = reduce_to_section( + sample_parsed, ParserX86ATT() + ) self.assertEqual(len(sample_kernel), kernel_length) kernel_start = len( list( filter( None, - (prologue + mov_start_var + bytes_var_1).split("\n"), + (prologue + mov_start_var + bytes_var_1).split( + "\n" + ), ) ) ) @@ -245,7 +255,7 @@ class TestMarkerUtils(unittest.TestCase): ] li_start_variations = [li_start_1, li_start_2] li_end_variations = [li_end_1, li_end_2] - + # actual tests for RISC-V for li_start_var in li_start_variations: for bytes_var_1 in bytes_variations: @@ -267,13 +277,17 @@ class TestMarkerUtils(unittest.TestCase): bytes_end=bytes_var_2, ): sample_parsed = self.parser_RISCV.parse_file(sample_code) - sample_kernel = reduce_to_section(sample_parsed, ParserRISCV()) + sample_kernel = reduce_to_section( + sample_parsed, ParserRISCV() + ) self.assertEqual(len(sample_kernel), kernel_length) kernel_start = len( list( filter( None, - (prologue + li_start_var + bytes_var_1).split("\n"), + (prologue + li_start_var + bytes_var_1).split( + "\n" + ), ) ) ) @@ -323,7 +337,9 @@ class TestMarkerUtils(unittest.TestCase): kernel_start = len((pro).strip().split("\n")) else: kernel_start = 0 - parsed_kernel = self.parser_AArch.parse_file(kernel, start_line=kernel_start) + parsed_kernel = self.parser_AArch.parse_file( + kernel, start_line=kernel_start + ) self.assertEqual( test_kernel, parsed_kernel, @@ -334,7 +350,9 @@ class TestMarkerUtils(unittest.TestCase): bytes_line = ".byte 100\n" ".byte 103\n" ".byte 144\n" start_marker = "movl $111, %ebx\n" + bytes_line end_marker = "movl $222, %ebx\n" + bytes_line - prologue = "movl -88(%rbp), %r10d\n" "xorl %r11d, %r11d\n" ".p2align 4,,10\n" + prologue = ( + "movl -88(%rbp), %r10d\n" "xorl %r11d, %r11d\n" ".p2align 4,,10\n" + ) kernel = ( ".L3: #L3\n" "vmovsd .LC1(%rip), %xmm0\n" @@ -342,7 +360,9 @@ class TestMarkerUtils(unittest.TestCase): "cmpl %ecx, %ebx\n" "jle .L3\n" ) - epilogue = "leaq -56(%rbp), %rsi\n" "movl %r10d, -88(%rbp)\n" "call timing\n" + epilogue = ( + "leaq -56(%rbp), %rsi\n" "movl %r10d, -88(%rbp)\n" "call timing\n" + ) samples = [ # (test name, # ignored prologue, section to be extraced, ignored epilogue) @@ -371,7 +391,9 @@ class TestMarkerUtils(unittest.TestCase): kernel_start = len((pro).strip().split("\n")) else: kernel_start = 0 - parsed_kernel = self.parser_x86_att.parse_file(kernel, start_line=kernel_start) + parsed_kernel = self.parser_x86_att.parse_file( + kernel, start_line=kernel_start + ) self.assertEqual( test_kernel, @@ -422,7 +444,9 @@ class TestMarkerUtils(unittest.TestCase): kernel_start = len((pro).strip().split("\n")) else: kernel_start = 0 - parsed_kernel = self.parser_RISCV.parse_file(kernel, start_line=kernel_start) + parsed_kernel = self.parser_RISCV.parse_file( + kernel, start_line=kernel_start + ) self.assertEqual( test_kernel, parsed_kernel, @@ -490,7 +514,7 @@ class TestMarkerUtils(unittest.TestCase): ] ), ) - + # Check that find_jump_labels works for RISC-V riscv_labels = find_jump_labels(self.parsed_RISCV) self.assertIsInstance(riscv_labels, OrderedDict) @@ -559,7 +583,7 @@ class TestMarkerUtils(unittest.TestCase): ("main", 575, 590), ], ) - + # Check that find_basic_blocks works for RISC-V riscv_blocks = find_basic_blocks(self.parsed_RISCV) self.assertGreater(len(riscv_blocks), 0) @@ -587,7 +611,7 @@ class TestMarkerUtils(unittest.TestCase): (".LBB0_35", 494, 504), ], ) - + # Check that find_basic_loop_bodies works for RISC-V riscv_loop_bodies = find_basic_loop_bodies(self.parsed_RISCV) self.assertGreater(len(riscv_loop_bodies), 0) diff --git a/tests/test_parser_AArch64.py b/tests/test_parser_AArch64.py index 167dcfb..ea0a163 100755 --- a/tests/test_parser_AArch64.py +++ b/tests/test_parser_AArch64.py @@ -32,7 +32,9 @@ class TestParserAArch64(unittest.TestCase): ################## def test_comment_parser(self): - self.assertEqual(self._get_comment(self.parser, "// some comments"), "some comments") + self.assertEqual( + self._get_comment(self.parser, "// some comments"), "some comments" + ) self.assertEqual( self._get_comment(self.parser, "\t\t//AA BB CC \t end \t"), "AA BB CC end" ) @@ -44,8 +46,12 @@ class TestParserAArch64(unittest.TestCase): def test_label_parser(self): self.assertEqual(self._get_label(self.parser, "main:")[0].name, "main") self.assertEqual(self._get_label(self.parser, "..B1.10:")[0].name, "..B1.10") - self.assertEqual(self._get_label(self.parser, ".2.3_2_pack.3:")[0].name, ".2.3_2_pack.3") - self.assertEqual(self._get_label(self.parser, ".L1:\t\t\t//label1")[0].name, ".L1") + self.assertEqual( + self._get_label(self.parser, ".2.3_2_pack.3:")[0].name, ".2.3_2_pack.3" + ) + self.assertEqual( + self._get_label(self.parser, ".L1:\t\t\t//label1")[0].name, ".L1" + ) self.assertEqual( " ".join(self._get_label(self.parser, ".L1:\t\t\t//label1")[1]), "label1", @@ -55,29 +61,36 @@ class TestParserAArch64(unittest.TestCase): def test_directive_parser(self): self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text") - self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0) - self.assertEqual(self._get_directive(self.parser, "\t.align\t16,0x90")[0].name, "align") + self.assertEqual( + len(self._get_directive(self.parser, "\t.text")[0].parameters), 0 + ) + self.assertEqual( + self._get_directive(self.parser, "\t.align\t16,0x90")[0].name, "align" + ) self.assertEqual( len(self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters), 2 ) self.assertEqual( - self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters[1], "0x90" + self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters[1], + "0x90", ) self.assertEqual( - self._get_directive(self.parser, " .byte 100,103,144 //IACA START")[ - 0 - ].name, + self._get_directive( + self.parser, " .byte 100,103,144 //IACA START" + )[0].name, "byte", ) self.assertEqual( - self._get_directive(self.parser, " .byte 100,103,144 //IACA START")[ - 0 - ].parameters[2], + self._get_directive( + self.parser, " .byte 100,103,144 //IACA START" + )[0].parameters[2], "144", ) self.assertEqual( " ".join( - self._get_directive(self.parser, " .byte 100,103,144 //IACA START")[1] + self._get_directive( + self.parser, " .byte 100,103,144 //IACA START" + )[1] ), "IACA START", ) @@ -169,7 +182,9 @@ class TestParserAArch64(unittest.TestCase): self.assertEqual(parsed_8.operands[1].name, "16") self.assertEqual(parsed_8.operands[1].prefix, "v") self.assertEqual(parsed_8.operands[1].index, "1") - self.assertEqual(self.parser.get_full_reg_name(parsed_8.operands[1]), "v16.d[1]") + self.assertEqual( + self.parser.get_full_reg_name(parsed_8.operands[1]), "v16.d[1]" + ) self.assertEqual(parsed_9.mnemonic, "ccmp") self.assertEqual(parsed_9.operands[0].name, "0") @@ -215,7 +230,9 @@ class TestParserAArch64(unittest.TestCase): instruction_form_3 = InstructionForm( mnemonic=None, operands=[], - directive_id=DirectiveOperand(name="cfi_def_cfa", parameters=["w29", "-16"]), + directive_id=DirectiveOperand( + name="cfi_def_cfa", parameters=["w29", "-16"] + ), comment_id=None, label_id=None, line=".cfi_def_cfa w29, -16", @@ -359,13 +376,16 @@ class TestParserAArch64(unittest.TestCase): imd_type="float", value={"mantissa": "0.79", "e_sign": "+", "exponent": "2"} ) imd_float_12 = ImmediateOperand( - imd_type="float", value={"mantissa": "790.0", "e_sign": "-", "exponent": "1"} + imd_type="float", + value={"mantissa": "790.0", "e_sign": "-", "exponent": "1"}, ) imd_double_11 = ImmediateOperand( - imd_type="double", value={"mantissa": "0.79", "e_sign": "+", "exponent": "2"} + imd_type="double", + value={"mantissa": "0.79", "e_sign": "+", "exponent": "2"}, ) imd_double_12 = ImmediateOperand( - imd_type="double", value={"mantissa": "790.0", "e_sign": "-", "exponent": "1"} + imd_type="double", + value={"mantissa": "790.0", "e_sign": "-", "exponent": "1"}, ) identifier = IdentifierOperand(name="..B1.4") @@ -441,35 +461,45 @@ class TestParserAArch64(unittest.TestCase): for rj in regs: assert_value = True if rj in reg_1 else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_2: for rj in regs: assert_value = True if rj in reg_2 else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_v: for rj in regs: assert_value = True if rj in reg_v else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_others: for rj in regs: assert_value = True if rj == ri else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) ################## # Helper functions ################## def _get_comment(self, parser, comment): return " ".join( - parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[ - "comment" - ] + parser.process_operand( + parser.comment.parseString(comment, parseAll=True).asDict() + )["comment"] ) def _get_label(self, parser, label): - return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) + return parser.process_operand( + parser.label.parseString(label, parseAll=True).asDict() + ) def _get_directive(self, parser, directive): return parser.process_operand( diff --git a/tests/test_parser_RISCV.py b/tests/test_parser_RISCV.py index 0f0b621..a069cb5 100644 --- a/tests/test_parser_RISCV.py +++ b/tests/test_parser_RISCV.py @@ -8,9 +8,7 @@ import unittest from pyparsing import ParseException -from osaca.parser import ParserRISCV, InstructionForm -from osaca.parser.directive import DirectiveOperand -from osaca.parser.memory import MemoryOperand +from osaca.parser import ParserRISCV from osaca.parser.register import RegisterOperand from osaca.parser.immediate import ImmediateOperand from osaca.parser.identifier import IdentifierOperand @@ -28,9 +26,12 @@ class TestParserRISCV(unittest.TestCase): ################## def test_comment_parser(self): - self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments") self.assertEqual( - self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), "RISC-V comment end" + self._get_comment(self.parser, "# some comments"), "some comments" + ) + self.assertEqual( + self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), + "RISC-V comment end", ) self.assertEqual( self._get_comment(self.parser, "\t## comment ## comment"), @@ -39,9 +40,13 @@ class TestParserRISCV(unittest.TestCase): def test_label_parser(self): # Test common label patterns from kernel_riscv.s - self.assertEqual(self._get_label(self.parser, "saxpy_golden:")[0].name, "saxpy_golden") + self.assertEqual( + self._get_label(self.parser, "saxpy_golden:")[0].name, "saxpy_golden" + ) self.assertEqual(self._get_label(self.parser, ".L4:")[0].name, ".L4") - self.assertEqual(self._get_label(self.parser, ".L25:\t\t\t# Return")[0].name, ".L25") + self.assertEqual( + self._get_label(self.parser, ".L25:\t\t\t# Return")[0].name, ".L25" + ) self.assertEqual( " ".join(self._get_label(self.parser, ".L25:\t\t\t# Return")[1]), "Return", @@ -51,28 +56,42 @@ class TestParserRISCV(unittest.TestCase): def test_directive_parser(self): self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text") - self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0) - self.assertEqual(self._get_directive(self.parser, "\t.word\t1113498583")[0].name, "word") self.assertEqual( - len(self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters), 1 + len(self._get_directive(self.parser, "\t.text")[0].parameters), 0 ) self.assertEqual( - self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters[0], "1113498583" + self._get_directive(self.parser, "\t.word\t1113498583")[0].name, "word" + ) + self.assertEqual( + len(self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters), + 1, + ) + self.assertEqual( + self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters[0], + "1113498583", ) # Test string directive self.assertEqual( - self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].name, "string" + self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].name, + "string", ) self.assertEqual( - self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].parameters[0], - '"fail, %f=!%f\\n"' + self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].parameters[ + 0 + ], + '"fail, %f=!%f\\n"', ) # Test set directive self.assertEqual( self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].name, "set" ) self.assertEqual( - len(self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].parameters), 2 + len( + self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[ + 0 + ].parameters + ), + 2, ) def test_parse_instruction(self): @@ -105,7 +124,7 @@ class TestParserRISCV(unittest.TestCase): # Test 1: Line with label and instruction parsed_1 = self.parser.parse_line(".L2:") self.assertEqual(parsed_1.label, ".L2") - + # Test 2: Line with instruction and comment parsed_2 = self.parser.parse_line("addi x10, x10, 1 # increment") self.assertEqual(parsed_2.mnemonic, "addi") @@ -118,24 +137,33 @@ class TestParserRISCV(unittest.TestCase): def test_parse_file(self): parsed = self.parser.parse_file(self.riscv_code) self.assertGreater(len(parsed), 10) # There should be multiple lines - + # Find common elements that should exist in any RISC-V file # without being tied to specific line numbers - + # Verify that we can find at least one label label_forms = [form for form in parsed if form.label is not None] self.assertGreater(len(label_forms), 0, "No labels found in the file") - + # Verify that we can find at least one branch instruction - branch_forms = [form for form in parsed if form.mnemonic and form.mnemonic.startswith("b")] - self.assertGreater(len(branch_forms), 0, "No branch instructions found in the file") - + branch_forms = [ + form for form in parsed if form.mnemonic and form.mnemonic.startswith("b") + ] + self.assertGreater( + len(branch_forms), 0, "No branch instructions found in the file" + ) + # Verify that we can find at least one store/load instruction - mem_forms = [form for form in parsed if form.mnemonic and ( - form.mnemonic.startswith("s") or - form.mnemonic.startswith("l"))] - self.assertGreater(len(mem_forms), 0, "No memory instructions found in the file") - + mem_forms = [ + form + for form in parsed + if form.mnemonic + and (form.mnemonic.startswith("s") or form.mnemonic.startswith("l")) + ] + self.assertGreater( + len(mem_forms), 0, "No memory instructions found in the file" + ) + # Verify that we can find at least one directive directive_forms = [form for form in parsed if form.directive is not None] self.assertGreater(len(directive_forms), 0, "No directives found in the file") @@ -148,15 +176,17 @@ class TestParserRISCV(unittest.TestCase): reg_a0 = RegisterOperand(name="a0") reg_t1 = RegisterOperand(name="t1") reg_s2 = RegisterOperand(name="s2") - + reg_x0 = RegisterOperand(prefix="x", name="0") reg_x1 = RegisterOperand(prefix="x", name="1") reg_x2 = RegisterOperand(prefix="x", name="2") - reg_x5 = RegisterOperand(prefix="x", name="5") # Define reg_x5 for use in tests below + reg_x5 = RegisterOperand( + prefix="x", name="5" + ) # Define reg_x5 for use in tests below reg_x10 = RegisterOperand(prefix="x", name="10") reg_x6 = RegisterOperand(prefix="x", name="6") reg_x18 = RegisterOperand(prefix="x", name="18") - + # Test canonical name conversion self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0") self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1") @@ -164,7 +194,7 @@ class TestParserRISCV(unittest.TestCase): self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10") self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6") self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18") - + # Test register dependency self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0)) self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1)) @@ -172,29 +202,27 @@ class TestParserRISCV(unittest.TestCase): self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10)) self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6)) self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18)) - + # Test non-dependent registers self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1)) self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2)) self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1)) - + # Test floating-point registers reg_fa0 = RegisterOperand(prefix="f", name="a0") - reg_fa1 = RegisterOperand(prefix="f", name="a1") reg_f10 = RegisterOperand(prefix="f", name="10") - + # Test vector registers reg_v1 = RegisterOperand(prefix="v", name="1") - reg_v2 = RegisterOperand(prefix="v", name="2") - + # Test register type detection self.assertTrue(self.parser.is_gpr(reg_a0)) self.assertTrue(self.parser.is_gpr(reg_x5)) self.assertTrue(self.parser.is_gpr(reg_sp)) - + self.assertFalse(self.parser.is_gpr(reg_fa0)) self.assertFalse(self.parser.is_gpr(reg_f10)) - + self.assertTrue(self.parser.is_vector_register(reg_v1)) self.assertFalse(self.parser.is_vector_register(reg_x10)) self.assertFalse(self.parser.is_vector_register(reg_fa0)) @@ -235,13 +263,15 @@ class TestParserRISCV(unittest.TestCase): ################## def _get_comment(self, parser, comment): return " ".join( - parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[ - "comment" - ] + parser.process_operand( + parser.comment.parseString(comment, parseAll=True).asDict() + )["comment"] ) def _get_label(self, parser, label): - return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) + return parser.process_operand( + parser.label.parseString(label, parseAll=True).asDict() + ) def _get_directive(self, parser, directive): return parser.process_operand( @@ -258,4 +288,4 @@ class TestParserRISCV(unittest.TestCase): if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromTestCase(TestParserRISCV) - unittest.TextTestRunner(verbosity=2).run(suite) \ No newline at end of file + unittest.TextTestRunner(verbosity=2).run(suite) diff --git a/tests/test_parser_x86att.py b/tests/test_parser_x86att.py index b7a1459..4c23a0a 100755 --- a/tests/test_parser_x86att.py +++ b/tests/test_parser_x86att.py @@ -25,8 +25,12 @@ class TestParserX86ATT(unittest.TestCase): ################## def test_comment_parser(self): - self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments") - self.assertEqual(self._get_comment(self.parser, "\t\t#AA BB CC \t end \t"), "AA BB CC end") + self.assertEqual( + self._get_comment(self.parser, "# some comments"), "some comments" + ) + self.assertEqual( + self._get_comment(self.parser, "\t\t#AA BB CC \t end \t"), "AA BB CC end" + ) self.assertEqual( self._get_comment(self.parser, "\t## comment ## comment"), "# comment ## comment", @@ -35,8 +39,12 @@ class TestParserX86ATT(unittest.TestCase): def test_label_parser(self): self.assertEqual(self._get_label(self.parser, "main:")[0].name, "main") self.assertEqual(self._get_label(self.parser, "..B1.10:")[0].name, "..B1.10") - self.assertEqual(self._get_label(self.parser, ".2.3_2_pack.3:")[0].name, ".2.3_2_pack.3") - self.assertEqual(self._get_label(self.parser, ".L1:\t\t\t#label1")[0].name, ".L1") + self.assertEqual( + self._get_label(self.parser, ".2.3_2_pack.3:")[0].name, ".2.3_2_pack.3" + ) + self.assertEqual( + self._get_label(self.parser, ".L1:\t\t\t#label1")[0].name, ".L1" + ) self.assertEqual( " ".join(self._get_label(self.parser, ".L1:\t\t\t#label1")[1]), "label1", @@ -46,22 +54,36 @@ class TestParserX86ATT(unittest.TestCase): def test_directive_parser(self): self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text") - self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0) - self.assertEqual(self._get_directive(self.parser, "\t.align\t16,0x90")[0].name, "align") + self.assertEqual( + len(self._get_directive(self.parser, "\t.text")[0].parameters), 0 + ) + self.assertEqual( + self._get_directive(self.parser, "\t.align\t16,0x90")[0].name, "align" + ) self.assertEqual( len(self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters), 2 ) - self.assertEqual(len(self._get_directive(self.parser, ".text")[0].parameters), 0) self.assertEqual( - len(self._get_directive(self.parser, '.file\t1 "path/to/file.c"')[0].parameters), + len(self._get_directive(self.parser, ".text")[0].parameters), 0 + ) + self.assertEqual( + len( + self._get_directive(self.parser, '.file\t1 "path/to/file.c"')[ + 0 + ].parameters + ), 2, ) self.assertEqual( - self._get_directive(self.parser, '.file\t1 "path/to/file.c"')[0].parameters[1], + self._get_directive(self.parser, '.file\t1 "path/to/file.c"')[0].parameters[ + 1 + ], '"path/to/file.c"', ) self.assertEqual( - self._get_directive(self.parser, "\t.set\tL$set$0,LECIE1-LSCIE1")[0].parameters, + self._get_directive(self.parser, "\t.set\tL$set$0,LECIE1-LSCIE1")[ + 0 + ].parameters, ["L$set$0", "LECIE1-LSCIE1"], ) self.assertEqual( @@ -77,29 +99,32 @@ class TestParserX86ATT(unittest.TestCase): ], ) self.assertEqual( - self._get_directive(self.parser, "\t.section\t__TEXT,__literal16,16byte_literals")[ - 0 - ].parameters, + self._get_directive( + self.parser, "\t.section\t__TEXT,__literal16,16byte_literals" + )[0].parameters, ["__TEXT", "__literal16", "16byte_literals"], ) self.assertEqual( - self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters[1], "0x90" + self._get_directive(self.parser, "\t.align\t16,0x90")[0].parameters[1], + "0x90", ) self.assertEqual( - self._get_directive(self.parser, " .byte 100,103,144 #IACA START")[ - 0 - ].name, + self._get_directive( + self.parser, " .byte 100,103,144 #IACA START" + )[0].name, "byte", ) self.assertEqual( - self._get_directive(self.parser, " .byte 100,103,144 #IACA START")[ - 0 - ].parameters[2], + self._get_directive( + self.parser, " .byte 100,103,144 #IACA START" + )[0].parameters[2], "144", ) self.assertEqual( " ".join( - self._get_directive(self.parser, " .byte 100,103,144 #IACA START")[1] + self._get_directive( + self.parser, " .byte 100,103,144 #IACA START" + )[1] ), "IACA START", ) @@ -175,7 +200,8 @@ class TestParserX86ATT(unittest.TestCase): parsed_8.operands[1].segment_ext[0]["offset"]["identifier"]["name"], "var" ) self.assertEqual( - parsed_8.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], "@TPOFF" + parsed_8.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], + "@TPOFF", ) self.assertEqual(parsed_9.mnemonic, "movq") @@ -185,10 +211,12 @@ class TestParserX86ATT(unittest.TestCase): parsed_9.operands[1].segment_ext[0]["offset"]["identifier"]["name"], "var" ) self.assertEqual( - parsed_9.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], "@TPOFF" + parsed_9.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], + "@TPOFF", ) self.assertEqual( - parsed_9.operands[1].segment_ext[0]["offset"]["identifier"]["offset"][0], "-8" + parsed_9.operands[1].segment_ext[0]["offset"]["identifier"]["offset"][0], + "-8", ) self.assertEqual(parsed_9.operands[1].segment_ext[0]["base"]["name"], "rcx") @@ -199,10 +227,12 @@ class TestParserX86ATT(unittest.TestCase): parsed_10.operands[1].segment_ext[0]["offset"]["identifier"]["name"], "var" ) self.assertEqual( - parsed_10.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], "@TPOFF" + parsed_10.operands[1].segment_ext[0]["offset"]["identifier"]["relocation"], + "@TPOFF", ) self.assertEqual( - parsed_10.operands[1].segment_ext[0]["offset"]["identifier"]["offset"][0], "4992" + parsed_10.operands[1].segment_ext[0]["offset"]["identifier"]["offset"][0], + "4992", ) def test_parse_line(self): @@ -333,35 +363,45 @@ class TestParserX86ATT(unittest.TestCase): for rj in regs: assert_value = True if rj in reg_a else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_r: for rj in regs: assert_value = True if rj in reg_r else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_vec_1: for rj in regs: assert_value = True if rj in reg_vec_1 else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) for ri in reg_others: for rj in regs: assert_value = True if rj == ri else False with self.subTest(reg_a=ri, reg_b=rj, assert_val=assert_value): - self.assertEqual(self.parser.is_reg_dependend_of(ri, rj), assert_value) + self.assertEqual( + self.parser.is_reg_dependend_of(ri, rj), assert_value + ) ################## # Helper functions ################## def _get_comment(self, parser, comment): return " ".join( - parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[ - "comment" - ] + parser.process_operand( + parser.comment.parseString(comment, parseAll=True).asDict() + )["comment"] ) def _get_label(self, parser, label): - return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) + return parser.process_operand( + parser.label.parseString(label, parseAll=True).asDict() + ) def _get_directive(self, parser, directive): return parser.process_operand( diff --git a/tests/test_parser_x86intel.py b/tests/test_parser_x86intel.py index 1918810..efb7a46 100755 --- a/tests/test_parser_x86intel.py +++ b/tests/test_parser_x86intel.py @@ -33,8 +33,12 @@ class TestParserX86Intel(unittest.TestCase): ################## def test_comment_parser(self): - self.assertEqual(self._get_comment(self.parser, "; some comments"), "some comments") - self.assertEqual(self._get_comment(self.parser, "\t\t;AA BB CC \t end \t"), "AA BB CC end") + self.assertEqual( + self._get_comment(self.parser, "; some comments"), "some comments" + ) + self.assertEqual( + self._get_comment(self.parser, "\t\t;AA BB CC \t end \t"), "AA BB CC end" + ) self.assertEqual( self._get_comment(self.parser, "\t;; comment ;; comment"), "; comment ;; comment", @@ -44,11 +48,15 @@ class TestParserX86Intel(unittest.TestCase): self.assertEqual(self._get_label(self.parser, "main:")[0].name, "main") self.assertEqual(self._get_label(self.parser, "$$B1?10:")[0].name, "$$B1?10") self.assertEqual( - self._get_label(self.parser, "$LN9:\tcall\t__CheckForDebuggerJustMyCode")[0].name, + self._get_label(self.parser, "$LN9:\tcall\t__CheckForDebuggerJustMyCode")[ + 0 + ].name, "$LN9", ) self.assertEqual( - self._get_label(self.parser, "$LN9:\tcall\t__CheckForDebuggerJustMyCode")[1], + self._get_label(self.parser, "$LN9:\tcall\t__CheckForDebuggerJustMyCode")[ + 1 + ], InstructionForm( mnemonic="call", operands=[ @@ -81,7 +89,9 @@ class TestParserX86Intel(unittest.TestCase): ) self.assertEqual( self._get_directive(self.parser, "$pdata$kernel DD imagerel $LN9")[0], - DirectiveOperand(name="DD", parameters=["$pdata$kernel", "imagerel", "$LN9"]), + DirectiveOperand( + name="DD", parameters=["$pdata$kernel", "imagerel", "$LN9"] + ), ) self.assertEqual( self._get_directive(self.parser, "repeat$ = 320")[0], @@ -133,7 +143,9 @@ class TestParserX86Intel(unittest.TestCase): self.assertEqual( parsed_3.operands[1], MemoryOperand( - base=RegisterOperand(name="RDX"), index=RegisterOperand(name="RCX"), scale=8 + base=RegisterOperand(name="RDX"), + index=RegisterOperand(name="RCX"), + scale=8, ), ) @@ -149,7 +161,9 @@ class TestParserX86Intel(unittest.TestCase): self.assertEqual(parsed_5.mnemonic, "mov") self.assertEqual( parsed_5.operands[0], - MemoryOperand(offset=ImmediateOperand(value=24), base=RegisterOperand(name="RSP")), + MemoryOperand( + offset=ImmediateOperand(value=24), base=RegisterOperand(name="RSP") + ), ) self.assertEqual(parsed_5.operands[1], RegisterOperand(name="R8")) @@ -166,7 +180,9 @@ class TestParserX86Intel(unittest.TestCase): self.assertEqual(parsed_8.mnemonic, "mov") self.assertEqual( parsed_8.operands[0], - MemoryOperand(base=RegisterOperand(name="GS"), offset=ImmediateOperand(value=111)), + MemoryOperand( + base=RegisterOperand(name="GS"), offset=ImmediateOperand(value=111) + ), ) self.assertEqual(parsed_8.operands[1], RegisterOperand(name="AL")) @@ -193,7 +209,9 @@ class TestParserX86Intel(unittest.TestCase): self.assertEqual( parsed_11.operands[1], MemoryOperand( - offset=IdentifierOperand(name="??_R0N@8", offset=ImmediateOperand(value=8)) + offset=IdentifierOperand( + name="??_R0N@8", offset=ImmediateOperand(value=8) + ) ), ) @@ -206,7 +224,9 @@ class TestParserX86Intel(unittest.TestCase): ) self.assertEqual(parsed_13.mnemonic, "jmp") - self.assertEqual(parsed_13.operands[0], IdentifierOperand(name="$LN18@operator")) + self.assertEqual( + parsed_13.operands[0], IdentifierOperand(name="$LN18@operator") + ) self.assertEqual(parsed_14.mnemonic, "vaddsd") self.assertEqual(parsed_14.operands[0], RegisterOperand(name="XMM0")) @@ -294,7 +314,9 @@ class TestParserX86Intel(unittest.TestCase): InstructionForm( mnemonic="mov", operands=[ - MemoryOperand(base=RegisterOperand("RSP"), offset=ImmediateOperand(value=8)), + MemoryOperand( + base=RegisterOperand("RSP"), offset=ImmediateOperand(value=8) + ), RegisterOperand(name="RCX"), ], line="\tmov\tQWORD PTR [rsp+8], rcx", @@ -347,7 +369,8 @@ class TestParserX86Intel(unittest.TestCase): ), ], comment_id="26.19", - line=" vmovsd xmm5, QWORD PTR [16+r11+r10]" + " #26.19", + line=" vmovsd xmm5, QWORD PTR [16+r11+r10]" + + " #26.19", line_number=114, ), ) @@ -419,14 +442,18 @@ class TestParserX86Intel(unittest.TestCase): ################## def _get_comment(self, parser, comment): return " ".join( - parser.process_operand(parser.comment.parseString(comment, parseAll=True))["comment"] + parser.process_operand(parser.comment.parseString(comment, parseAll=True))[ + "comment" + ] ) def _get_label(self, parser, label): return parser.process_operand(parser.label.parseString(label, parseAll=True)) def _get_directive(self, parser, directive): - return parser.process_operand(parser.directive.parseString(directive, parseAll=True)) + return parser.process_operand( + parser.directive.parseString(directive, parseAll=True) + ) @staticmethod def _find_file(name): diff --git a/tests/test_semantics.py b/tests/test_semantics.py index 4f001a0..0d3859f 100755 --- a/tests/test_semantics.py +++ b/tests/test_semantics.py @@ -70,7 +70,8 @@ class TestSemanticTools(unittest.TestCase): cls.parser_x86_intel.parse_file(cls.code_x86_intel), cls.parser_x86_intel ) cls.kernel_x86_intel_memdep = reduce_to_section( - cls.parser_x86_intel.parse_file(cls.code_x86_intel_memdep), cls.parser_x86_intel + cls.parser_x86_intel.parse_file(cls.code_x86_intel_memdep), + cls.parser_x86_intel, ) cls.kernel_AArch64 = reduce_to_section( cls.parser_AArch64.parse_file(cls.code_AArch64), cls.parser_AArch64 @@ -194,7 +195,9 @@ class TestSemanticTools(unittest.TestCase): RegisterOperand(name="xmm"), ] instr_form_x86_1 = test_mm_x86.get_instruction(name_x86_1, operands_x86_1) - self.assertEqual(instr_form_x86_1, test_mm_x86.get_instruction(name_x86_1, operands_x86_1)) + self.assertEqual( + instr_form_x86_1, test_mm_x86.get_instruction(name_x86_1, operands_x86_1) + ) self.assertEqual( test_mm_x86.get_instruction("jg", [IdentifierOperand()]), test_mm_x86.get_instruction("jg", [IdentifierOperand()]), @@ -206,20 +209,26 @@ class TestSemanticTools(unittest.TestCase): RegisterOperand(prefix="v", shape="s"), ] instr_form_arm_1 = test_mm_arm.get_instruction(name_arm_1, operands_arm_1) - self.assertEqual(instr_form_arm_1, test_mm_arm.get_instruction(name_arm_1, operands_arm_1)) + self.assertEqual( + instr_form_arm_1, test_mm_arm.get_instruction(name_arm_1, operands_arm_1) + ) self.assertEqual( test_mm_arm.get_instruction("b.ne", [IdentifierOperand()]), test_mm_arm.get_instruction("b.ne", [IdentifierOperand()]), ) self.assertEqual( - test_mm_arm.get_instruction("b.someNameThatDoesNotExist", [IdentifierOperand()]), + test_mm_arm.get_instruction( + "b.someNameThatDoesNotExist", [IdentifierOperand()] + ), test_mm_arm.get_instruction("b.someOtherName", [IdentifierOperand()]), ) # test get_store_tp self.assertEqual( test_mm_x86.get_store_throughput( - MemoryOperand(base=RegisterOperand(name="x"), offset=None, index=None, scale=1) + MemoryOperand( + base=RegisterOperand(name="x"), offset=None, index=None, scale=1 + ) )[0][1], [[2, "237"], [2, "4"]], ) @@ -263,7 +272,9 @@ class TestSemanticTools(unittest.TestCase): # test get_store_lt self.assertEqual( test_mm_x86.get_store_latency( - MemoryOperand(base=RegisterOperand(name="x"), offset=None, index=None, scale=1) + MemoryOperand( + base=RegisterOperand(name="x"), offset=None, index=None, scale=1 + ) ), 0, ) @@ -285,7 +296,9 @@ class TestSemanticTools(unittest.TestCase): # test default load tp self.assertEqual( test_mm_x86.get_load_throughput( - MemoryOperand(base=RegisterOperand(name="x"), offset=None, index=None, scale=1) + MemoryOperand( + base=RegisterOperand(name="x"), offset=None, index=None, scale=1 + ) )[0][1], [[1, "23"], [1, ["2D", "3D"]]], ) @@ -379,8 +392,12 @@ class TestSemanticTools(unittest.TestCase): tmp_semantics.assign_optimal_throughput(tmp_kernel_2) k1i1_pp = [round(x, 2) for x in tmp_kernel_1[0].port_pressure] k2i1_pp = [round(x, 2) for x in tmp_kernel_2[0].port_pressure] - self.assertEqual(k1i1_pp, [0.33, 0.0, 0.33, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0]) - self.assertEqual(k2i1_pp, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]) + self.assertEqual( + k1i1_pp, [0.33, 0.0, 0.33, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0] + ) + self.assertEqual( + k2i1_pp, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] + ) def test_optimal_throughput_assignment_x86_intel(self): kernel_fixed = deepcopy(self.kernel_x86_intel) @@ -408,8 +425,12 @@ class TestSemanticTools(unittest.TestCase): tmp_semantics.assign_optimal_throughput(tmp_kernel_2) k1i1_pp = [round(x, 2) for x in tmp_kernel_1[0].port_pressure] k2i1_pp = [round(x, 2) for x in tmp_kernel_2[0].port_pressure] - self.assertEqual(k1i1_pp, [0.33, 0.0, 0.33, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0]) - self.assertEqual(k2i1_pp, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]) + self.assertEqual( + k1i1_pp, [0.33, 0.0, 0.33, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33, 0.0, 0.0] + ) + self.assertEqual( + k2i1_pp, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] + ) def test_optimal_throughput_assignment_AArch64(self): kernel_fixed = deepcopy(self.kernel_AArch64) @@ -433,16 +454,27 @@ class TestSemanticTools(unittest.TestCase): # 5_______>9 # dg = KernelDG( - self.kernel_x86, self.parser_x86_att, self.machine_model_csx, self.semantics_csx + self.kernel_x86, + self.parser_x86_att, + self.machine_model_csx, + self.semantics_csx, ) self.assertTrue(nx.algorithms.dag.is_directed_acyclic_graph(dg.dg)) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=3))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=3))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=3)), 6) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=4))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=4))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=4)), 6) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=5))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=5))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=5)), 9) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=6))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=6))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=6)), 7) self.assertEqual(list(dg.get_dependent_instruction_forms(line_number=7)), []) self.assertEqual(list(dg.get_dependent_instruction_forms(line_number=8)), []) @@ -467,13 +499,21 @@ class TestSemanticTools(unittest.TestCase): self.semantics_csx_intel, ) self.assertTrue(nx.algorithms.dag.is_directed_acyclic_graph(dg.dg)) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=3))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=3))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=3)), 5) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=4))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=4))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=4)), 5) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=5))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=5))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=5)), 6) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=5.1))), 1) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=5.1))), 1 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=5.1)), 5) self.assertEqual(list(dg.get_dependent_instruction_forms(line_number=6)), []) self.assertEqual(list(dg.get_dependent_instruction_forms(line_number=7)), []) @@ -492,7 +532,9 @@ class TestSemanticTools(unittest.TestCase): ) self.assertTrue(nx.algorithms.dag.is_directed_acyclic_graph(dg.dg)) self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=3)), {6, 8}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=5)), {10, 12}) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=5)), {10, 12} + ) with self.assertRaises(ValueError): dg.get_dependent_instruction_forms() # test dot creation @@ -507,7 +549,9 @@ class TestSemanticTools(unittest.TestCase): ) self.assertTrue(nx.algorithms.dag.is_directed_acyclic_graph(dg.dg)) self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=3)), {6, 8}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=5)), {10, 12}) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=5)), {10, 12} + ) with self.assertRaises(ValueError): dg.get_dependent_instruction_forms() # test dot creation @@ -522,23 +566,41 @@ class TestSemanticTools(unittest.TestCase): ) self.assertTrue(nx.algorithms.dag.is_directed_acyclic_graph(dg.dg)) self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=3)), {7, 8}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=4)), {9, 10}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=5)), {6, 7, 8}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=6)), {9, 10}) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=4)), {9, 10} + ) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=5)), {6, 7, 8} + ) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=6)), {9, 10} + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=7)), 13) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=8)), 14) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=9)), 16) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=10)), 17) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=11)), {13, 14}) - self.assertEqual(set(dg.get_dependent_instruction_forms(line_number=12)), {16, 17}) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=11)), {13, 14} + ) + self.assertEqual( + set(dg.get_dependent_instruction_forms(line_number=12)), {16, 17} + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=13)), 15) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=14)), 15) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=15))), 0) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=15))), 0 + ) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=16)), 18) self.assertEqual(next(dg.get_dependent_instruction_forms(line_number=17)), 18) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=18))), 0) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=19))), 0) - self.assertEqual(len(list(dg.get_dependent_instruction_forms(line_number=20))), 0) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=18))), 0 + ) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=19))), 0 + ) + self.assertEqual( + len(list(dg.get_dependent_instruction_forms(line_number=20))), 0 + ) with self.assertRaises(ValueError): dg.get_dependent_instruction_forms() # test dot creation @@ -586,19 +648,33 @@ class TestSemanticTools(unittest.TestCase): for instruction_form in self.kernel_mops_1[:-1]: with self.subTest(instruction_form=instruction_form): if not instruction_form.line.startswith("//"): - self.assertTrue(mops_dest in instruction_form.semantic_operands["destination"]) - self.assertTrue(mops_src in instruction_form.semantic_operands["source"]) - self.assertTrue(mops_n in instruction_form.semantic_operands["src_dst"]) + self.assertTrue( + mops_dest in instruction_form.semantic_operands["destination"] + ) + self.assertTrue( + mops_src in instruction_form.semantic_operands["source"] + ) + self.assertTrue( + mops_n in instruction_form.semantic_operands["src_dst"] + ) self.assertTrue( mops_dest.base in instruction_form.semantic_operands["src_dst"] ) - self.assertTrue(mops_src.base in instruction_form.semantic_operands["src_dst"]) + self.assertTrue( + mops_src.base in instruction_form.semantic_operands["src_dst"] + ) for instruction_form in self.kernel_mops_2[-2:-1]: with self.subTest(instruction_form=instruction_form): if not instruction_form.line.startswith("//"): - self.assertTrue(mops_dest in instruction_form.semantic_operands["destination"]) - self.assertTrue(mops_x1 in instruction_form.semantic_operands["source"]) - self.assertTrue(mops_n in instruction_form.semantic_operands["src_dst"]) + self.assertTrue( + mops_dest in instruction_form.semantic_operands["destination"] + ) + self.assertTrue( + mops_x1 in instruction_form.semantic_operands["source"] + ) + self.assertTrue( + mops_n in instruction_form.semantic_operands["src_dst"] + ) self.assertTrue( mops_dest.base in instruction_form.semantic_operands["src_dst"] ) @@ -622,16 +698,25 @@ class TestSemanticTools(unittest.TestCase): semantics_hld.add_semantics(kernel_hld_2) semantics_hld.add_semantics(kernel_hld_3) - num_hidden_loads = len([x for x in kernel_hld if INSTR_FLAGS.HIDDEN_LD in x.flags]) - num_hidden_loads_2 = len([x for x in kernel_hld_2 if INSTR_FLAGS.HIDDEN_LD in x.flags]) - num_hidden_loads_3 = len([x for x in kernel_hld_3 if INSTR_FLAGS.HIDDEN_LD in x.flags]) + num_hidden_loads = len( + [x for x in kernel_hld if INSTR_FLAGS.HIDDEN_LD in x.flags] + ) + num_hidden_loads_2 = len( + [x for x in kernel_hld_2 if INSTR_FLAGS.HIDDEN_LD in x.flags] + ) + num_hidden_loads_3 = len( + [x for x in kernel_hld_3 if INSTR_FLAGS.HIDDEN_LD in x.flags] + ) self.assertEqual(num_hidden_loads, 1) self.assertEqual(num_hidden_loads_2, 0) self.assertEqual(num_hidden_loads_3, 1) def test_cyclic_dag(self): dg = KernelDG( - self.kernel_x86, self.parser_x86_att, self.machine_model_csx, self.semantics_csx + self.kernel_x86, + self.parser_x86_att, + self.machine_model_csx, + self.semantics_csx, ) dg.dg.add_edge(100, 101, latency=1.0) dg.dg.add_edge(101, 102, latency=2.0) @@ -656,7 +741,10 @@ class TestSemanticTools(unittest.TestCase): dep_path = "6-10-11-12-13-14" self.assertEqual(lc_deps[dep_path]["latency"], 29.0) self.assertEqual( - [(iform.line_number, lat) for iform, lat in lc_deps[dep_path]["dependencies"]], + [ + (iform.line_number, lat) + for iform, lat in lc_deps[dep_path]["dependencies"] + ], [(6, 4.0), (10, 6.0), (11, 6.0), (12, 6.0), (13, 6.0), (14, 1.0)], ) @@ -673,7 +761,10 @@ class TestSemanticTools(unittest.TestCase): dep_path = "4-5-6-9-10-11-12" self.assertEqual(lc_deps[dep_path]["latency"], 7.0) self.assertEqual( - [(iform.line_number, lat) for iform, lat in lc_deps[dep_path]["dependencies"]], + [ + (iform.line_number, lat) + for iform, lat in lc_deps[dep_path]["dependencies"] + ], [(4, 1.0), (5, 1.0), (6, 1.0), (9, 1.0), (10, 1.0), (11, 1.0), (12, 1.0)], ) @@ -690,7 +781,10 @@ class TestSemanticTools(unittest.TestCase): dep_path = "4-5-10-11-12" self.assertEqual(lc_deps[dep_path]["latency"], 5.0) self.assertEqual( - [(iform.line_number, lat) for iform, lat in lc_deps[dep_path]["dependencies"]], + [ + (iform.line_number, lat) + for iform, lat in lc_deps[dep_path]["dependencies"] + ], [(4, 1.0), (5, 1.0), (10, 1.0), (11, 1.0), (12, 1.0)], ) @@ -698,13 +792,17 @@ class TestSemanticTools(unittest.TestCase): lcd_id = "8" lcd_id2 = "5" dg = KernelDG( - self.kernel_x86, self.parser_x86_att, self.machine_model_csx, self.semantics_csx + self.kernel_x86, + self.parser_x86_att, + self.machine_model_csx, + self.semantics_csx, ) lc_deps = dg.get_loopcarried_dependencies() # self.assertEqual(len(lc_deps), 2) # ID 8 self.assertEqual( - lc_deps[lcd_id]["root"], dg.dg.nodes(data=True)[int(lcd_id)]["instruction_form"] + lc_deps[lcd_id]["root"], + dg.dg.nodes(data=True)[int(lcd_id)]["instruction_form"], ) self.assertEqual(len(lc_deps[lcd_id]["dependencies"]), 1) self.assertEqual( @@ -737,7 +835,8 @@ class TestSemanticTools(unittest.TestCase): # self.assertEqual(len(lc_deps), 2) # ID 8 self.assertEqual( - lc_deps[lcd_id]["root"], dg.dg.nodes(data=True)[int(lcd_id)]["instruction_form"] + lc_deps[lcd_id]["root"], + dg.dg.nodes(data=True)[int(lcd_id)]["instruction_form"], ) self.assertEqual(len(lc_deps[lcd_id]["dependencies"]), 1) self.assertEqual( @@ -792,7 +891,9 @@ class TestSemanticTools(unittest.TestCase): instr_form_r_c = self.parser_x86_att.parse_line("vmovsd %xmm0, (%r15,%rcx,8)") self.semantics_csx.normalize_instruction_form(instr_form_r_c) self.semantics_csx.assign_src_dst(instr_form_r_c) - instr_form_non_r_c = self.parser_x86_att.parse_line("movl %xmm0, (%r15,%rax,8)") + instr_form_non_r_c = self.parser_x86_att.parse_line( + "movl %xmm0, (%r15,%rax,8)" + ) self.semantics_csx.normalize_instruction_form(instr_form_non_r_c) self.semantics_csx.assign_src_dst(instr_form_non_r_c) instr_form_w_c = self.parser_x86_att.parse_line("movi $0x05ACA, %rcx") @@ -830,20 +931,28 @@ class TestSemanticTools(unittest.TestCase): reg_rcx = RegisterOperand(name="rcx") reg_ymm1 = RegisterOperand(name="ymm1") - instr_form_r_c = self.parser_x86_intel.parse_line("vmovsd QWORD PTR [r15+rcx*8], xmm0") + instr_form_r_c = self.parser_x86_intel.parse_line( + "vmovsd QWORD PTR [r15+rcx*8], xmm0" + ) self.semantics_csx_intel.normalize_instruction_form(instr_form_r_c) self.semantics_csx_intel.assign_src_dst(instr_form_r_c) - instr_form_non_r_c = self.parser_x86_intel.parse_line("mov QWORD PTR [r15+rax*8], xmm0") + instr_form_non_r_c = self.parser_x86_intel.parse_line( + "mov QWORD PTR [r15+rax*8], xmm0" + ) self.semantics_csx_intel.normalize_instruction_form(instr_form_non_r_c) self.semantics_csx_intel.assign_src_dst(instr_form_non_r_c) instr_form_w_c = self.parser_x86_intel.parse_line("mov rcx, H05ACA") self.semantics_csx_intel.normalize_instruction_form(instr_form_w_c) self.semantics_csx_intel.assign_src_dst(instr_form_w_c) - instr_form_rw_ymm_1 = self.parser_x86_intel.parse_line("vinsertf128 ymm1, ymm0, xmm1, 1") + instr_form_rw_ymm_1 = self.parser_x86_intel.parse_line( + "vinsertf128 ymm1, ymm0, xmm1, 1" + ) self.semantics_csx_intel.normalize_instruction_form(instr_form_rw_ymm_1) self.semantics_csx_intel.assign_src_dst(instr_form_rw_ymm_1) - instr_form_rw_ymm_2 = self.parser_x86_intel.parse_line("vinsertf128 ymm1, ymm1, xmm0, 1") + instr_form_rw_ymm_2 = self.parser_x86_intel.parse_line( + "vinsertf128 ymm1, ymm1, xmm0, 1" + ) self.semantics_csx_intel.normalize_instruction_form(instr_form_rw_ymm_2) self.semantics_csx_intel.assign_src_dst(instr_form_rw_ymm_2) instr_form_r_ymm = self.parser_x86_intel.parse_line("vmovapd ymm0, ymm1") @@ -886,7 +995,9 @@ class TestSemanticTools(unittest.TestCase): instr_form_w_1 = self.parser_AArch64.parse_line("ldr d1, [x1, #:got_lo12:q2c]") self.semantics_tx2.normalize_instruction_form(instr_form_w_1) self.semantics_tx2.assign_src_dst(instr_form_w_1) - instr_form_non_w_1 = self.parser_AArch64.parse_line("ldr x1, [x1, #:got_lo12:q2c]") + instr_form_non_w_1 = self.parser_AArch64.parse_line( + "ldr x1, [x1, #:got_lo12:q2c]" + ) self.semantics_tx2.normalize_instruction_form(instr_form_non_w_1) self.semantics_tx2.assign_src_dst(instr_form_non_w_1) instr_form_rw_1 = self.parser_AArch64.parse_line("fmul v1.2d, v1.2d, v0.2d") @@ -938,11 +1049,15 @@ class TestSemanticTools(unittest.TestCase): with self.assertRaises(ValueError): MachineModel() with self.assertRaises(ValueError): - MachineModel(arch="CSX", path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml")) + MachineModel( + arch="CSX", path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "csx.yml") + ) with self.assertRaises(FileNotFoundError): MachineModel(arch="THE_MACHINE") with self.assertRaises(FileNotFoundError): - MachineModel(path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "THE_MACHINE.yml")) + MachineModel( + path_to_yaml=os.path.join(self.MODULE_DATA_DIR, "THE_MACHINE.yml") + ) def test_MachineModel_getter(self): sample_operands = [ @@ -953,8 +1068,12 @@ class TestSemanticTools(unittest.TestCase): scale=8, ) ] - self.assertIsNone(self.machine_model_csx.get_instruction("GETRESULT", sample_operands)) - self.assertIsNone(self.machine_model_tx2.get_instruction("GETRESULT", sample_operands)) + self.assertIsNone( + self.machine_model_csx.get_instruction("GETRESULT", sample_operands) + ) + self.assertIsNone( + self.machine_model_tx2.get_instruction("GETRESULT", sample_operands) + ) self.assertEqual(self.machine_model_csx.get_arch(), "csx") self.assertEqual(self.machine_model_tx2.get_arch(), "tx2") diff --git a/validation/build_and_run.py b/validation/build_and_run.py index 379e16d..2e84e58 100755 --- a/validation/build_and_run.py +++ b/validation/build_and_run.py @@ -97,20 +97,20 @@ arch_info = { "cflags": { "icc": { "Ofast": ( - "-Ofast -fno-alias -xCORE-AVX512 -qopt-zmm-usage=high -nolib-inline " - "-ffreestanding -falign-loops" + "-Ofast -fno-alias -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" ).split(), "O3": ( - "-O3 -fno-alias -xCORE-AVX512 -qopt-zmm-usage=high -nolib-inline " - "-ffreestanding -falign-loops" + "-O3 -fno-alias -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" ).split(), "O2": ( - "-O2 -fno-alias -xCORE-AVX512 -qopt-zmm-usage=high -nolib-inline " - "-ffreestanding -falign-loops" + "-O2 -fno-alias -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" ).split(), "O1": ( - "-O1 -fno-alias -xCORE-AVX512 -qopt-zmm-usage=high -nolib-inline " - "-ffreestanding -falign-loops" + "-O1 -fno-alias -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" ).split(), }, "clang": { @@ -316,11 +316,21 @@ arch_info = { }, "icx": { "Ofast": ( - "-Ofast -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding -falign-loops" + "-Ofast -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" + ).split(), + "O3": ( + "-O3 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" + ).split(), + "O2": ( + "-O2 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" + ).split(), + "O1": ( + "-O1 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding " + "-falign-loops" ).split(), - "O3": "-O3 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding -falign-loops".split(), - "O2": "-O2 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding -falign-loops".split(), - "O1": "-O1 -xCORE-AVX512 -fno-alias -nolib-inline -ffreestanding -falign-loops".split(), }, }, }, @@ -391,10 +401,22 @@ arch_info = { "O1": "-O1 -msve-vector-bits=128 -march=armv9-a+sve2 -ffreestanding".split(), }, "armclang": { - "Ofast": "-Ofast -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 -mcpu=neoverse-v2 -ffreestanding".split(), - "O3": "-O3 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 -mcpu=neoverse-v2 -ffreestanding".split(), - "O2": "-O2 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 -mcpu=neoverse-v2 -ffreestanding".split(), - "O1": "-O1 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 -mcpu=neoverse-v2 -ffreestanding".split(), + "Ofast": ( + "-Ofast -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 " + "-mcpu=neoverse-v2 -ffreestanding" + ).split(), + "O3": ( + "-O3 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 " + "-mcpu=neoverse-v2 -ffreestanding" + ).split(), + "O2": ( + "-O2 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 " + "-mcpu=neoverse-v2 -ffreestanding" + ).split(), + "O1": ( + "-O1 -target aarch64-unknown-linux-gnu -march=armv9-a+sve2 " + "-mcpu=neoverse-v2 -ffreestanding" + ).split(), }, }, }, @@ -620,19 +642,23 @@ def build_mark_run_all_kernels(measurements=True, osaca=True, iaca=True, llvm_mc if llvm_mca and ainfo["LLVM-MCA"] is not None: print("LLVM-MCA", end="", flush=True) if not row.get("LLVM-MCA_ports"): - row["LLVM-MCA_raw"] = llvm_mca_analyse_instrumented_assembly( - marked_asmfile, - micro_architecture=ainfo["LLVM-MCA"], - isa=ainfo["isa"], + row["LLVM-MCA_raw"] = ( + llvm_mca_analyse_instrumented_assembly( + marked_asmfile, + micro_architecture=ainfo["LLVM-MCA"], + isa=ainfo["isa"], + ) ) row["LLVM-MCA_ports"] = { k: v / (row["pointer_increment"] / row["element_size"]) for k, v in row["LLVM-MCA_raw"]["port cycles"].items() } - row["LLVM-MCA_prediction"] = row["LLVM-MCA_raw"]["throughput"] / ( - row["pointer_increment"] / row["element_size"] + row["LLVM-MCA_prediction"] = row["LLVM-MCA_raw"][ + "throughput" + ] / (row["pointer_increment"] / row["element_size"]) + row["LLVM-MCA_throughput"] = max( + row["LLVM-MCA_ports"].values() ) - row["LLVM-MCA_throughput"] = max(row["LLVM-MCA_ports"].values()) row["LLVM-MCA_cp"] = row["LLVM-MCA_raw"]["cp_latency"] / ( row["pointer_increment"] / row["element_size"] ) @@ -699,7 +725,9 @@ def scalingrun(kernel_exec, total_iterations=25000000, lengths=range(8, 1 * 1024 # TODO use arch specific events and grooup r, o = perfctr(chain([kernel_exec], map(str, parameters)), 1, group="L2") global_infos = {} - for m in [re.match(r"(:?([a-z_\-0-9]+):)?([a-z]+): ([a-z\_\-0-9]+)", line) for line in o]: + for m in [ + re.match(r"(:?([a-z_\-0-9]+):)?([a-z]+): ([a-z\_\-0-9]+)", line) for line in o + ]: if m is not None: try: v = int(m.group(4)) @@ -773,7 +801,9 @@ def mark(asm_path, compiler, cflags, isa, overwrite=False): # Compile marked assembly to object for IACA marked_obj = Path(asm_path).with_suffix(".marked.o") if not marked_obj.exists(): - check_call([compiler] + cflags + ["-c", str(marked_asm_path), "-o", str(marked_obj)]) + check_call( + [compiler] + cflags + ["-c", str(marked_asm_path), "-o", str(marked_obj)] + ) return str(marked_asm_path), str(marked_obj), pointer_increment, overwrite @@ -805,18 +835,28 @@ def build_kernel( raise ValueError("Must build, but not allowed.") if not Path(f"{build_path}/dummy.o").exists(): - check_call([compiler] + cflags + ["-c", "kernels/dummy.c", "-o", f"{build_path}/dummy.o"]) + check_call( + [compiler] + + cflags + + ["-c", "kernels/dummy.c", "-o", f"{build_path}/dummy.o"] + ) if not Path(f"{build_path}/compiler_version").exists(): # Document compiler version with open(f"{build_path}/compiler_version", "w") as f: - f.write(check_output([compiler, "--version"], encoding="utf8", stderr=STDOUT)) + f.write( + check_output([compiler, "--version"], encoding="utf8", stderr=STDOUT) + ) if overwrite: # build object + assembly - check_call([compiler] + cflags + ["-c", f"kernels/{kernel}.c", "-o", kernel_object]) check_call( - [compiler] + cflags + ["-c", f"kernels/{kernel}.c", "-S", "-o", kernel_assembly] + [compiler] + cflags + ["-c", f"kernels/{kernel}.c", "-o", kernel_object] + ) + check_call( + [compiler] + + cflags + + ["-c", f"kernels/{kernel}.c", "-S", "-o", kernel_assembly] ) # build main and link executable @@ -974,7 +1014,10 @@ def main(): llvm_mca = False build_mark_run_all_kernels( - measurements="--no-measurements" not in sys.argv, iaca=False, osaca=True, llvm_mca=llvm_mca + measurements="--no-measurements" not in sys.argv, + iaca=False, + osaca=True, + llvm_mca=llvm_mca, ) sys.exit()