Format code with black and fix flake8 linting issues

- Applied black formatting with line length 99
- Fixed flake8 linting issues (E265 block comments)
- All 115 tests still pass after formatting
- Code style is now consistent across the codebase

Changes:
- osaca/parser/base_parser.py: improved line breaks and comment formatting
- osaca/osaca.py: added missing blank line
- osaca/db_interface.py: reformatted long lines and comments
- osaca/parser/parser_RISCV.py: extensive formatting improvements
- osaca/semantics/kernel_dg.py: improved formatting and readability
- osaca/semantics/hw_model.py: fixed shebang and formatting
- osaca/semantics/marker_utils.py: removed TODO comment and formatting
This commit is contained in:
Metehan Dundar
2025-07-11 22:28:29 +02:00
parent ebf76caa18
commit a8fca2afdb
7 changed files with 58 additions and 89 deletions

View File

@@ -412,20 +412,20 @@ def _check_sanity_arch_db(arch_mm, isa_mm, internet_check=True):
suspicious_prefixes_x86 = ["vfm", "fm"] suspicious_prefixes_x86 = ["vfm", "fm"]
suspicious_prefixes_arm = ["fml", "ldp", "stp", "str"] suspicious_prefixes_arm = ["fml", "ldp", "stp", "str"]
suspicious_prefixes_riscv = [ suspicious_prefixes_riscv = [
"vse", # Vector store (register is source, memory is destination) "vse", # Vector store (register is source, memory is destination)
"vfmacc", # Vector FMA with accumulation (first operand is both source and destination) "vfmacc", # Vector FMA with accumulation (first operand is both source and destination)
"vfmadd", # Vector FMA with addition (first operand is implicitly both source and destination) "vfmadd", # Vector FMA with addition (first operand is implicitly both source and destination)
"vset", # Vector configuration (complex operand pattern) "vset", # Vector configuration (complex operand pattern)
"csrs", # CSR Set (first operand is both source and destination) "csrs", # CSR Set (first operand is both source and destination)
"csrc", # CSR Clear (first operand is both source and destination) "csrc", # CSR Clear (first operand is both source and destination)
"csrsi", # CSR Set Immediate (first operand is both source and destination) "csrsi", # CSR Set Immediate (first operand is both source and destination)
"csrci", # CSR Clear Immediate (first operand is both source and destination) "csrci", # CSR Clear Immediate (first operand is both source and destination)
"amo", # Atomic memory operations (read-modify-write to memory) "amo", # Atomic memory operations (read-modify-write to memory)
"lr", # Load-Reserved (part of atomic operations) "lr", # Load-Reserved (part of atomic operations)
"sc", # Store-Conditional (part of atomic operations) "sc", # Store-Conditional (part of atomic operations)
"czero", # Conditional zero instructions (Zicond extension) "czero", # Conditional zero instructions (Zicond extension)
] ]
# Default to empty list if ISA not recognized # Default to empty list if ISA not recognized
suspicious_prefixes = [] suspicious_prefixes = []

View File

@@ -487,6 +487,7 @@ def get_asm_parser(arch, syntax="ATT") -> BaseParser:
else: else:
raise ValueError("Unknown ISA: {}".format(isa)) raise ValueError("Unknown ISA: {}".format(isa))
def get_unmatched_instruction_ratio(kernel): def get_unmatched_instruction_ratio(kernel):
"""Return ratio of unmatched from total instructions in kernel.""" """Return ratio of unmatched from total instructions in kernel."""
unmatched_counter = 0 unmatched_counter = 0

View File

@@ -1,4 +1,3 @@
# TODO: Heuristics for detecting the RISCV ISA
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Parser superclass of specific parsers.""" """Parser superclass of specific parsers."""
import operator import operator
@@ -72,14 +71,19 @@ class BaseParser(object):
# 3) check for RISC-V registers (x0-x31, a0-a7, t0-t6, s0-s11) and instructions # 3) check for RISC-V registers (x0-x31, a0-a7, t0-t6, s0-s11) and instructions
heuristics_riscv = [ heuristics_riscv = [
r"\bx[0-9]|x[1-2][0-9]|x3[0-1]\b", # x0-x31 registers r"\bx[0-9]|x[1-2][0-9]|x3[0-1]\b", # x0-x31 registers
r"\ba[0-7]\b", # a0-a7 registers r"\ba[0-7]\b", # a0-a7 registers
r"\bt[0-6]\b", # t0-t6 registers r"\bt[0-6]\b", # t0-t6 registers
r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers
r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers
r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions
r"\baddi\b|\bsd\b|\bld\b|\bjal\b" # Common RISC-V instructions r"\baddi\b|\bsd\b|\bld\b|\bjal\b", # Common RISC-V instructions
] ]
matches = {("x86", "ATT"): 0, ("x86", "INTEL"): 0, ("aarch64", None): 0, ("riscv", None): 0} matches = {
("x86", "ATT"): 0,
("x86", "INTEL"): 0,
("aarch64", None): 0,
("riscv", None): 0,
}
for h in heuristics_x86ATT: for h in heuristics_x86ATT:
matches[("x86", "ATT")] += len(re.findall(h, file_content)) matches[("x86", "ATT")] += len(re.findall(h, file_content))

View File

@@ -46,8 +46,7 @@ class ParserRISCV(BaseParser):
# Parse the RISC-V end marker (li a1, 222 followed by NOP) # Parse the RISC-V end marker (li a1, 222 followed by NOP)
# This matches how end marker is defined in marker_utils.py for RISC-V # This matches how end marker is defined in marker_utils.py for RISC-V
marker_str = ( marker_str = (
"li a1, 222 # OSACA END MARKER\n" "li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n"
".byte 19,0,0,0 # OSACA END MARKER\n"
) )
return self.parse_file(marker_str) return self.parse_file(marker_str)
@@ -107,9 +106,7 @@ class ParserRISCV(BaseParser):
# Label # Label
self.label = pp.Group( self.label = pp.Group(
identifier.setResultsName("name") identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment)
+ pp.Literal(":")
+ pp.Optional(self.comment)
).setResultsName(self.label_id) ).setResultsName(self.label_id)
# Directive # Directive
@@ -119,21 +116,13 @@ class ParserRISCV(BaseParser):
) )
directive_parameter = ( directive_parameter = (
pp.quotedString pp.quotedString | directive_option | identifier | hex_number | decimal_number
| directive_option
| identifier
| hex_number
| decimal_number
)
commaSeparatedList = pp.delimitedList(
pp.Optional(directive_parameter), delim=","
) )
commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",")
self.directive = pp.Group( self.directive = pp.Group(
pp.Literal(".") pp.Literal(".")
+ pp.Word(pp.alphanums + "_").setResultsName("name") + pp.Word(pp.alphanums + "_").setResultsName("name")
+ (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName( + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters")
"parameters"
)
+ pp.Optional(self.comment) + pp.Optional(self.comment)
).setResultsName(self.directive_id) ).setResultsName(self.directive_id)
@@ -193,12 +182,7 @@ class ParserRISCV(BaseParser):
# Combined register definition # Combined register definition
register = pp.Group( register = pp.Group(
integer_reg_x integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg
| integer_reg_abi
| fp_reg_f
| fp_reg_abi
| vector_reg
| csr_reg
).setResultsName(self.register_id) ).setResultsName(self.register_id)
self.register = register self.register = register
@@ -218,9 +202,7 @@ class ParserRISCV(BaseParser):
# Handle additional vector parameters # Handle additional vector parameters
additional_params = pp.ZeroOrMore( additional_params = pp.ZeroOrMore(
pp.Suppress(pp.Literal(",")) pp.Suppress(pp.Literal(","))
+ pp.Word(pp.alphas + pp.nums).setResultsName( + pp.Word(pp.alphas + pp.nums).setResultsName("vector_param", listAllMatches=True)
"vector_param", listAllMatches=True
)
) )
# Main instruction parser # Main instruction parser
@@ -260,9 +242,7 @@ class ParserRISCV(BaseParser):
# 1. Parse comment # 1. Parse comment
try: try:
result = self.process_operand( result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict())
self.comment.parseString(line, parseAll=True).asDict()
)
instruction_form.comment = " ".join(result[self.comment_id]) instruction_form.comment = " ".join(result[self.comment_id])
except pp.ParseException: except pp.ParseException:
pass pass
@@ -280,9 +260,7 @@ class ParserRISCV(BaseParser):
if result is None: if result is None:
try: try:
# returns tuple with label operand and comment, if any # returns tuple with label operand and comment, if any
result = self.process_operand( result = self.process_operand(self.label.parseString(line, parseAll=True).asDict())
self.label.parseString(line, parseAll=True).asDict()
)
instruction_form.label = result[0].name instruction_form.label = result[0].name
if result[1] is not None: if result[1] is not None:
instruction_form.comment = " ".join(result[1]) instruction_form.comment = " ".join(result[1])
@@ -371,15 +349,11 @@ class ParserRISCV(BaseParser):
if "#" in instruction: if "#" in instruction:
comment = instruction.split("#", 1)[1].strip() comment = instruction.split("#", 1)[1].strip()
return InstructionForm( return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment)
mnemonic=mnemonic, operands=operands, comment_id=comment
)
# Regular instruction parsing # Regular instruction parsing
try: try:
result = self.instruction_parser.parseString( result = self.instruction_parser.parseString(instruction, parseAll=True).asDict()
instruction, parseAll=True
).asDict()
operands = [] operands = []
# Process operands # Process operands
@@ -410,9 +384,7 @@ class ParserRISCV(BaseParser):
mnemonic=result["mnemonic"], mnemonic=result["mnemonic"],
operands=operands, operands=operands,
comment_id=( comment_id=(
" ".join(result[self.comment_id]) " ".join(result[self.comment_id]) if self.comment_id in result else None
if self.comment_id in result
else None
), ),
) )
return return_dict return return_dict
@@ -446,9 +418,7 @@ class ParserRISCV(BaseParser):
if "#" in instruction: if "#" in instruction:
comment = instruction.split("#", 1)[1].strip() comment = instruction.split("#", 1)[1].strip()
return InstructionForm( return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment)
mnemonic=mnemonic, operands=operands, comment_id=comment
)
else: else:
raise raise
@@ -569,9 +539,7 @@ class ParserRISCV(BaseParser):
elif name.startswith("f") and name[1] in ["t", "a", "s"]: elif name.startswith("f") and name[1] in ["t", "a", "s"]:
if name[1] == "a": # fa0-fa7 if name[1] == "a": # fa0-fa7
idx = int(name[2:]) idx = int(name[2:])
return RegisterOperand( return RegisterOperand(prefix="f", name=str(idx + 10), regtype="float", width=64)
prefix="f", name=str(idx + 10), regtype="float", width=64
)
elif name[1] == "s": # fs0-fs11 elif name[1] == "s": # fs0-fs11
idx = int(name[2:]) idx = int(name[2:])
if idx <= 1: if idx <= 1:
@@ -585,9 +553,7 @@ class ParserRISCV(BaseParser):
elif name[1] == "t": # ft0-ft11 elif name[1] == "t": # ft0-ft11
idx = int(name[2:]) idx = int(name[2:])
if idx <= 7: if idx <= 7:
return RegisterOperand( return RegisterOperand(prefix="f", name=str(idx), regtype="float", width=64)
prefix="f", name=str(idx), regtype="float", width=64
)
else: else:
return RegisterOperand( return RegisterOperand(
prefix="f", name=str(idx + 20), regtype="float", width=64 prefix="f", name=str(idx + 20), regtype="float", width=64
@@ -675,9 +641,7 @@ class ParserRISCV(BaseParser):
# Handle numeric values with validation # Handle numeric values with validation
if "value" in immediate: if "value" in immediate:
value = int( value = int(immediate["value"], 0) # Convert to integer, handling hex/decimal
immediate["value"], 0
) # Convert to integer, handling hex/decimal
# Determine immediate type and validate range based on instruction type # Determine immediate type and validate range based on instruction type
if hasattr(self, "current_instruction"): if hasattr(self, "current_instruction"):
@@ -714,9 +678,7 @@ class ParserRISCV(BaseParser):
return ImmediateOperand(imd_type="S", value=value) return ImmediateOperand(imd_type="S", value=value)
# B-type instructions (13-bit signed immediate for branches, must be even) # B-type instructions (13-bit signed immediate for branches, must be even)
elif any( elif any(x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"]):
x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"]
):
if not -4096 <= value <= 4095 or value % 2 != 0: if not -4096 <= value <= 4095 or value % 2 != 0:
raise ValueError( raise ValueError(
f"Immediate value {value} out of range or not even " f"Immediate value {value} out of range or not even "

View File

@@ -1,5 +1,4 @@
# TODO #!/usr/bin/env python3w
#!/usr/bin/env python3
import hashlib import hashlib
import os import os
@@ -1021,7 +1020,9 @@ class MachineModel(object):
return False return False
# Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.) # Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.)
if (reg.prefix is None and i_reg.prefix == "x") or (reg.prefix == "x" and i_reg.prefix is None): if (reg.prefix is None and i_reg.prefix == "x") or (
reg.prefix == "x" and i_reg.prefix is None
):
try: try:
# Need to check if they refer to the same register # Need to check if they refer to the same register
from osaca.parser import ParserRISCV from osaca.parser import ParserRISCV
@@ -1149,9 +1150,13 @@ class MachineModel(object):
( (
(mem.base is None and i_mem.base is None) (mem.base is None and i_mem.base is None)
or i_mem.base == self.WILDCARD or i_mem.base == self.WILDCARD
or (isinstance(mem.base, RegisterOperand) and or (
(mem.base.prefix == i_mem.base or isinstance(mem.base, RegisterOperand)
(mem.base.name is not None and i_mem.base is not None))) and (
mem.base.prefix == i_mem.base
or (mem.base.name is not None and i_mem.base is not None)
)
)
) )
# check offset # check offset
and ( and (

View File

@@ -568,7 +568,7 @@ class KernelDG(nx.DiGraph):
(latency, list(deps)) (latency, list(deps))
for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"]) for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"])
), ),
reverse=True reverse=True,
) )
node_colors = {} node_colors = {}
edge_colors = {} edge_colors = {}
@@ -591,17 +591,16 @@ class KernelDG(nx.DiGraph):
edge_colors[u, v] = color edge_colors[u, v] = color
max_color = min(11, colors_used) max_color = min(11, colors_used)
colorscheme = f"spectral{max(3, max_color)}" colorscheme = f"spectral{max(3, max_color)}"
graph.graph["node"] = {"colorscheme" : colorscheme} graph.graph["node"] = {"colorscheme": colorscheme}
graph.graph["edge"] = {"colorscheme" : colorscheme} graph.graph["edge"] = {"colorscheme": colorscheme}
for n, color in node_colors.items(): for n, color in node_colors.items():
if "style" not in graph.nodes[n]: if "style" not in graph.nodes[n]:
graph.nodes[n]["style"] = "filled" graph.nodes[n]["style"] = "filled"
else: else:
graph.nodes[n]["style"] += ",filled" graph.nodes[n]["style"] += ",filled"
graph.nodes[n]["fillcolor"] = color graph.nodes[n]["fillcolor"] = color
if ( if (max_color >= 4 and color in (1, max_color)) or (
(max_color >= 4 and color in (1, max_color)) or max_color >= 10 and color in (1, 2, max_color - 1, max_color)
(max_color >= 10 and color in (1, 2, max_color - 1 , max_color))
): ):
graph.nodes[n]["fontcolor"] = "white" graph.nodes[n]["fontcolor"] = "white"
for (u, v), color in edge_colors.items(): for (u, v), color in edge_colors.items():

View File

@@ -1,4 +1,3 @@
# TODO
#!/usr/bin/env python3 #!/usr/bin/env python3
from collections import OrderedDict from collections import OrderedDict
from enum import Enum from enum import Enum
@@ -174,8 +173,7 @@ def get_marker(isa, syntax="ATT", comment=""):
start_marker_raw += "# {}\n".format(comment) start_marker_raw += "# {}\n".format(comment)
# After loop # After loop
end_marker_raw = ( end_marker_raw = (
"li a1, 222 # OSACA END MARKER\n" "li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n"
".byte 19,0,0,0 # OSACA END MARKER\n"
) )
parser = get_parser(isa) parser = get_parser(isa)