diff --git a/osaca/db_interface.py b/osaca/db_interface.py index ce02418..95c592a 100644 --- a/osaca/db_interface.py +++ b/osaca/db_interface.py @@ -412,20 +412,20 @@ def _check_sanity_arch_db(arch_mm, isa_mm, internet_check=True): suspicious_prefixes_x86 = ["vfm", "fm"] suspicious_prefixes_arm = ["fml", "ldp", "stp", "str"] suspicious_prefixes_riscv = [ - "vse", # Vector store (register is source, memory is destination) + "vse", # Vector store (register is source, memory is destination) "vfmacc", # Vector FMA with accumulation (first operand is both source and destination) "vfmadd", # Vector FMA with addition (first operand is implicitly both source and destination) - "vset", # Vector configuration (complex operand pattern) - "csrs", # CSR Set (first operand is both source and destination) - "csrc", # CSR Clear (first operand is both source and destination) - "csrsi", # CSR Set Immediate (first operand is both source and destination) - "csrci", # CSR Clear Immediate (first operand is both source and destination) - "amo", # Atomic memory operations (read-modify-write to memory) - "lr", # Load-Reserved (part of atomic operations) - "sc", # Store-Conditional (part of atomic operations) - "czero", # Conditional zero instructions (Zicond extension) + "vset", # Vector configuration (complex operand pattern) + "csrs", # CSR Set (first operand is both source and destination) + "csrc", # CSR Clear (first operand is both source and destination) + "csrsi", # CSR Set Immediate (first operand is both source and destination) + "csrci", # CSR Clear Immediate (first operand is both source and destination) + "amo", # Atomic memory operations (read-modify-write to memory) + "lr", # Load-Reserved (part of atomic operations) + "sc", # Store-Conditional (part of atomic operations) + "czero", # Conditional zero instructions (Zicond extension) ] - + # Default to empty list if ISA not recognized suspicious_prefixes = [] diff --git a/osaca/osaca.py b/osaca/osaca.py index 755face..7b6523b 100644 --- a/osaca/osaca.py +++ b/osaca/osaca.py @@ -487,6 +487,7 @@ def get_asm_parser(arch, syntax="ATT") -> BaseParser: else: raise ValueError("Unknown ISA: {}".format(isa)) + def get_unmatched_instruction_ratio(kernel): """Return ratio of unmatched from total instructions in kernel.""" unmatched_counter = 0 diff --git a/osaca/parser/base_parser.py b/osaca/parser/base_parser.py index 7606933..53efd43 100644 --- a/osaca/parser/base_parser.py +++ b/osaca/parser/base_parser.py @@ -1,4 +1,3 @@ -# TODO: Heuristics for detecting the RISCV ISA #!/usr/bin/env python3 """Parser superclass of specific parsers.""" import operator @@ -72,14 +71,19 @@ class BaseParser(object): # 3) check for RISC-V registers (x0-x31, a0-a7, t0-t6, s0-s11) and instructions heuristics_riscv = [ r"\bx[0-9]|x[1-2][0-9]|x3[0-1]\b", # x0-x31 registers - r"\ba[0-7]\b", # a0-a7 registers - r"\bt[0-6]\b", # t0-t6 registers - r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers - r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers - r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions - r"\baddi\b|\bsd\b|\bld\b|\bjal\b" # Common RISC-V instructions + r"\ba[0-7]\b", # a0-a7 registers + r"\bt[0-6]\b", # t0-t6 registers + r"\bs[0-9]|s1[0-1]\b", # s0-s11 registers + r"\bzero\b|\bra\b|\bsp\b|\bgp\b", # zero, ra, sp, gp registers + r"\bvsetvli\b|\bvle\b|\bvse\b", # RV Vector instructions + r"\baddi\b|\bsd\b|\bld\b|\bjal\b", # Common RISC-V instructions ] - matches = {("x86", "ATT"): 0, ("x86", "INTEL"): 0, ("aarch64", None): 0, ("riscv", None): 0} + matches = { + ("x86", "ATT"): 0, + ("x86", "INTEL"): 0, + ("aarch64", None): 0, + ("riscv", None): 0, + } for h in heuristics_x86ATT: matches[("x86", "ATT")] += len(re.findall(h, file_content)) diff --git a/osaca/parser/parser_RISCV.py b/osaca/parser/parser_RISCV.py index bdf72f9..db1bcce 100644 --- a/osaca/parser/parser_RISCV.py +++ b/osaca/parser/parser_RISCV.py @@ -46,8 +46,7 @@ class ParserRISCV(BaseParser): # Parse the RISC-V end marker (li a1, 222 followed by NOP) # This matches how end marker is defined in marker_utils.py for RISC-V marker_str = ( - "li a1, 222 # OSACA END MARKER\n" - ".byte 19,0,0,0 # OSACA END MARKER\n" + "li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n" ) return self.parse_file(marker_str) @@ -107,9 +106,7 @@ class ParserRISCV(BaseParser): # Label self.label = pp.Group( - identifier.setResultsName("name") - + pp.Literal(":") - + pp.Optional(self.comment) + identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment) ).setResultsName(self.label_id) # Directive @@ -119,21 +116,13 @@ class ParserRISCV(BaseParser): ) directive_parameter = ( - pp.quotedString - | directive_option - | identifier - | hex_number - | decimal_number - ) - commaSeparatedList = pp.delimitedList( - pp.Optional(directive_parameter), delim="," + pp.quotedString | directive_option | identifier | hex_number | decimal_number ) + commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",") self.directive = pp.Group( pp.Literal(".") + pp.Word(pp.alphanums + "_").setResultsName("name") - + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName( - "parameters" - ) + + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters") + pp.Optional(self.comment) ).setResultsName(self.directive_id) @@ -193,12 +182,7 @@ class ParserRISCV(BaseParser): # Combined register definition register = pp.Group( - integer_reg_x - | integer_reg_abi - | fp_reg_f - | fp_reg_abi - | vector_reg - | csr_reg + integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg ).setResultsName(self.register_id) self.register = register @@ -218,9 +202,7 @@ class ParserRISCV(BaseParser): # Handle additional vector parameters additional_params = pp.ZeroOrMore( pp.Suppress(pp.Literal(",")) - + pp.Word(pp.alphas + pp.nums).setResultsName( - "vector_param", listAllMatches=True - ) + + pp.Word(pp.alphas + pp.nums).setResultsName("vector_param", listAllMatches=True) ) # Main instruction parser @@ -260,9 +242,7 @@ class ParserRISCV(BaseParser): # 1. Parse comment try: - result = self.process_operand( - self.comment.parseString(line, parseAll=True).asDict() - ) + result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass @@ -280,9 +260,7 @@ class ParserRISCV(BaseParser): if result is None: try: # returns tuple with label operand and comment, if any - result = self.process_operand( - self.label.parseString(line, parseAll=True).asDict() - ) + result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) instruction_form.label = result[0].name if result[1] is not None: instruction_form.comment = " ".join(result[1]) @@ -371,15 +349,11 @@ class ParserRISCV(BaseParser): if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - return InstructionForm( - mnemonic=mnemonic, operands=operands, comment_id=comment - ) + return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment) # Regular instruction parsing try: - result = self.instruction_parser.parseString( - instruction, parseAll=True - ).asDict() + result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() operands = [] # Process operands @@ -410,9 +384,7 @@ class ParserRISCV(BaseParser): mnemonic=result["mnemonic"], operands=operands, comment_id=( - " ".join(result[self.comment_id]) - if self.comment_id in result - else None + " ".join(result[self.comment_id]) if self.comment_id in result else None ), ) return return_dict @@ -446,9 +418,7 @@ class ParserRISCV(BaseParser): if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - return InstructionForm( - mnemonic=mnemonic, operands=operands, comment_id=comment - ) + return InstructionForm(mnemonic=mnemonic, operands=operands, comment_id=comment) else: raise @@ -569,9 +539,7 @@ class ParserRISCV(BaseParser): elif name.startswith("f") and name[1] in ["t", "a", "s"]: if name[1] == "a": # fa0-fa7 idx = int(name[2:]) - return RegisterOperand( - prefix="f", name=str(idx + 10), regtype="float", width=64 - ) + return RegisterOperand(prefix="f", name=str(idx + 10), regtype="float", width=64) elif name[1] == "s": # fs0-fs11 idx = int(name[2:]) if idx <= 1: @@ -585,9 +553,7 @@ class ParserRISCV(BaseParser): elif name[1] == "t": # ft0-ft11 idx = int(name[2:]) if idx <= 7: - return RegisterOperand( - prefix="f", name=str(idx), regtype="float", width=64 - ) + return RegisterOperand(prefix="f", name=str(idx), regtype="float", width=64) else: return RegisterOperand( prefix="f", name=str(idx + 20), regtype="float", width=64 @@ -675,9 +641,7 @@ class ParserRISCV(BaseParser): # Handle numeric values with validation if "value" in immediate: - value = int( - immediate["value"], 0 - ) # Convert to integer, handling hex/decimal + value = int(immediate["value"], 0) # Convert to integer, handling hex/decimal # Determine immediate type and validate range based on instruction type if hasattr(self, "current_instruction"): @@ -714,9 +678,7 @@ class ParserRISCV(BaseParser): return ImmediateOperand(imd_type="S", value=value) # B-type instructions (13-bit signed immediate for branches, must be even) - elif any( - x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"] - ): + elif any(x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"]): if not -4096 <= value <= 4095 or value % 2 != 0: raise ValueError( f"Immediate value {value} out of range or not even " diff --git a/osaca/semantics/hw_model.py b/osaca/semantics/hw_model.py index c178190..29b00b0 100644 --- a/osaca/semantics/hw_model.py +++ b/osaca/semantics/hw_model.py @@ -1,5 +1,4 @@ -# TODO -#!/usr/bin/env python3 +#!/usr/bin/env python3w import hashlib import os @@ -1021,7 +1020,9 @@ class MachineModel(object): return False # Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.) - if (reg.prefix is None and i_reg.prefix == "x") or (reg.prefix == "x" and i_reg.prefix is None): + if (reg.prefix is None and i_reg.prefix == "x") or ( + reg.prefix == "x" and i_reg.prefix is None + ): try: # Need to check if they refer to the same register from osaca.parser import ParserRISCV @@ -1149,9 +1150,13 @@ class MachineModel(object): ( (mem.base is None and i_mem.base is None) or i_mem.base == self.WILDCARD - or (isinstance(mem.base, RegisterOperand) and - (mem.base.prefix == i_mem.base or - (mem.base.name is not None and i_mem.base is not None))) + or ( + isinstance(mem.base, RegisterOperand) + and ( + mem.base.prefix == i_mem.base + or (mem.base.name is not None and i_mem.base is not None) + ) + ) ) # check offset and ( diff --git a/osaca/semantics/kernel_dg.py b/osaca/semantics/kernel_dg.py index 2dd46fb..c2a0611 100644 --- a/osaca/semantics/kernel_dg.py +++ b/osaca/semantics/kernel_dg.py @@ -568,7 +568,7 @@ class KernelDG(nx.DiGraph): (latency, list(deps)) for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"]) ), - reverse=True + reverse=True, ) node_colors = {} edge_colors = {} @@ -591,17 +591,16 @@ class KernelDG(nx.DiGraph): edge_colors[u, v] = color max_color = min(11, colors_used) colorscheme = f"spectral{max(3, max_color)}" - graph.graph["node"] = {"colorscheme" : colorscheme} - graph.graph["edge"] = {"colorscheme" : colorscheme} + graph.graph["node"] = {"colorscheme": colorscheme} + graph.graph["edge"] = {"colorscheme": colorscheme} for n, color in node_colors.items(): if "style" not in graph.nodes[n]: graph.nodes[n]["style"] = "filled" else: graph.nodes[n]["style"] += ",filled" graph.nodes[n]["fillcolor"] = color - if ( - (max_color >= 4 and color in (1, max_color)) or - (max_color >= 10 and color in (1, 2, max_color - 1 , max_color)) + if (max_color >= 4 and color in (1, max_color)) or ( + max_color >= 10 and color in (1, 2, max_color - 1, max_color) ): graph.nodes[n]["fontcolor"] = "white" for (u, v), color in edge_colors.items(): diff --git a/osaca/semantics/marker_utils.py b/osaca/semantics/marker_utils.py index 50f6c88..7982e07 100644 --- a/osaca/semantics/marker_utils.py +++ b/osaca/semantics/marker_utils.py @@ -1,4 +1,3 @@ -# TODO #!/usr/bin/env python3 from collections import OrderedDict from enum import Enum @@ -174,8 +173,7 @@ def get_marker(isa, syntax="ATT", comment=""): start_marker_raw += "# {}\n".format(comment) # After loop end_marker_raw = ( - "li a1, 222 # OSACA END MARKER\n" - ".byte 19,0,0,0 # OSACA END MARKER\n" + "li a1, 222 # OSACA END MARKER\n" ".byte 19,0,0,0 # OSACA END MARKER\n" ) parser = get_parser(isa)