mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2025-12-16 09:00:05 +01:00
Format code with black and fix flake8 linting issues
- Applied black formatting with line length 99 - Fixed flake8 linting issues (E265 block comments) - All 115 tests still pass after formatting - Code style is now consistent across the codebase Changes: - osaca/parser/base_parser.py: improved line breaks and comment formatting - osaca/osaca.py: added missing blank line - osaca/db_interface.py: reformatted long lines and comments - osaca/parser/parser_RISCV.py: extensive formatting improvements - osaca/semantics/kernel_dg.py: improved formatting and readability - osaca/semantics/hw_model.py: fixed shebang and formatting - osaca/semantics/marker_utils.py: removed TODO comment and formatting
This commit is contained in:
@@ -487,6 +487,7 @@ def get_asm_parser(arch, syntax="ATT") -> BaseParser:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown ISA: {}".format(isa))
|
raise ValueError("Unknown ISA: {}".format(isa))
|
||||||
|
|
||||||
|
|
||||||
def get_unmatched_instruction_ratio(kernel):
|
def get_unmatched_instruction_ratio(kernel):
|
||||||
"""Return ratio of unmatched from total instructions in kernel."""
|
"""Return ratio of unmatched from total instructions in kernel."""
|
||||||
unmatched_counter = 0
|
unmatched_counter = 0
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# TODO: Heuristics for detecting the RISCV ISA
|
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Parser superclass of specific parsers."""
|
"""Parser superclass of specific parsers."""
|
||||||
import operator
|
import operator
|
||||||
@@ -77,9 +76,14 @@ class BaseParser(object):
|
|||||||
r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers
|
r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers
|
||||||
r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers
|
r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers
|
||||||
r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions
|
r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions
|
||||||
r"\baddi\b|\bsd\b|\bld\b|\bjal\b" # Common RISC-V instructions
|
r"\baddi\b|\bsd\b|\bld\b|\bjal\b", # Common RISC-V instructions
|
||||||
]
|
]
|
||||||
matches = {("x86", "ATT"): 0, ("x86", "INTEL"): 0, ("aarch64", None): 0, ("riscv", None): 0}
|
matches = {
|
||||||
|
("x86", "ATT"): 0,
|
||||||
|
("x86", "INTEL"): 0,
|
||||||
|
("aarch64", None): 0,
|
||||||
|
("riscv", None): 0,
|
||||||
|
}
|
||||||
|
|
||||||
for h in heuristics_x86ATT:
|
for h in heuristics_x86ATT:
|
||||||
matches[("x86", "ATT")] += len(re.findall(h, file_content))
|
matches[("x86", "ATT")] += len(re.findall(h, file_content))
|
||||||
|
|||||||
@@ -46,8 +46,7 @@ class ParserRISCV(BaseParser):
|
|||||||
# Parse the RISC-V end marker (li a1, 222 followed by NOP)
|
# Parse the RISC-V end marker (li a1, 222 followed by NOP)
|
||||||
# This matches how end marker is defined in marker_utils.py for RISC-V
|
# This matches how end marker is defined in marker_utils.py for RISC-V
|
||||||
marker_str = (
|
marker_str = (
|
||||||
"li a1, 222 # OSACA END MARKER\n"
|
"li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n"
|
||||||
".byte 19,0,0,0 # OSACA END MARKER\n"
|
|
||||||
)
|
)
|
||||||
return self.parse_file(marker_str)
|
return self.parse_file(marker_str)
|
||||||
|
|
||||||
@@ -107,9 +106,7 @@ class ParserRISCV(BaseParser):
|
|||||||
|
|
||||||
# Label
|
# Label
|
||||||
self.label = pp.Group(
|
self.label = pp.Group(
|
||||||
identifier.setResultsName("name")
|
identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment)
|
||||||
+ pp.Literal(":")
|
|
||||||
+ pp.Optional(self.comment)
|
|
||||||
).setResultsName(self.label_id)
|
).setResultsName(self.label_id)
|
||||||
|
|
||||||
# Directive
|
# Directive
|
||||||
@@ -119,21 +116,13 @@ class ParserRISCV(BaseParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
directive_parameter = (
|
directive_parameter = (
|
||||||
pp.quotedString
|
pp.quotedString | directive_option | identifier | hex_number | decimal_number
|
||||||
| directive_option
|
|
||||||
| identifier
|
|
||||||
| hex_number
|
|
||||||
| decimal_number
|
|
||||||
)
|
|
||||||
commaSeparatedList = pp.delimitedList(
|
|
||||||
pp.Optional(directive_parameter), delim=","
|
|
||||||
)
|
)
|
||||||
|
commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",")
|
||||||
self.directive = pp.Group(
|
self.directive = pp.Group(
|
||||||
pp.Literal(".")
|
pp.Literal(".")
|
||||||
+ pp.Word(pp.alphanums + "_").setResultsName("name")
|
+ pp.Word(pp.alphanums + "_").setResultsName("name")
|
||||||
+ (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName(
|
+ (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters")
|
||||||
"parameters"
|
|
||||||
)
|
|
||||||
+ pp.Optional(self.comment)
|
+ pp.Optional(self.comment)
|
||||||
).setResultsName(self.directive_id)
|
).setResultsName(self.directive_id)
|
||||||
|
|
||||||
@@ -193,12 +182,7 @@ class ParserRISCV(BaseParser):
|
|||||||
|
|
||||||
# Combined register definition
|
# Combined register definition
|
||||||
register = pp.Group(
|
register = pp.Group(
|
||||||
integer_reg_x
|
integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg
|
||||||
| integer_reg_abi
|
|
||||||
| fp_reg_f
|
|
||||||
| fp_reg_abi
|
|
||||||
| vector_reg
|
|
||||||
| csr_reg
|
|
||||||
).setResultsName(self.register_id)
|
).setResultsName(self.register_id)
|
||||||
|
|
||||||
self.register = register
|
self.register = register
|
||||||
@@ -218,9 +202,7 @@ class ParserRISCV(BaseParser):
|
|||||||
# Handle additional vector parameters
|
# Handle additional vector parameters
|
||||||
additional_params = pp.ZeroOrMore(
|
additional_params = pp.ZeroOrMore(
|
||||||
pp.Suppress(pp.Literal(","))
|
pp.Suppress(pp.Literal(","))
|
||||||
+ pp.Word(pp.alphas + pp.nums).setResultsName(
|
+ pp.Word(pp.alphas + pp.nums).setResultsName("vector_param", listAllMatches=True)
|
||||||
"vector_param", listAllMatches=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main instruction parser
|
# Main instruction parser
|
||||||
@@ -260,9 +242,7 @@ class ParserRISCV(BaseParser):
|
|||||||
|
|
||||||
# 1. Parse comment
|
# 1. Parse comment
|
||||||
try:
|
try:
|
||||||
result = self.process_operand(
|
result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict())
|
||||||
self.comment.parseString(line, parseAll=True).asDict()
|
|
||||||
)
|
|
||||||
instruction_form.comment = " ".join(result[self.comment_id])
|
instruction_form.comment = " ".join(result[self.comment_id])
|
||||||
except pp.ParseException:
|
except pp.ParseException:
|
||||||
pass
|
pass
|
||||||
@@ -280,9 +260,7 @@ class ParserRISCV(BaseParser):
|
|||||||
if result is None:
|
if result is None:
|
||||||
try:
|
try:
|
||||||
# returns tuple with label operand and comment, if any
|
# returns tuple with label operand and comment, if any
|
||||||
result = self.process_operand(
|
result = self.process_operand(self.label.parseString(line, parseAll=True).asDict())
|
||||||
self.label.parseString(line, parseAll=True).asDict()
|
|
||||||
)
|
|
||||||
instruction_form.label = result[0].name
|
instruction_form.label = result[0].name
|
||||||
if result[1] is not None:
|
if result[1] is not None:
|
||||||
instruction_form.comment = " ".join(result[1])
|
instruction_form.comment = " ".join(result[1])
|
||||||
@@ -371,15 +349,11 @@ class ParserRISCV(BaseParser):
|
|||||||
if "#" in instruction:
|
if "#" in instruction:
|
||||||
comment = instruction.split("#", 1)[1].strip()
|
comment = instruction.split("#", 1)[1].strip()
|
||||||
|
|
||||||
return InstructionForm(
|
return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment)
|
||||||
mnemonic=mnemonic, operands=operands, comment_id=comment
|
|
||||||
)
|
|
||||||
|
|
||||||
# Regular instruction parsing
|
# Regular instruction parsing
|
||||||
try:
|
try:
|
||||||
result = self.instruction_parser.parseString(
|
result = self.instruction_parser.parseString(instruction, parseAll=True).asDict()
|
||||||
instruction, parseAll=True
|
|
||||||
).asDict()
|
|
||||||
operands = []
|
operands = []
|
||||||
|
|
||||||
# Process operands
|
# Process operands
|
||||||
@@ -410,9 +384,7 @@ class ParserRISCV(BaseParser):
|
|||||||
mnemonic=result["mnemonic"],
|
mnemonic=result["mnemonic"],
|
||||||
operands=operands,
|
operands=operands,
|
||||||
comment_id=(
|
comment_id=(
|
||||||
" ".join(result[self.comment_id])
|
" ".join(result[self.comment_id]) if self.comment_id in result else None
|
||||||
if self.comment_id in result
|
|
||||||
else None
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return return_dict
|
return return_dict
|
||||||
@@ -446,9 +418,7 @@ class ParserRISCV(BaseParser):
|
|||||||
if "#" in instruction:
|
if "#" in instruction:
|
||||||
comment = instruction.split("#", 1)[1].strip()
|
comment = instruction.split("#", 1)[1].strip()
|
||||||
|
|
||||||
return InstructionForm(
|
return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment)
|
||||||
mnemonic=mnemonic, operands=operands, comment_id=comment
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -569,9 +539,7 @@ class ParserRISCV(BaseParser):
|
|||||||
elif name.startswith("f") and name[1] in ["t", "a", "s"]:
|
elif name.startswith("f") and name[1] in ["t", "a", "s"]:
|
||||||
if name[1] == "a": # fa0-fa7
|
if name[1] == "a": # fa0-fa7
|
||||||
idx = int(name[2:])
|
idx = int(name[2:])
|
||||||
return RegisterOperand(
|
return RegisterOperand(prefix="f", name=str(idx + 10), regtype="float", width=64)
|
||||||
prefix="f", name=str(idx + 10), regtype="float", width=64
|
|
||||||
)
|
|
||||||
elif name[1] == "s": # fs0-fs11
|
elif name[1] == "s": # fs0-fs11
|
||||||
idx = int(name[2:])
|
idx = int(name[2:])
|
||||||
if idx <= 1:
|
if idx <= 1:
|
||||||
@@ -585,9 +553,7 @@ class ParserRISCV(BaseParser):
|
|||||||
elif name[1] == "t": # ft0-ft11
|
elif name[1] == "t": # ft0-ft11
|
||||||
idx = int(name[2:])
|
idx = int(name[2:])
|
||||||
if idx <= 7:
|
if idx <= 7:
|
||||||
return RegisterOperand(
|
return RegisterOperand(prefix="f", name=str(idx), regtype="float", width=64)
|
||||||
prefix="f", name=str(idx), regtype="float", width=64
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return RegisterOperand(
|
return RegisterOperand(
|
||||||
prefix="f", name=str(idx + 20), regtype="float", width=64
|
prefix="f", name=str(idx + 20), regtype="float", width=64
|
||||||
@@ -675,9 +641,7 @@ class ParserRISCV(BaseParser):
|
|||||||
|
|
||||||
# Handle numeric values with validation
|
# Handle numeric values with validation
|
||||||
if "value" in immediate:
|
if "value" in immediate:
|
||||||
value = int(
|
value = int(immediate["value"], 0) # Convert to integer, handling hex/decimal
|
||||||
immediate["value"], 0
|
|
||||||
) # Convert to integer, handling hex/decimal
|
|
||||||
|
|
||||||
# Determine immediate type and validate range based on instruction type
|
# Determine immediate type and validate range based on instruction type
|
||||||
if hasattr(self, "current_instruction"):
|
if hasattr(self, "current_instruction"):
|
||||||
@@ -714,9 +678,7 @@ class ParserRISCV(BaseParser):
|
|||||||
return ImmediateOperand(imd_type="S", value=value)
|
return ImmediateOperand(imd_type="S", value=value)
|
||||||
|
|
||||||
# B-type instructions (13-bit signed immediate for branches, must be even)
|
# B-type instructions (13-bit signed immediate for branches, must be even)
|
||||||
elif any(
|
elif any(x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"]):
|
||||||
x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"]
|
|
||||||
):
|
|
||||||
if not -4096 <= value <= 4095 or value % 2 != 0:
|
if not -4096 <= value <= 4095 or value % 2 != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Immediate value {value} out of range or not even "
|
f"Immediate value {value} out of range or not even "
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# TODO
|
#!/usr/bin/env python3w
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
@@ -1021,7 +1020,9 @@ class MachineModel(object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.)
|
# Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.)
|
||||||
if (reg.prefix is None and i_reg.prefix == "x") or (reg.prefix == "x" and i_reg.prefix is None):
|
if (reg.prefix is None and i_reg.prefix == "x") or (
|
||||||
|
reg.prefix == "x" and i_reg.prefix is None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# Need to check if they refer to the same register
|
# Need to check if they refer to the same register
|
||||||
from osaca.parser import ParserRISCV
|
from osaca.parser import ParserRISCV
|
||||||
@@ -1149,9 +1150,13 @@ class MachineModel(object):
|
|||||||
(
|
(
|
||||||
(mem.base is None and i_mem.base is None)
|
(mem.base is None and i_mem.base is None)
|
||||||
or i_mem.base == self.WILDCARD
|
or i_mem.base == self.WILDCARD
|
||||||
or (isinstance(mem.base, RegisterOperand) and
|
or (
|
||||||
(mem.base.prefix == i_mem.base or
|
isinstance(mem.base, RegisterOperand)
|
||||||
(mem.base.name is not None and i_mem.base is not None)))
|
and (
|
||||||
|
mem.base.prefix == i_mem.base
|
||||||
|
or (mem.base.name is not None and i_mem.base is not None)
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# check offset
|
# check offset
|
||||||
and (
|
and (
|
||||||
|
|||||||
@@ -568,7 +568,7 @@ class KernelDG(nx.DiGraph):
|
|||||||
(latency, list(deps))
|
(latency, list(deps))
|
||||||
for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"])
|
for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"])
|
||||||
),
|
),
|
||||||
reverse=True
|
reverse=True,
|
||||||
)
|
)
|
||||||
node_colors = {}
|
node_colors = {}
|
||||||
edge_colors = {}
|
edge_colors = {}
|
||||||
@@ -599,9 +599,8 @@ class KernelDG(nx.DiGraph):
|
|||||||
else:
|
else:
|
||||||
graph.nodes[n]["style"] += ",filled"
|
graph.nodes[n]["style"] += ",filled"
|
||||||
graph.nodes[n]["fillcolor"] = color
|
graph.nodes[n]["fillcolor"] = color
|
||||||
if (
|
if (max_color >= 4 and color in (1, max_color)) or (
|
||||||
(max_color >= 4 and color in (1, max_color)) or
|
max_color >= 10 and color in (1, 2, max_color - 1, max_color)
|
||||||
(max_color >= 10 and color in (1, 2, max_color - 1 , max_color))
|
|
||||||
):
|
):
|
||||||
graph.nodes[n]["fontcolor"] = "white"
|
graph.nodes[n]["fontcolor"] = "white"
|
||||||
for (u, v), color in edge_colors.items():
|
for (u, v), color in edge_colors.items():
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# TODO
|
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -174,8 +173,7 @@ def get_marker(isa, syntax="ATT", comment=""):
|
|||||||
start_marker_raw += "# {}\n".format(comment)
|
start_marker_raw += "# {}\n".format(comment)
|
||||||
# After loop
|
# After loop
|
||||||
end_marker_raw = (
|
end_marker_raw = (
|
||||||
"li a1, 222 # OSACA END MARKER\n"
|
"li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n"
|
||||||
".byte 19,0,0,0 # OSACA END MARKER\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = get_parser(isa)
|
parser = get_parser(isa)
|
||||||
|
|||||||
Reference in New Issue
Block a user