diff --git a/.gitignore b/.gitignore index a996ec7..426ace6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,5 @@ # OSACA specific files and folders *.*.pickle -osaca_testfront_venv/ -examples/riscy_asm_files/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/examples/add/add.s.rv6.gcc.s b/examples/add/add.s.rv64.gcc.s similarity index 100% rename from examples/add/add.s.rv6.gcc.s rename to examples/add/add.s.rv64.gcc.s diff --git a/examples/copy/copy.s.rv6.gcc.s b/examples/copy/copy.s.rv64.gcc.s similarity index 100% rename from examples/copy/copy.s.rv6.gcc.s rename to examples/copy/copy.s.rv64.gcc.s diff --git a/examples/daxpy/daxpy.s.rv6.gcc.s b/examples/daxpy/daxpy.s.rv64.gcc.s similarity index 100% rename from examples/daxpy/daxpy.s.rv6.gcc.s rename to examples/daxpy/daxpy.s.rv64.gcc.s diff --git a/examples/gs/gs.s.rv6.gcc.s b/examples/gs/gs.s.rv64.gcc.s similarity index 100% rename from examples/gs/gs.s.rv6.gcc.s rename to examples/gs/gs.s.rv64.gcc.s diff --git a/examples/j2d/j2d.s.rv6.gcc.s b/examples/j2d/j2d.s.rv64.gcc.s similarity index 100% rename from examples/j2d/j2d.s.rv6.gcc.s rename to examples/j2d/j2d.s.rv64.gcc.s diff --git a/examples/striad/striad.s.rv6.gcc.s b/examples/striad/striad.s.rv64.gcc.s similarity index 100% rename from examples/striad/striad.s.rv6.gcc.s rename to examples/striad/striad.s.rv64.gcc.s diff --git a/examples/sum_reduction/sum_reduction.s.rv6.gcc.s b/examples/sum_reduction/sum_reduction.s.rv64.gcc.s similarity index 100% rename from examples/sum_reduction/sum_reduction.s.rv6.gcc.s rename to examples/sum_reduction/sum_reduction.s.rv64.gcc.s diff --git a/examples/triad/triad.s.rv6.gcc.s b/examples/triad/triad.s.rv64.gcc.s similarity index 100% rename from examples/triad/triad.s.rv6.gcc.s rename to examples/triad/triad.s.rv64.gcc.s diff --git a/examples/update/update.s.rv6.gcc.s b/examples/update/update.s.rv64.gcc.s similarity index 100% rename from examples/update/update.s.rv6.gcc.s rename to examples/update/update.s.rv64.gcc.s diff --git a/osaca/data/rv64.yml b/osaca/data/rv64.yml index fdfca03..03e92bb 100644 --- a/osaca/data/rv64.yml +++ b/osaca/data/rv64.yml @@ -661,7 +661,7 @@ instruction_forms: latency: 3 throughput: 1 port_pressure: [[1, ["FP"]]] - + - name: VFMADD.VV operands: - class: register diff --git a/osaca/parser/immediate.py b/osaca/parser/immediate.py index afda34c..0dd9c1e 100644 --- a/osaca/parser/immediate.py +++ b/osaca/parser/immediate.py @@ -10,6 +10,8 @@ class ImmediateOperand(Operand): imd_type=None, value=None, shift=None, + reloc_type=None, + symbol=None, source=False, destination=False, ): @@ -18,6 +20,8 @@ class ImmediateOperand(Operand): self._imd_type = imd_type self._value = value self._shift = shift + self._reloc_type = reloc_type + self._symbol = symbol @property def identifier(self): @@ -33,7 +37,15 @@ class ImmediateOperand(Operand): @property def shift(self): - return self._imd_type + return self._shift + + @property + def reloc_type(self): + return self._reloc_type + + @property + def symbol(self): + return self._symbol @imd_type.setter def imd_type(self, itype): @@ -51,10 +63,19 @@ class ImmediateOperand(Operand): def shift(self, shift): self._shift = shift + @reloc_type.setter + def reloc_type(self, reloc_type): + self._reloc_type = reloc_type + + @symbol.setter + def symbol(self, symbol): + self._symbol = symbol + def __str__(self): return ( f"Immediate(identifier={self._identifier}, imd_type={self._imd_type}, " - f"value={self._value}, shift={self._shift}, source={self._source}, destination={self._destination})" + f"value={self._value}, shift={self._shift}, reloc_type={self._reloc_type}, " + f"symbol={self._symbol}, source={self._source}, destination={self._destination})" ) def __repr__(self): @@ -62,10 +83,18 @@ class ImmediateOperand(Operand): def __eq__(self, other): if isinstance(other, ImmediateOperand): + # Handle cases where old instances might not have the new attributes + self_reloc_type = getattr(self, "_reloc_type", None) + self_symbol = getattr(self, "_symbol", None) + other_reloc_type = getattr(other, "_reloc_type", None) + other_symbol = getattr(other, "_symbol", None) + return ( self._identifier == other._identifier and self._imd_type == other._imd_type and self._value == other._value and self._shift == other._shift + and self_reloc_type == other_reloc_type + and self_symbol == other_symbol ) return False diff --git a/osaca/parser/parser_RISCV.py b/osaca/parser/parser_RISCV.py index b1db9fe..bdf72f9 100644 --- a/osaca/parser/parser_RISCV.py +++ b/osaca/parser/parser_RISCV.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 import re -import os -from copy import deepcopy import pyparsing as pp from osaca.parser import BaseParser @@ -13,7 +11,6 @@ from osaca.parser.label import LabelOperand from osaca.parser.register import RegisterOperand from osaca.parser.identifier import IdentifierOperand from osaca.parser.immediate import ImmediateOperand -from osaca.parser.condition import ConditionOperand class ParserRISCV(BaseParser): @@ -61,7 +58,7 @@ class ParserRISCV(BaseParser): self.comment = pp.Literal(symbol_comment) + pp.Group( pp.ZeroOrMore(pp.Word(pp.printables)) ).setResultsName(self.comment_id) - + # Define RISC-V assembly identifier decimal_number = pp.Combine( pp.Optional(pp.Literal("-")) + pp.Word(pp.nums) @@ -69,18 +66,32 @@ class ParserRISCV(BaseParser): hex_number = pp.Combine( pp.Optional(pp.Literal("-")) + pp.Literal("0x") + pp.Word(pp.hexnums) ).setResultsName("value") - - # Additional identifiers used in vector instructions - vector_identifier = pp.Word(pp.alphas, pp.alphanums) - special_identifier = pp.Word(pp.alphas + "%") - - # First character of an identifier + + # RISC-V specific relocation attributes + reloc_type = ( + pp.Literal("%hi") + | pp.Literal("%lo") + | pp.Literal("%pcrel_hi") + | pp.Literal("%pcrel_lo") + | pp.Literal("%tprel_hi") + | pp.Literal("%tprel_lo") + | pp.Literal("%tprel_add") + ).setResultsName("reloc_type") + + reloc_expr = pp.Group( + reloc_type + + pp.Suppress("(") + + pp.Word(pp.alphas + pp.nums + "_").setResultsName("symbol") + + pp.Suppress(")") + ).setResultsName("relocation") + + # First character of an identifier first = pp.Word(pp.alphas + "_.", exact=1) # Rest of the identifier rest = pp.Word(pp.alphanums + "_.") # PLT suffix (@plt) for calls to shared libraries plt_suffix = pp.Optional(pp.Literal("@") + pp.Word(pp.alphas)) - + identifier = pp.Group( (pp.Combine(first + pp.Optional(rest) + plt_suffix)).setResultsName("name") + pp.Optional( @@ -88,31 +99,44 @@ class ParserRISCV(BaseParser): + (hex_number | decimal_number).setResultsName("offset") ) ).setResultsName(self.identifier) - + + # Immediate with optional relocation + immediate = pp.Group( + reloc_expr | (hex_number ^ decimal_number) | identifier + ).setResultsName(self.immediate_id) + # Label self.label = pp.Group( - identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment) + identifier.setResultsName("name") + + pp.Literal(":") + + pp.Optional(self.comment) ).setResultsName(self.label_id) - + # Directive directive_option = pp.Combine( pp.Word(pp.alphas + "#@.%", exact=1) + pp.Optional(pp.Word(pp.printables + " ", excludeChars=",")) ) - - # For vector instructions - vector_parameter = pp.Word(pp.alphas) + directive_parameter = ( - pp.quotedString | directive_option | identifier | hex_number | decimal_number + pp.quotedString + | directive_option + | identifier + | hex_number + | decimal_number + ) + commaSeparatedList = pp.delimitedList( + pp.Optional(directive_parameter), delim="," ) - commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",") self.directive = pp.Group( pp.Literal(".") + pp.Word(pp.alphanums + "_").setResultsName("name") - + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters") + + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName( + "parameters" + ) + pp.Optional(self.comment) ).setResultsName(self.directive_id) - + # LLVM-MCA markers self.llvm_markers = pp.Group( pp.Literal("#") @@ -127,61 +151,58 @@ class ParserRISCV(BaseParser): # Instructions # Mnemonic mnemonic = pp.Word(pp.alphanums + ".").setResultsName("mnemonic") - - # Immediate: - # int: ^-?[0-9]+ | hex: ^0x[0-9a-fA-F]+ - immediate = pp.Group( - (hex_number ^ decimal_number) - | identifier - ).setResultsName(self.immediate_id) - + # Register: # RISC-V has two main types of registers: # 1. Integer registers (x0-x31 or ABI names) # 2. Floating-point registers (f0-f31 or ABI names) - + # Integer register ABI names integer_reg_abi = ( - pp.CaselessLiteral("zero") | - pp.CaselessLiteral("ra") | - pp.CaselessLiteral("sp") | - pp.CaselessLiteral("gp") | - pp.CaselessLiteral("tp") | - pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11 + pp.CaselessLiteral("zero") + | pp.CaselessLiteral("ra") + | pp.CaselessLiteral("sp") + | pp.CaselessLiteral("gp") + | pp.CaselessLiteral("tp") + | pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11 ).setResultsName("name") - + # Integer registers x0-x31 - integer_reg_x = ( - pp.CaselessLiteral("x").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + integer_reg_x = pp.CaselessLiteral("x").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Floating point registers - fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName("name") # ft0-ft11, fa0-fa7, fs0-fs11 - - fp_reg_f = ( - pp.CaselessLiteral("f").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName( + "name" + ) # ft0-ft11, fa0-fa7, fs0-fs11 + + fp_reg_f = pp.CaselessLiteral("f").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Control and status registers (CSRs) csr_reg = pp.Combine( pp.CaselessLiteral("csr") + pp.Word(pp.alphanums + "_") ).setResultsName("name") - + # Vector registers (for the "V" extension) - vector_reg = ( - pp.CaselessLiteral("v").setResultsName("prefix") + - pp.Word(pp.nums).setResultsName("name") - ) - + vector_reg = pp.CaselessLiteral("v").setResultsName("prefix") + pp.Word( + pp.nums + ).setResultsName("name") + # Combined register definition register = pp.Group( - integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg + integer_reg_x + | integer_reg_abi + | fp_reg_f + | fp_reg_abi + | vector_reg + | csr_reg ).setResultsName(self.register_id) - + self.register = register - + # Memory addressing mode in RISC-V: offset(base_register) memory = pp.Group( pp.Optional(immediate.setResultsName("offset")) @@ -189,24 +210,19 @@ class ParserRISCV(BaseParser): + register.setResultsName("base") + pp.Suppress(pp.Literal(")")) ).setResultsName(self.memory_id) - + # Combine to instruction form - operand_first = pp.Group( - register ^ immediate ^ memory ^ identifier - ) - operand_rest = pp.Group( - register ^ immediate ^ memory ^ identifier - ) - - # Vector instruction special parameters (e.g., e32, m4, ta, ma) - vector_param = pp.Word(pp.alphas + pp.nums) - + operand_first = pp.Group(register ^ immediate ^ memory ^ identifier) + operand_rest = pp.Group(register ^ immediate ^ memory ^ identifier) + # Handle additional vector parameters additional_params = pp.ZeroOrMore( - pp.Suppress(pp.Literal(",")) + - vector_param.setResultsName("vector_param", listAllMatches=True) + pp.Suppress(pp.Literal(",")) + + pp.Word(pp.alphas + pp.nums).setResultsName( + "vector_param", listAllMatches=True + ) ) - + # Main instruction parser self.instruction_parser = ( mnemonic @@ -217,7 +233,7 @@ class ParserRISCV(BaseParser): + pp.Optional(operand_rest.setResultsName("operand3")) + pp.Optional(pp.Suppress(pp.Literal(","))) + pp.Optional(operand_rest.setResultsName("operand4")) - + pp.Optional(additional_params) # For vector instructions with more params + + pp.Optional(additional_params) + pp.Optional(self.comment) ) @@ -228,7 +244,8 @@ class ParserRISCV(BaseParser): :param str line: line of assembly code :param line_number: identifier of instruction form, defaults to None :type line_number: int, optional - :return: `dict` -- parsed asm line (comment, label, directive or instruction form) + :return: `dict` -- parsed asm line (comment, label, directive or + instruction form) """ instruction_form = InstructionForm( mnemonic=None, @@ -243,11 +260,13 @@ class ParserRISCV(BaseParser): # 1. Parse comment try: - result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.comment.parseString(line, parseAll=True).asDict() + ) instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass - + # 1.2 check for llvm-mca marker try: result = self.process_operand( @@ -256,12 +275,14 @@ class ParserRISCV(BaseParser): instruction_form.comment = " ".join(result[self.comment_id]) except pp.ParseException: pass - + # 2. Parse label if result is None: try: # returns tuple with label operand and comment, if any - result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) + result = self.process_operand( + self.label.parseString(line, parseAll=True).asDict() + ) instruction_form.label = result[0].name if result[1] is not None: instruction_form.comment = " ".join(result[1]) @@ -294,7 +315,7 @@ class ParserRISCV(BaseParser): instruction_form.mnemonic = result.mnemonic instruction_form.operands = result.operands instruction_form.comment = result.comment - + return instruction_form def parse_instruction(self, instruction): @@ -304,75 +325,99 @@ class ParserRISCV(BaseParser): :param str instruction: Assembly line string. :returns: `dict` -- parsed instruction form """ + # Store current instruction for context in operand processing + if instruction.startswith("vsetvli"): + self.current_instruction = "vsetvli" + else: + # Extract mnemonic for context + parts = instruction.split("#")[0].strip().split() + self.current_instruction = parts[0] if parts else None + # Special handling for vector instructions like vsetvli with many parameters if instruction.startswith("vsetvli"): - parts = instruction.split("#")[0].strip().split() + # Split into mnemonic and operands part + parts = ( + instruction.split("#")[0].strip().split(None, 1) + ) # Split on first whitespace only mnemonic = parts[0] - + # Split operands by commas if len(parts) > 1: operand_part = parts[1] operands_list = [op.strip() for op in operand_part.split(",")] - + # Process each operand operands = [] for op in operands_list: - if op.startswith("x") or op in ["zero", "ra", "sp", "gp", "tp"] or re.match(r"[tas][0-9]+", op): + if ( + op.startswith("x") + or op in ["zero", "ra", "sp", "gp", "tp"] + or re.match(r"[tas][0-9]+", op) + ): operands.append(RegisterOperand(name=op)) - elif op in ["e8", "e16", "e32", "e64", "m1", "m2", "m4", "m8", "ta", "tu", "ma", "mu"]: - operands.append(IdentifierOperand(name=op)) else: - operands.append(IdentifierOperand(name=op)) - + # Vector parameters get appropriate attributes + if op.startswith("e"): # Element width + operands.append(IdentifierOperand(name=op)) + elif op.startswith("m"): # LMUL setting + operands.append(IdentifierOperand(name=op)) + elif op in ["ta", "tu", "ma", "mu"]: # Tail/mask policies + operands.append(IdentifierOperand(name=op)) + else: + operands.append(IdentifierOperand(name=op)) + # Get comment if present comment = None if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - + return InstructionForm( - mnemonic=mnemonic, - operands=operands, - comment_id=comment + mnemonic=mnemonic, operands=operands, comment_id=comment ) - + # Regular instruction parsing try: - result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() + result = self.instruction_parser.parseString( + instruction, parseAll=True + ).asDict() operands = [] - # Add operands to list - # Check first operand - if "operand1" in result: - operand = self.process_operand(result["operand1"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check second operand - if "operand2" in result: - operand = self.process_operand(result["operand2"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check third operand - if "operand3" in result: - operand = self.process_operand(result["operand3"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - # Check fourth operand - if "operand4" in result: - operand = self.process_operand(result["operand4"]) - operands.extend(operand) if isinstance(operand, list) else operands.append(operand) - - # Handle vector_param for vector instructions + + # Process operands + for i in range(1, 5): + operand_key = f"operand{i}" + if operand_key in result: + operand = self.process_operand(result[operand_key]) + ( + operands.extend(operand) + if isinstance(operand, list) + else operands.append(operand) + ) + + # Handle vector parameters as identifiers with appropriate attributes if "vector_param" in result: if isinstance(result["vector_param"], list): for param in result["vector_param"]: - operands.append(IdentifierOperand(name=param)) + if param.startswith("e"): # Element width + operands.append(IdentifierOperand(name=param)) + elif param.startswith("m"): # LMUL setting + operands.append(IdentifierOperand(name=param)) + else: + operands.append(IdentifierOperand(name=param)) else: operands.append(IdentifierOperand(name=result["vector_param"])) - + return_dict = InstructionForm( mnemonic=result["mnemonic"], operands=operands, - comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None, + comment_id=( + " ".join(result[self.comment_id]) + if self.comment_id in result + else None + ), ) return return_dict - - except Exception as e: + + except Exception: # For special vector instructions or ones with % in them if "%" in instruction or instruction.startswith("v"): parts = instruction.split("#")[0].strip().split(None, 1) @@ -383,21 +428,26 @@ class ParserRISCV(BaseParser): operands_list = [op.strip() for op in operand_part.split(",")] for op in operands_list: # Process '%hi(data)' to 'data' for certain operands - if op.startswith("%") and '(' in op and ')' in op: - # Extract data from %hi(data) format - data = op[op.index('(')+1:op.index(')')] - operands.append(IdentifierOperand(name=data)) + if op.startswith("%") and "(" in op and ")" in op: + reloc_type = op[: op.index("(")] + symbol = op[op.index("(") + 1 : op.index(")")] + operands.append( + ImmediateOperand( + imd_type="reloc", + value=None, + reloc_type=reloc_type, + symbol=symbol, + ) + ) else: operands.append(IdentifierOperand(name=op)) - + comment = None if "#" in instruction: comment = instruction.split("#", 1)[1].strip() - + return InstructionForm( - mnemonic=mnemonic, - operands=operands, - comment_id=comment + mnemonic=mnemonic, operands=operands, comment_id=comment ) else: raise @@ -430,62 +480,127 @@ class ParserRISCV(BaseParser): ) def process_register_operand(self, operand): - """Process register operands, including ABI name to x-register mapping""" - # If already has prefix (x#, f#, v#), just return as is + """Process register operands, including ABI name to x-register mapping + and vector attributes""" + # If already has prefix (x#, f#, v#), process with appropriate attributes if "prefix" in operand: - return RegisterOperand( - prefix=operand["prefix"].lower(), - name=operand["name"] - ) - + prefix = operand["prefix"].lower() + + # Special handling for vector registers + if prefix == "v": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="vector", + # Vector registers can have different element widths (e8,e16,e32,e64) + width=operand.get("width", None), + # Number of elements (m1,m2,m4,m8) + lanes=operand.get("lanes", None), + # For vector mask registers + mask=operand.get("mask", False), + # For tail agnostic/undisturbed policies + zeroing=operand.get("zeroing", False), + ) + # For floating point registers + elif prefix == "f": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="float", + width=64, # RISC-V typically uses 64-bit float registers + ) + # For integer registers + elif prefix == "x": + return RegisterOperand( + prefix=prefix, + name=operand["name"], + regtype="int", + width=64, # RV64 uses 64-bit registers + ) + # Handle ABI names by converting to x-register numbers name = operand["name"].lower() - + # ABI name mapping for integer registers abi_to_x = { - "zero": "0", "ra": "1", "sp": "2", "gp": "3", "tp": "4", - "t0": "5", "t1": "6", "t2": "7", - "s0": "8", "fp": "8", "s1": "9", - "a0": "10", "a1": "11", "a2": "12", "a3": "13", - "a4": "14", "a5": "15", "a6": "16", "a7": "17", - "s2": "18", "s3": "19", "s4": "20", "s5": "21", - "s6": "22", "s7": "23", "s8": "24", "s9": "25", - "s10": "26", "s11": "27", - "t3": "28", "t4": "29", "t5": "30", "t6": "31" + "zero": "x0", + "ra": "x1", + "sp": "x2", + "gp": "x3", + "tp": "x4", + "t0": "x5", + "t1": "x6", + "t2": "x7", + "s0": "x8", + "s1": "x9", + "a0": "x10", + "a1": "x11", + "a2": "x12", + "a3": "x13", + "a4": "x14", + "a5": "x15", + "a6": "x16", + "a7": "x17", + "s2": "x18", + "s3": "x19", + "s4": "x20", + "s5": "x21", + "s6": "x22", + "s7": "x23", + "s8": "x24", + "s9": "x25", + "s10": "x26", + "s11": "x27", + "t3": "x28", + "t4": "x29", + "t5": "x30", + "t6": "x31", } - + # Integer register ABI names if name in abi_to_x: return RegisterOperand( prefix="x", - name=abi_to_x[name] + name=abi_to_x[name], + regtype="int", + width=64, # RV64 uses 64-bit registers ) # Floating point register ABI names elif name.startswith("f") and name[1] in ["t", "a", "s"]: if name[1] == "a": # fa0-fa7 idx = int(name[2:]) - return RegisterOperand(prefix="f", name=str(idx + 10)) + return RegisterOperand( + prefix="f", name=str(idx + 10), regtype="float", width=64 + ) elif name[1] == "s": # fs0-fs11 idx = int(name[2:]) if idx <= 1: - return RegisterOperand(prefix="f", name=str(idx + 8)) + return RegisterOperand( + prefix="f", name=str(idx + 8), regtype="float", width=64 + ) else: - return RegisterOperand(prefix="f", name=str(idx + 16)) + return RegisterOperand( + prefix="f", name=str(idx + 16), regtype="float", width=64 + ) elif name[1] == "t": # ft0-ft11 idx = int(name[2:]) if idx <= 7: - return RegisterOperand(prefix="f", name=str(idx)) + return RegisterOperand( + prefix="f", name=str(idx), regtype="float", width=64 + ) else: - return RegisterOperand(prefix="f", name=str(idx + 20)) + return RegisterOperand( + prefix="f", name=str(idx + 20), regtype="float", width=64 + ) # CSR registers elif name.startswith("csr"): - return RegisterOperand(prefix="", name=name) - + return RegisterOperand(prefix="", name=name, regtype="csr") + # If no mapping found, return as is return RegisterOperand(prefix="", name=name) def process_memory_address(self, memory_address): - """Post-process memory address operand""" + """Post-process memory address operand with RISC-V specific attributes""" # Process offset offset = memory_address.get("offset", None) if isinstance(offset, list) and len(offset) == 1: @@ -494,18 +609,38 @@ class ParserRISCV(BaseParser): offset = ImmediateOperand(value=int(offset["value"], 0)) if isinstance(offset, dict) and "identifier" in offset: offset = self.process_identifier(offset["identifier"]) - + # Process base register base = memory_address.get("base", None) if base is not None: base = self.process_register_operand(base) - - # Create memory operand + + # Determine data type from instruction context if available + # RISC-V load/store instructions encode the data width in the mnemonic + # e.g., lw (word), lh (half), lb (byte), etc. + data_type = None + if hasattr(self, "current_instruction"): + mnemonic = self.current_instruction.lower() + if any(x in mnemonic for x in ["b", "bu"]): # byte operations + data_type = "byte" + elif any(x in mnemonic for x in ["h", "hu"]): # halfword operations + data_type = "halfword" + elif any(x in mnemonic for x in ["w", "wu"]): # word operations + data_type = "word" + elif "d" in mnemonic: # doubleword operations + data_type = "doubleword" + + # Create memory operand with enhanced attributes return MemoryOperand( offset=offset, base=base, - index=None, - scale=1 + index=None, # RISC-V doesn't use index registers + scale=1, # RISC-V doesn't use scaling + data_type=data_type, + # Handle vector memory operations + mask=memory_address.get("mask", None), # For vector masked loads/stores + src=memory_address.get("src", None), # Source register type for stores + dst=memory_address.get("dst", None), # Destination register type for loads ) def process_label(self, label): @@ -519,21 +654,102 @@ class ParserRISCV(BaseParser): """Post-process identifier operand""" return IdentifierOperand( name=identifier["name"] if "name" in identifier else None, - offset=identifier["offset"] if "offset" in identifier else None + offset=identifier["offset"] if "offset" in identifier else None, ) - + def process_immediate(self, immediate): - """Post-process immediate operand""" + """Post-process immediate operand with RISC-V specific handling""" + # Handle relocations + if "relocation" in immediate: + reloc = immediate["relocation"] + return ImmediateOperand( + imd_type="reloc", + value=None, + reloc_type=reloc["reloc_type"], + symbol=reloc["symbol"], + ) + + # Handle identifiers if "identifier" in immediate: - # actually an identifier, change declaration return self.process_identifier(immediate["identifier"]) + + # Handle numeric values with validation if "value" in immediate: - # normal integer value - immediate["type"] = "int" - # convert hex/bin immediates to dec - new_immediate = ImmediateOperand(imd_type=immediate["type"], value=immediate["value"]) - new_immediate.value = self.normalize_imd(new_immediate) - return new_immediate + value = int( + immediate["value"], 0 + ) # Convert to integer, handling hex/decimal + + # Determine immediate type and validate range based on instruction type + if hasattr(self, "current_instruction"): + mnemonic = self.current_instruction.lower() + + # I-type instructions (12-bit signed immediate) + if any( + x in mnemonic + for x in [ + "addi", + "slti", + "xori", + "ori", + "andi", + "slli", + "srli", + "srai", + ] + ): + if not -2048 <= value <= 2047: + raise ValueError( + f"Immediate value {value} out of range for I-type " + f"instruction (-2048 to 2047)" + ) + return ImmediateOperand(imd_type="I", value=value) + + # S-type instructions (12-bit signed immediate for store) + elif any(x in mnemonic for x in ["sb", "sh", "sw", "sd"]): + if not -2048 <= value <= 2047: + raise ValueError( + f"Immediate value {value} out of range for S-type " + f"instruction (-2048 to 2047)" + ) + return ImmediateOperand(imd_type="S", value=value) + + # B-type instructions (13-bit signed immediate for branches, must be even) + elif any( + x in mnemonic for x in ["beq", "bne", "blt", "bge", "bltu", "bgeu"] + ): + if not -4096 <= value <= 4095 or value % 2 != 0: + raise ValueError( + f"Immediate value {value} out of range or not even " + f"for B-type instruction (-4096 to 4095, must be even)" + ) + return ImmediateOperand(imd_type="B", value=value) + + # U-type instructions (20-bit upper immediate) + elif any(x in mnemonic for x in ["lui", "auipc"]): + if not 0 <= value <= 1048575: + raise ValueError( + f"Immediate value {value} out of range for U-type " + f"instruction (0 to 1048575)" + ) + return ImmediateOperand(imd_type="U", value=value) + + # J-type instructions (21-bit signed immediate for jumps, must be even) + elif any(x in mnemonic for x in ["jal"]): + if not -1048576 <= value <= 1048575 or value % 2 != 0: + raise ValueError( + f"Immediate value {value} out of range or not even " + f"for J-type instruction (-1048576 to 1048575, must be even)" + ) + return ImmediateOperand(imd_type="J", value=value) + + # Vector instructions might have specific immediate ranges + elif mnemonic.startswith("v"): + # Handle vector specific immediates (implementation specific) + return ImmediateOperand(imd_type="V", value=value) + + # Default case - no specific validation + return ImmediateOperand(imd_type="int", value=value) + return immediate def get_full_reg_name(self, register): @@ -558,44 +774,83 @@ class ParserRISCV(BaseParser): def parse_register(self, register_string): """ Parse register string and return register dictionary. - + :param str register_string: register representation as string :returns: dict with register info """ # Remove any leading/trailing whitespace register_string = register_string.strip() - + # Check for integer registers (x0-x31) - x_match = re.match(r'^x([0-9]|[1-2][0-9]|3[0-1])$', register_string) + x_match = re.match(r"^x([0-9]|[1-2][0-9]|3[0-1])$", register_string) if x_match: reg_num = int(x_match.group(1)) - return {"class": "register", "register": {"prefix": "x", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "x", "name": str(reg_num)}, + } + # Check for floating-point registers (f0-f31) - f_match = re.match(r'^f([0-9]|[1-2][0-9]|3[0-1])$', register_string) + f_match = re.match(r"^f([0-9]|[1-2][0-9]|3[0-1])$", register_string) if f_match: reg_num = int(f_match.group(1)) - return {"class": "register", "register": {"prefix": "f", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "f", "name": str(reg_num)}, + } + # Check for vector registers (v0-v31) - v_match = re.match(r'^v([0-9]|[1-2][0-9]|3[0-1])$', register_string) + v_match = re.match(r"^v([0-9]|[1-2][0-9]|3[0-1])$", register_string) if v_match: reg_num = int(v_match.group(1)) - return {"class": "register", "register": {"prefix": "v", "name": str(reg_num)}} - + return { + "class": "register", + "register": {"prefix": "v", "name": str(reg_num)}, + } + # Check for ABI names abi_names = { - "zero": 0, "ra": 1, "sp": 2, "gp": 3, "tp": 4, - "t0": 5, "t1": 6, "t2": 7, - "s0": 8, "fp": 8, "s1": 9, - "a0": 10, "a1": 11, "a2": 12, "a3": 13, "a4": 14, "a5": 15, "a6": 16, "a7": 17, - "s2": 18, "s3": 19, "s4": 20, "s5": 21, "s6": 22, "s7": 23, "s8": 24, "s9": 25, "s10": 26, "s11": 27, - "t3": 28, "t4": 29, "t5": 30, "t6": 31 + "zero": 0, + "ra": 1, + "sp": 2, + "gp": 3, + "tp": 4, + "t0": 5, + "t1": 6, + "t2": 7, + "s0": 8, + "fp": 8, + "s1": 9, + "a0": 10, + "a1": 11, + "a2": 12, + "a3": 13, + "a4": 14, + "a5": 15, + "a6": 16, + "a7": 17, + "s2": 18, + "s3": 19, + "s4": 20, + "s5": 21, + "s6": 22, + "s7": 23, + "s8": 24, + "s9": 25, + "s10": 26, + "s11": 27, + "t3": 28, + "t4": 29, + "t5": 30, + "t6": 31, } - + if register_string in abi_names: - return {"class": "register", "register": {"prefix": "", "name": register_string}} - + return { + "class": "register", + "register": {"prefix": "", "name": register_string}, + } + # If no match is found return None @@ -626,38 +881,61 @@ class ParserRISCV(BaseParser): """Check if ``reg_a`` is dependent on ``reg_b``""" if not isinstance(reg_a, Operand): reg_a = RegisterOperand(name=reg_a["name"]) - + # Get canonical register names reg_a_canonical = self._get_canonical_reg_name(reg_a) reg_b_canonical = self._get_canonical_reg_name(reg_b) - + # Same register type and number means dependency return reg_a_canonical == reg_b_canonical - + def _get_canonical_reg_name(self, register): """Get the canonical form of a register (x-form for integer, f-form for FP)""" # If already in canonical form (x# or f#) if register.prefix in ["x", "f", "v"] and register.name.isdigit(): return f"{register.prefix}{register.name}" - + # ABI name mapping for integer registers abi_to_x = { - "zero": "x0", "ra": "x1", "sp": "x2", "gp": "x3", "tp": "x4", - "t0": "x5", "t1": "x6", "t2": "x7", - "s0": "x8", "s1": "x9", - "a0": "x10", "a1": "x11", "a2": "x12", "a3": "x13", - "a4": "x14", "a5": "x15", "a6": "x16", "a7": "x17", - "s2": "x18", "s3": "x19", "s4": "x20", "s5": "x21", - "s6": "x22", "s7": "x23", "s8": "x24", "s9": "x25", - "s10": "x26", "s11": "x27", - "t3": "x28", "t4": "x29", "t5": "x30", "t6": "x31" + "zero": "x0", + "ra": "x1", + "sp": "x2", + "gp": "x3", + "tp": "x4", + "t0": "x5", + "t1": "x6", + "t2": "x7", + "s0": "x8", + "s1": "x9", + "a0": "x10", + "a1": "x11", + "a2": "x12", + "a3": "x13", + "a4": "x14", + "a5": "x15", + "a6": "x16", + "a7": "x17", + "s2": "x18", + "s3": "x19", + "s4": "x20", + "s5": "x21", + "s6": "x22", + "s7": "x23", + "s8": "x24", + "s9": "x25", + "s10": "x26", + "s11": "x27", + "t3": "x28", + "t4": "x29", + "t5": "x30", + "t6": "x31", } - + # For integer register ABI names name = register.name.lower() if name in abi_to_x: return abi_to_x[name] - + # For FP register ABI names like fa0, fs1, etc. if name.startswith("f") and len(name) > 1: if name[1] == "a": # fa0-fa7 @@ -675,7 +953,7 @@ class ParserRISCV(BaseParser): return f"f{idx}" else: return f"f{idx + 20}" - + # Return as is if no mapping found return f"{register.prefix}{register.name}" @@ -684,7 +962,7 @@ class ParserRISCV(BaseParser): # Return register prefix if exists if register.prefix: return register.prefix - + # Determine type from ABI name name = register.name.lower() if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]: @@ -693,30 +971,30 @@ class ParserRISCV(BaseParser): return "f" # Floating point register elif name.startswith("csr"): return "csr" # Control and Status Register - + return "unknown" def normalize_instruction_form(self, instruction_form, isa_model, arch_model): """ Normalize instruction form for RISC-V instructions. - + :param instruction_form: instruction form to normalize :param isa_model: ISA model to use for normalization :param arch_model: architecture model to use for normalization """ if instruction_form.normalized: return - + if instruction_form.mnemonic is None: instruction_form.normalized = True return - + # Normalize the mnemonic if needed if instruction_form.mnemonic: # Handle any RISC-V specific mnemonic normalization # For example, convert aliases or pseudo-instructions to their base form pass - + # Normalize the operands if needed for i, operand in enumerate(instruction_form.operands): if isinstance(operand, ImmediateOperand): @@ -725,8 +1003,8 @@ class ParserRISCV(BaseParser): elif isinstance(operand, RegisterOperand): # Convert register names to canonical form if needed pass - - instruction_form.normalized = True + + instruction_form.normalized = True def get_regular_source_operands(self, instruction_form): """Get source operand of given instruction form assuming regular src/dst behavior.""" @@ -736,14 +1014,14 @@ class ParserRISCV(BaseParser): return [instruction_form.operands[0]] else: return [op for op in instruction_form.operands[1:]] - + def get_regular_destination_operands(self, instruction_form): """Get destination operand of given instruction form assuming regular src/dst behavior.""" # For RISC-V, the first operand is typically the destination if len(instruction_form.operands) == 1: return [] else: - return instruction_form.operands[:1] + return instruction_form.operands[:1] def process_immediate_operand(self, operand): """Process immediate operands, converting them to ImmediateOperand objects""" @@ -751,7 +1029,7 @@ class ParserRISCV(BaseParser): # For raw integer values or string immediates return ImmediateOperand( imd_type="int", - value=str(operand) if isinstance(operand, int) else operand + value=str(operand) if isinstance(operand, int) else operand, ) elif isinstance(operand, dict) and "imd" in operand: # For immediate operands from instruction definitions @@ -759,11 +1037,8 @@ class ParserRISCV(BaseParser): imd_type=operand["imd"], value=operand.get("value"), identifier=operand.get("identifier"), - shift=operand.get("shift") + shift=operand.get("shift"), ) else: # For any other immediate format - return ImmediateOperand( - imd_type="int", - value=str(operand) - ) \ No newline at end of file + return ImmediateOperand(imd_type="int", value=str(operand)) diff --git a/osaca/semantics/hw_model.py b/osaca/semantics/hw_model.py index 4979119..c178190 100644 --- a/osaca/semantics/hw_model.py +++ b/osaca/semantics/hw_model.py @@ -873,13 +873,13 @@ class MachineModel(object): if not isinstance(i_operand, RegisterOperand): return False return self._is_RISCV_reg_type(i_operand, operand) - + # memory if isinstance(operand, MemoryOperand): if not isinstance(i_operand, MemoryOperand): return False return self._is_RISCV_mem_type(i_operand, operand) - + # immediate if isinstance(operand, (ImmediateOperand, int)): if not isinstance(i_operand, ImmediateOperand): @@ -895,7 +895,7 @@ class MachineModel(object): if i_operand.imd_type == self.WILDCARD: return True return False - + # identifier if isinstance(operand, IdentifierOperand) or ( isinstance(operand, ImmediateOperand) and operand.identifier is not None @@ -1011,7 +1011,7 @@ class MachineModel(object): # check for wildcards if reg.prefix == self.WILDCARD or i_reg.prefix == self.WILDCARD: return True - + # First handle potentially None values to avoid AttributeError if reg.name is None or i_reg.name is None: # If both have same prefix, they might still match @@ -1019,12 +1019,13 @@ class MachineModel(object): return True # If we can't determine canonical names, be conservative and return False return False - + # Check for ABI name (a0, t0, etc.) vs x-prefix registers (x10, x5, etc.) if (reg.prefix is None and i_reg.prefix == "x") or (reg.prefix == "x" and i_reg.prefix is None): try: # Need to check if they refer to the same register from osaca.parser import ParserRISCV + parser = ParserRISCV() reg_canonical = parser._get_canonical_reg_name(reg) i_reg_canonical = parser._get_canonical_reg_name(i_reg) @@ -1032,16 +1033,16 @@ class MachineModel(object): return True except (AttributeError, KeyError): return False - + # Check for direct prefix matches if reg.prefix == i_reg.prefix: # For vector registers, check lanes if present if reg.prefix == "v" and reg.lanes is not None and i_reg.lanes is not None: return reg.lanes == i_reg.lanes or self.WILDCARD in (reg.lanes + i_reg.lanes) return True - + return False - + def _is_AArch64_mem_type(self, i_mem, mem): """Check if memory addressing type match.""" if ( @@ -1181,4 +1182,4 @@ class MachineModel(object): def __represent_none(self, yaml_obj, data): """YAML representation for `None`""" - return yaml_obj.represent_scalar("tag:yaml.org,2002:null", "~") \ No newline at end of file + return yaml_obj.represent_scalar("tag:yaml.org,2002:null", "~") diff --git a/tests/test_cli.py b/tests/test_cli.py index 3a2a018..ae24ab0 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -307,11 +307,13 @@ class TestCLI(unittest.TestCase): @staticmethod def _find_file(kernel, arch, comp): testdir = os.path.dirname(__file__) + # Handle special case for rv64 architecture + arch_prefix = arch.lower() if arch.lower() == "rv64" else arch[:3].lower() name = os.path.join( testdir, "../examples", kernel, - kernel + ".s." + arch[:3].lower() + "." + comp.lower() + ".s", + kernel + ".s." + arch_prefix + "." + comp.lower() + ".s", ) if kernel == "j2d" and arch.lower() == "csx": name = name[:-1] + "AVX.s" diff --git a/tests/test_parser_RISCV.py b/tests/test_parser_RISCV.py index 0f0b621..2df6d58 100644 --- a/tests/test_parser_RISCV.py +++ b/tests/test_parser_RISCV.py @@ -8,9 +8,7 @@ import unittest from pyparsing import ParseException -from osaca.parser import ParserRISCV, InstructionForm -from osaca.parser.directive import DirectiveOperand -from osaca.parser.memory import MemoryOperand +from osaca.parser import ParserRISCV from osaca.parser.register import RegisterOperand from osaca.parser.immediate import ImmediateOperand from osaca.parser.identifier import IdentifierOperand @@ -105,7 +103,7 @@ class TestParserRISCV(unittest.TestCase): # Test 1: Line with label and instruction parsed_1 = self.parser.parse_line(".L2:") self.assertEqual(parsed_1.label, ".L2") - + # Test 2: Line with instruction and comment parsed_2 = self.parser.parse_line("addi x10, x10, 1 # increment") self.assertEqual(parsed_2.mnemonic, "addi") @@ -118,14 +116,14 @@ class TestParserRISCV(unittest.TestCase): def test_parse_file(self): parsed = self.parser.parse_file(self.riscv_code) self.assertGreater(len(parsed), 10) # There should be multiple lines - + # Find common elements that should exist in any RISC-V file # without being tied to specific line numbers - + # Verify that we can find at least one label label_forms = [form for form in parsed if form.label is not None] self.assertGreater(len(label_forms), 0, "No labels found in the file") - + # Verify that we can find at least one branch instruction branch_forms = [form for form in parsed if form.mnemonic and form.mnemonic.startswith("b")] self.assertGreater(len(branch_forms), 0, "No branch instructions found in the file") @@ -148,7 +146,7 @@ class TestParserRISCV(unittest.TestCase): reg_a0 = RegisterOperand(name="a0") reg_t1 = RegisterOperand(name="t1") reg_s2 = RegisterOperand(name="s2") - + reg_x0 = RegisterOperand(prefix="x", name="0") reg_x1 = RegisterOperand(prefix="x", name="1") reg_x2 = RegisterOperand(prefix="x", name="2") @@ -156,7 +154,7 @@ class TestParserRISCV(unittest.TestCase): reg_x10 = RegisterOperand(prefix="x", name="10") reg_x6 = RegisterOperand(prefix="x", name="6") reg_x18 = RegisterOperand(prefix="x", name="18") - + # Test canonical name conversion self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0") self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1") @@ -164,7 +162,7 @@ class TestParserRISCV(unittest.TestCase): self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10") self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6") self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18") - + # Test register dependency self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0)) self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1)) @@ -172,29 +170,27 @@ class TestParserRISCV(unittest.TestCase): self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10)) self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6)) self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18)) - + # Test non-dependent registers self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1)) self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2)) self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1)) - + # Test floating-point registers reg_fa0 = RegisterOperand(prefix="f", name="a0") - reg_fa1 = RegisterOperand(prefix="f", name="a1") reg_f10 = RegisterOperand(prefix="f", name="10") - + # Test vector registers reg_v1 = RegisterOperand(prefix="v", name="1") - reg_v2 = RegisterOperand(prefix="v", name="2") - + # Test register type detection self.assertTrue(self.parser.is_gpr(reg_a0)) self.assertTrue(self.parser.is_gpr(reg_x5)) self.assertTrue(self.parser.is_gpr(reg_sp)) - + self.assertFalse(self.parser.is_gpr(reg_fa0)) self.assertFalse(self.parser.is_gpr(reg_f10)) - + self.assertTrue(self.parser.is_vector_register(reg_v1)) self.assertFalse(self.parser.is_vector_register(reg_x10)) self.assertFalse(self.parser.is_vector_register(reg_fa0))