From 7e546d970fb0626143b7c05d3e8fed06278d5c48 Mon Sep 17 00:00:00 2001 From: Metehan Dundar Date: Tue, 4 Mar 2025 00:44:38 +0100 Subject: [PATCH] Parser for RISCV is implemented and tested with a simple kernel. Changes to be committed: modified: osaca/parser/__init__.py new file: osaca/parser/parser_RISCV.py new file: tests/test_files/kernel_riscv.s new file: tests/test_parser_RISCV.py --- osaca/parser/__init__.py | 4 + osaca/parser/parser_RISCV.py | 643 ++++++++++++++++++++++++++++++++ tests/test_files/kernel_riscv.s | 120 ++++++ tests/test_parser_RISCV.py | 321 ++++++++++++++++ 4 files changed, 1088 insertions(+) create mode 100644 osaca/parser/parser_RISCV.py create mode 100644 tests/test_files/kernel_riscv.s create mode 100644 tests/test_parser_RISCV.py diff --git a/osaca/parser/__init__.py b/osaca/parser/__init__.py index 3b5e8ba..96218c7 100644 --- a/osaca/parser/__init__.py +++ b/osaca/parser/__init__.py @@ -7,6 +7,7 @@ Only the parser below will be exported, so please add new parsers to __all__. from .base_parser import BaseParser from .parser_x86att import ParserX86ATT from .parser_AArch64 import ParserAArch64 +from .parser_RISCV import ParserRISCV from .instruction_form import InstructionForm from .operand import Operand @@ -16,6 +17,7 @@ __all__ = [ "BaseParser", "ParserX86ATT", "ParserAArch64", + "ParserRISCV", "get_parser", ] @@ -25,5 +27,7 @@ def get_parser(isa): return ParserX86ATT() elif isa.lower() == "aarch64": return ParserAArch64() + elif isa.lower() == "riscv": + return ParserRISCV() else: raise ValueError("Unknown ISA {!r}.".format(isa)) diff --git a/osaca/parser/parser_RISCV.py b/osaca/parser/parser_RISCV.py new file mode 100644 index 0000000..8efda9a --- /dev/null +++ b/osaca/parser/parser_RISCV.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python3 +import re +import os +import logging +from copy import deepcopy +import pyparsing as pp + +logger = logging.getLogger(__name__) + +from osaca.parser import BaseParser +from osaca.parser.instruction_form import InstructionForm +from osaca.parser.operand import Operand +from osaca.parser.directive import DirectiveOperand +from osaca.parser.memory import MemoryOperand +from osaca.parser.label import LabelOperand +from osaca.parser.register import RegisterOperand +from osaca.parser.identifier import IdentifierOperand +from osaca.parser.immediate import ImmediateOperand +from osaca.parser.condition import ConditionOperand + + +class ParserRISCV(BaseParser): + _instance = None + + # Singleton pattern, as this is created very many times + def __new__(cls): + if cls._instance is None: + cls._instance = super(ParserRISCV, cls).__new__(cls) + return cls._instance + + def __init__(self): + super().__init__() + self.isa = "riscv" + + def construct_parser(self): + """Create parser for RISC-V ISA.""" + # Comment - RISC-V uses # for comments + symbol_comment = "#" + self.comment = pp.Literal(symbol_comment) + pp.Group( + pp.ZeroOrMore(pp.Word(pp.printables)) + ).setResultsName(self.comment_id) + + # Define RISC-V assembly identifier + decimal_number = pp.Combine( + pp.Optional(pp.Literal("-")) + pp.Word(pp.nums) + ).setResultsName("value") + hex_number = pp.Combine( + pp.Optional(pp.Literal("-")) + pp.Literal("0x") + pp.Word(pp.hexnums) + ).setResultsName("value") + + # Additional identifiers used in vector instructions + vector_identifier = pp.Word(pp.alphas, pp.alphanums) + special_identifier = pp.Word(pp.alphas + "%") + + first = pp.Word(pp.alphas + "_.", exact=1) + rest = pp.Word(pp.alphanums + "_.") + identifier = pp.Group( + pp.Combine(first + pp.Optional(rest)).setResultsName("name") + + pp.Optional( + pp.Suppress(pp.Literal("+")) + + (hex_number | decimal_number).setResultsName("offset") + ) + ).setResultsName(self.identifier) + + # Label + self.label = pp.Group( + identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment) + ).setResultsName(self.label_id) + + # Directive + directive_option = pp.Combine( + pp.Word(pp.alphas + "#@.%", exact=1) + + pp.Optional(pp.Word(pp.printables + " ", excludeChars=",")) + ) + + # For vector instructions + vector_parameter = pp.Word(pp.alphas) + directive_parameter = ( + pp.quotedString | directive_option | identifier | hex_number | decimal_number + ) + commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",") + self.directive = pp.Group( + pp.Literal(".") + + pp.Word(pp.alphanums + "_").setResultsName("name") + + (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters") + + pp.Optional(self.comment) + ).setResultsName(self.directive_id) + + # LLVM-MCA markers + self.llvm_markers = pp.Group( + pp.Literal("#") + + pp.Combine( + pp.CaselessLiteral("LLVM-MCA-") + + (pp.CaselessLiteral("BEGIN") | pp.CaselessLiteral("END")) + ) + + pp.Optional(self.comment) + ).setResultsName(self.comment_id) + + ############################## + # Instructions + # Mnemonic + mnemonic = pp.Word(pp.alphanums + ".").setResultsName("mnemonic") + + # Immediate: + # int: ^-?[0-9]+ | hex: ^0x[0-9a-fA-F]+ + immediate = pp.Group( + (hex_number ^ decimal_number) + | identifier + ).setResultsName(self.immediate_id) + + # Register: + # RISC-V has two main types of registers: + # 1. Integer registers (x0-x31 or ABI names) + # 2. Floating-point registers (f0-f31 or ABI names) + + # Integer register ABI names + integer_reg_abi = ( + pp.CaselessLiteral("zero") | + pp.CaselessLiteral("ra") | + pp.CaselessLiteral("sp") | + pp.CaselessLiteral("gp") | + pp.CaselessLiteral("tp") | + pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11 + ).setResultsName("name") + + # Integer registers x0-x31 + integer_reg_x = ( + pp.CaselessLiteral("x").setResultsName("prefix") + + pp.Word(pp.nums).setResultsName("name") + ) + + # Floating point registers + fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName("name") # ft0-ft11, fa0-fa7, fs0-fs11 + + fp_reg_f = ( + pp.CaselessLiteral("f").setResultsName("prefix") + + pp.Word(pp.nums).setResultsName("name") + ) + + # Control and status registers (CSRs) + csr_reg = pp.Combine( + pp.CaselessLiteral("csr") + pp.Word(pp.alphanums + "_") + ).setResultsName("name") + + # Vector registers (for the "V" extension) + vector_reg = ( + pp.CaselessLiteral("v").setResultsName("prefix") + + pp.Word(pp.nums).setResultsName("name") + ) + + # Combined register definition + register = pp.Group( + integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg + ).setResultsName(self.register_id) + + self.register = register + + # Memory addressing mode in RISC-V: offset(base_register) + memory = pp.Group( + pp.Optional(immediate.setResultsName("offset")) + + pp.Suppress(pp.Literal("(")) + + register.setResultsName("base") + + pp.Suppress(pp.Literal(")")) + ).setResultsName(self.memory_id) + + # Combine to instruction form + operand_first = pp.Group( + register ^ immediate ^ memory ^ identifier + ) + operand_rest = pp.Group( + register ^ immediate ^ memory ^ identifier + ) + + # Vector instruction special parameters (e.g., e32, m4, ta, ma) + vector_param = pp.Word(pp.alphas + pp.nums) + + # Handle additional vector parameters + additional_params = pp.ZeroOrMore( + pp.Suppress(pp.Literal(",")) + + vector_param.setResultsName("vector_param", listAllMatches=True) + ) + + # Main instruction parser + self.instruction_parser = ( + mnemonic + + pp.Optional(operand_first.setResultsName("operand1")) + + pp.Optional(pp.Suppress(pp.Literal(","))) + + pp.Optional(operand_rest.setResultsName("operand2")) + + pp.Optional(pp.Suppress(pp.Literal(","))) + + pp.Optional(operand_rest.setResultsName("operand3")) + + pp.Optional(pp.Suppress(pp.Literal(","))) + + pp.Optional(operand_rest.setResultsName("operand4")) + + pp.Optional(additional_params) # For vector instructions with more params + + pp.Optional(self.comment) + ) + + def parse_line(self, line, line_number=None): + """ + Parse line and return instruction form. + + :param str line: line of assembly code + :param line_number: identifier of instruction form, defaults to None + :type line_number: int, optional + :return: `dict` -- parsed asm line (comment, label, directive or instruction form) + """ + instruction_form = InstructionForm( + mnemonic=None, + operands=[], + directive_id=None, + comment_id=None, + label_id=None, + line=line, + line_number=line_number, + ) + result = None + + # 1. Parse comment + try: + result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict()) + instruction_form.comment = " ".join(result[self.comment_id]) + except pp.ParseException: + pass + + # 1.2 check for llvm-mca marker + try: + result = self.process_operand( + self.llvm_markers.parseString(line, parseAll=True).asDict() + ) + instruction_form.comment = " ".join(result[self.comment_id]) + except pp.ParseException: + pass + + # 2. Parse label + if result is None: + try: + # returns tuple with label operand and comment, if any + result = self.process_operand(self.label.parseString(line, parseAll=True).asDict()) + instruction_form.label = result[0].name + if result[1] is not None: + instruction_form.comment = " ".join(result[1]) + except pp.ParseException: + pass + + # 3. Parse directive + if result is None: + try: + # returns directive with label operand and comment, if any + result = self.process_operand( + self.directive.parseString(line, parseAll=True).asDict() + ) + instruction_form.directive = DirectiveOperand( + name=result[0].name, parameters=result[0].parameters + ) + if result[1] is not None: + instruction_form.comment = " ".join(result[1]) + except pp.ParseException: + pass + + # 4. Parse instruction + if result is None: + try: + result = self.parse_instruction(line) + except (pp.ParseException, KeyError) as e: + raise ValueError( + "Unable to parse {!r} on line {}".format(line, line_number) + ) from e + instruction_form.mnemonic = result.mnemonic + instruction_form.operands = result.operands + instruction_form.comment = result.comment + + return instruction_form + + def parse_instruction(self, instruction): + """ + Parse instruction in asm line. + + :param str instruction: Assembly line string. + :returns: `dict` -- parsed instruction form + """ + # Special handling for vector instructions like vsetvli with many parameters + if instruction.startswith("vsetvli"): + parts = instruction.split("#")[0].strip().split() + mnemonic = parts[0] + + # Split operands by commas + if len(parts) > 1: + operand_part = parts[1] + operands_list = [op.strip() for op in operand_part.split(",")] + + # Process each operand + operands = [] + for op in operands_list: + if op.startswith("x") or op in ["zero", "ra", "sp", "gp", "tp"] or re.match(r"[tas][0-9]+", op): + operands.append(RegisterOperand(name=op)) + elif op in ["e8", "e16", "e32", "e64", "m1", "m2", "m4", "m8", "ta", "tu", "ma", "mu"]: + operands.append(IdentifierOperand(name=op)) + else: + operands.append(IdentifierOperand(name=op)) + + # Get comment if present + comment = None + if "#" in instruction: + comment = instruction.split("#", 1)[1].strip() + + return InstructionForm( + mnemonic=mnemonic, + operands=operands, + comment_id=comment + ) + + # Regular instruction parsing + try: + result = self.instruction_parser.parseString(instruction, parseAll=True).asDict() + operands = [] + # Add operands to list + # Check first operand + if "operand1" in result: + operand = self.process_operand(result["operand1"]) + operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + # Check second operand + if "operand2" in result: + operand = self.process_operand(result["operand2"]) + operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + # Check third operand + if "operand3" in result: + operand = self.process_operand(result["operand3"]) + operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + # Check fourth operand + if "operand4" in result: + operand = self.process_operand(result["operand4"]) + operands.extend(operand) if isinstance(operand, list) else operands.append(operand) + + # Handle vector_param for vector instructions + if "vector_param" in result: + if isinstance(result["vector_param"], list): + for param in result["vector_param"]: + operands.append(IdentifierOperand(name=param)) + else: + operands.append(IdentifierOperand(name=result["vector_param"])) + + return_dict = InstructionForm( + mnemonic=result["mnemonic"], + operands=operands, + comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None, + ) + return return_dict + + except Exception as e: + logger.debug(f"Error parsing instruction: {instruction} - {str(e)}") + # For special vector instructions or ones with % in them + if "%" in instruction or instruction.startswith("v"): + parts = instruction.split("#")[0].strip().split(None, 1) + mnemonic = parts[0] + operands = [] + if len(parts) > 1: + operand_part = parts[1] + operands_list = [op.strip() for op in operand_part.split(",")] + for op in operands_list: + # Process '%hi(data)' to 'data' for certain operands + if op.startswith("%") and '(' in op and ')' in op: + # Extract data from %hi(data) format + data = op[op.index('(')+1:op.index(')')] + operands.append(IdentifierOperand(name=data)) + else: + operands.append(IdentifierOperand(name=op)) + + comment = None + if "#" in instruction: + comment = instruction.split("#", 1)[1].strip() + + return InstructionForm( + mnemonic=mnemonic, + operands=operands, + comment_id=comment + ) + else: + raise + + def process_operand(self, operand): + """Post-process operand""" + # structure memory addresses + if self.memory_id in operand: + return self.process_memory_address(operand[self.memory_id]) + # add value attribute to immediates + if self.immediate_id in operand: + return self.process_immediate(operand[self.immediate_id]) + if self.label_id in operand: + return self.process_label(operand[self.label_id]) + if self.identifier in operand: + return self.process_identifier(operand[self.identifier]) + if self.register_id in operand: + return self.process_register_operand(operand[self.register_id]) + if self.directive_id in operand: + return self.process_directive_operand(operand[self.directive_id]) + return operand + + def process_directive_operand(self, operand): + return ( + DirectiveOperand( + name=operand["name"], + parameters=operand["parameters"], + ), + operand["comment"] if "comment" in operand else None, + ) + + def process_register_operand(self, operand): + """Process register operands, including ABI name to x-register mapping""" + # Handle ABI names by adding the appropriate prefix + if "prefix" not in operand: + name = operand["name"].lower() + # Integer register ABI names + if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]: + prefix = "x" + # Floating point register ABI names + elif name[0] == "f" and name[1] in ["t", "a", "s"]: + prefix = "f" + # CSR registers + elif name.startswith("csr"): + prefix = "" + else: + prefix = "" + + return RegisterOperand( + prefix=prefix, + name=name + ) + else: + return RegisterOperand( + prefix=operand["prefix"].lower(), + name=operand["name"] + ) + + def process_memory_address(self, memory_address): + """Post-process memory address operand""" + # Process offset + offset = memory_address.get("offset", None) + if isinstance(offset, list) and len(offset) == 1: + offset = offset[0] + if offset is not None and "value" in offset: + offset = ImmediateOperand(value=int(offset["value"], 0)) + if isinstance(offset, dict) and "identifier" in offset: + offset = self.process_identifier(offset["identifier"]) + + # Process base register + base = memory_address.get("base", None) + if base is not None: + base = self.process_register_operand(base) + + # Create memory operand + return MemoryOperand( + offset=offset, + base=base, + index=None, + scale=1 + ) + + def process_label(self, label): + """Post-process label asm line""" + return ( + LabelOperand(name=label["name"]["name"]), + label["comment"] if self.comment_id in label else None, + ) + + def process_identifier(self, identifier): + """Post-process identifier operand""" + return IdentifierOperand( + name=identifier["name"] if "name" in identifier else None, + offset=identifier["offset"] if "offset" in identifier else None + ) + + def process_immediate(self, immediate): + """Post-process immediate operand""" + if "identifier" in immediate: + # actually an identifier, change declaration + return self.process_identifier(immediate["identifier"]) + if "value" in immediate: + # normal integer value + immediate["type"] = "int" + # convert hex/bin immediates to dec + new_immediate = ImmediateOperand(imd_type=immediate["type"], value=immediate["value"]) + new_immediate.value = self.normalize_imd(new_immediate) + return new_immediate + return immediate + + def get_full_reg_name(self, register): + """Return one register name string including all attributes""" + if register.prefix and register.name: + return register.prefix + str(register.name) + return str(register.name) + + def normalize_imd(self, imd): + """Normalize immediate to decimal based representation""" + if isinstance(imd, IdentifierOperand): + return imd + elif imd.value is not None: + if isinstance(imd.value, str): + # hex or bin, return decimal + return int(imd.value, 0) + else: + return imd.value + # identifier + return imd + + def parse_register(self, register_string): + """ + Parse register string and return register dictionary. + + :param str register_string: register representation as string + :returns: dict with register info + """ + # Remove any leading/trailing whitespace + register_string = register_string.strip() + + # Check for integer registers (x0-x31) + x_match = re.match(r'^x([0-9]|[1-2][0-9]|3[0-1])$', register_string) + if x_match: + reg_num = int(x_match.group(1)) + return {"class": "register", "register": {"prefix": "x", "name": str(reg_num)}} + + # Check for floating-point registers (f0-f31) + f_match = re.match(r'^f([0-9]|[1-2][0-9]|3[0-1])$', register_string) + if f_match: + reg_num = int(f_match.group(1)) + return {"class": "register", "register": {"prefix": "f", "name": str(reg_num)}} + + # Check for vector registers (v0-v31) + v_match = re.match(r'^v([0-9]|[1-2][0-9]|3[0-1])$', register_string) + if v_match: + reg_num = int(v_match.group(1)) + return {"class": "register", "register": {"prefix": "v", "name": str(reg_num)}} + + # Check for ABI names + abi_names = { + "zero": 0, "ra": 1, "sp": 2, "gp": 3, "tp": 4, + "t0": 5, "t1": 6, "t2": 7, + "s0": 8, "fp": 8, "s1": 9, + "a0": 10, "a1": 11, "a2": 12, "a3": 13, "a4": 14, "a5": 15, "a6": 16, "a7": 17, + "s2": 18, "s3": 19, "s4": 20, "s5": 21, "s6": 22, "s7": 23, "s8": 24, "s9": 25, "s10": 26, "s11": 27, + "t3": 28, "t4": 29, "t5": 30, "t6": 31 + } + + if register_string in abi_names: + return {"class": "register", "register": {"prefix": "", "name": register_string}} + + # If no match is found + return None + + def is_gpr(self, register): + """Check if register is a general purpose register""" + # Integer registers: x0-x31 or ABI names + if register.prefix == "x": + return True + if not register.prefix and register.name in ["zero", "ra", "sp", "gp", "tp"]: + return True + if not register.prefix and register.name[0] in ["t", "a", "s"]: + return True + return False + + def is_vector_register(self, register): + """Check if register is a vector register""" + # Vector registers: v0-v31 + if register.prefix == "v": + return True + return False + + def is_flag_dependend_of(self, flag_a, flag_b): + """Check if ``flag_a`` is dependent on ``flag_b``""" + # RISC-V doesn't have explicit flags like x86 or AArch64 + return flag_a.name == flag_b.name + + def is_reg_dependend_of(self, reg_a, reg_b): + """Check if ``reg_a`` is dependent on ``reg_b``""" + if not isinstance(reg_a, Operand): + reg_a = RegisterOperand(name=reg_a["name"]) + + # Get canonical register names + reg_a_canonical = self._get_canonical_reg_name(reg_a) + reg_b_canonical = self._get_canonical_reg_name(reg_b) + + # Same register type and number means dependency + return reg_a_canonical == reg_b_canonical + + def _get_canonical_reg_name(self, register): + """Get the canonical form of a register (x-form for integer, f-form for FP)""" + # If already in canonical form (x# or f#) + if register.prefix in ["x", "f", "v"] and register.name.isdigit(): + return f"{register.prefix}{register.name}" + + # ABI name mapping for integer registers + abi_to_x = { + "zero": "x0", "ra": "x1", "sp": "x2", "gp": "x3", "tp": "x4", + "t0": "x5", "t1": "x6", "t2": "x7", + "s0": "x8", "s1": "x9", + "a0": "x10", "a1": "x11", "a2": "x12", "a3": "x13", + "a4": "x14", "a5": "x15", "a6": "x16", "a7": "x17", + "s2": "x18", "s3": "x19", "s4": "x20", "s5": "x21", + "s6": "x22", "s7": "x23", "s8": "x24", "s9": "x25", + "s10": "x26", "s11": "x27", + "t3": "x28", "t4": "x29", "t5": "x30", "t6": "x31" + } + + # For integer register ABI names + name = register.name.lower() + if name in abi_to_x: + return abi_to_x[name] + + # For FP register ABI names like fa0, fs1, etc. + if name.startswith("f") and len(name) > 1: + if name[1] == "a": # fa0-fa7 + idx = int(name[2:]) + return f"f{idx + 10}" + elif name[1] == "s": # fs0-fs11 + idx = int(name[2:]) + if idx <= 1: + return f"f{idx + 8}" + else: + return f"f{idx + 16}" + elif name[1] == "t": # ft0-ft11 + idx = int(name[2:]) + if idx <= 7: + return f"f{idx}" + else: + return f"f{idx + 20}" + + # Return as is if no mapping found + return f"{register.prefix}{register.name}" + + def get_reg_type(self, register): + """Get register type""" + # Return register prefix if exists + if register.prefix: + return register.prefix + + # Determine type from ABI name + name = register.name.lower() + if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]: + return "x" # Integer register + elif name.startswith("f"): + return "f" # Floating point register + elif name.startswith("csr"): + return "csr" # Control and Status Register + + return "unknown" \ No newline at end of file diff --git a/tests/test_files/kernel_riscv.s b/tests/test_files/kernel_riscv.s new file mode 100644 index 0000000..cb6845e --- /dev/null +++ b/tests/test_files/kernel_riscv.s @@ -0,0 +1,120 @@ +# Basic RISC-V test kernel with various instructions + +.text +.globl vector_add +.align 2 + +# Example of a basic function +vector_add: + # Prologue + addi sp, sp, -16 + sw ra, 12(sp) + sw s0, 8(sp) + addi s0, sp, 16 + + # Setup + mv a3, a0 + lw a0, 0(a0) # Load first element + lw a4, 0(a1) # Load second element + add a0, a0, a4 # Add elements + sw a0, 0(a2) # Store to result array + + # Integer operations + addi t0, zero, 10 + addi t1, zero, 5 + add t2, t0, t1 + sub t3, t0, t1 + and t4, t0, t1 + or t5, t0, t1 + xor t6, t0, t1 + sll a0, t0, t1 + srl a1, t0, t1 + sra a2, t0, t1 + + # Memory operations + lw a0, 8(sp) + sw a1, 4(sp) + lbu a2, 1(sp) + sb a3, 0(sp) + lh a4, 2(sp) + sh a5, 2(sp) + + # Branch and jump instructions + beq t0, t1, skip + bne t0, t1, continue + jal ra, function + jalr t0, 0(ra) + +.L1: # Loop Header + beq t0, t1, .L2 + addi t0, t0, 1 + j .L1 + +.L2: + # Floating point operations + flw fa0, 0(a0) + flw fa1, 4(a0) + fadd.s fa2, fa0, fa1 + fsub.s fa3, fa0, fa1 + fmv.x.w a0, fa0 + fmv.w.x fa4, a0 + + # CSR operations + csrr t0, mstatus + csrw mtvec, t0 + csrs mie, t0 + csrc mip, t0 + + # Vector instructions (RVV) + vsetvli t0, a0, e32, m4, ta, ma + vle32.v v0, (a0) + vle32.v v4, (a1) + vadd.vv v8, v0, v4 + vse32.v v8, (a2) + + # Atomic operations + lr.w t0, (a0) + sc.w t1, t2, (a0) + amoswap.w t3, t4, (a0) + amoadd.w t5, t6, (a0) + + # Multiply/divide instructions + mul t0, t1, t2 + mulh t3, t4, t5 + div t0, t1, t2 + rem t3, t4, t5 + + # Pseudo-instructions + li t0, 1234 + la t1, data + li a0, %hi(data) + addi a1, a0, %lo(data) + +skip: + # Skip destination + addi t2, zero, 20 + +continue: + # Continue destination + addi t3, zero, 30 + +function: + # Function destination + addi a0, zero, 0 + ret + + # Epilogue + lw ra, 12(sp) + lw s0, 8(sp) + addi sp, sp, 16 + ret + +.data +.align 4 +data: + .word 0x12345678 + .byte 0x01, 0x02, 0x03, 0x04 + .half 0xABCD, 0xEF01 + .float 3.14159 + .space 16 + .ascii "RISC-V Test String" \ No newline at end of file diff --git a/tests/test_parser_RISCV.py b/tests/test_parser_RISCV.py new file mode 100644 index 0000000..e5a696c --- /dev/null +++ b/tests/test_parser_RISCV.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +""" +Unit tests for RISC-V assembly parser +""" + +import os +import unittest + +from pyparsing import ParseException + +from osaca.parser import ParserRISCV, InstructionForm +from osaca.parser.directive import DirectiveOperand +from osaca.parser.memory import MemoryOperand +from osaca.parser.register import RegisterOperand +from osaca.parser.immediate import ImmediateOperand +from osaca.parser.identifier import IdentifierOperand + + +class TestParserRISCV(unittest.TestCase): + @classmethod + def setUpClass(self): + self.parser = ParserRISCV() + with open(self._find_file("kernel_riscv.s")) as f: + self.riscv_code = f.read() + + ################## + # Test + ################## + + def test_comment_parser(self): + self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments") + self.assertEqual( + self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), "RISC-V comment end" + ) + self.assertEqual( + self._get_comment(self.parser, "\t## comment ## comment"), + "# comment ## comment", + ) + + def test_label_parser(self): + self.assertEqual(self._get_label(self.parser, "main:")[0].name, "main") + self.assertEqual(self._get_label(self.parser, "loop_start:")[0].name, "loop_start") + self.assertEqual(self._get_label(self.parser, ".L1:\t\t\t# comment")[0].name, ".L1") + self.assertEqual( + " ".join(self._get_label(self.parser, ".L1:\t\t\t# comment")[1]), + "comment", + ) + with self.assertRaises(ParseException): + self._get_label(self.parser, "\t.cfi_startproc") + + def test_directive_parser(self): + self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text") + self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0) + self.assertEqual(self._get_directive(self.parser, "\t.align\t4")[0].name, "align") + self.assertEqual( + len(self._get_directive(self.parser, "\t.align\t4")[0].parameters), 1 + ) + self.assertEqual( + self._get_directive(self.parser, "\t.align\t4")[0].parameters[0], "4" + ) + self.assertEqual( + self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[ + 0 + ].name, + "byte", + ) + self.assertEqual( + self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[ + 0 + ].parameters[2], + "144", + ) + self.assertEqual( + " ".join( + self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[1] + ), + "IACA START", + ) + + def test_parse_instruction(self): + instr1 = "addi t0, zero, 1" + instr2 = "lw a0, 8(sp)" + instr3 = "beq t0, t1, loop_start" + instr4 = "lui a0, %hi(data)" + instr5 = "sw ra, -4(sp)" + instr6 = "jal ra, function" + + parsed_1 = self.parser.parse_instruction(instr1) + parsed_2 = self.parser.parse_instruction(instr2) + parsed_3 = self.parser.parse_instruction(instr3) + parsed_4 = self.parser.parse_instruction(instr4) + parsed_5 = self.parser.parse_instruction(instr5) + parsed_6 = self.parser.parse_instruction(instr6) + + # Verify addi instruction + self.assertEqual(parsed_1.mnemonic, "addi") + self.assertEqual(parsed_1.operands[0].name, "t0") + self.assertEqual(parsed_1.operands[1].name, "zero") + self.assertEqual(parsed_1.operands[2].value, 1) + + # Verify lw instruction + self.assertEqual(parsed_2.mnemonic, "lw") + self.assertEqual(parsed_2.operands[0].name, "a0") + self.assertEqual(parsed_2.operands[1].offset.value, 8) + self.assertEqual(parsed_2.operands[1].base.name, "sp") + + # Verify beq instruction + self.assertEqual(parsed_3.mnemonic, "beq") + self.assertEqual(parsed_3.operands[0].name, "t0") + self.assertEqual(parsed_3.operands[1].name, "t1") + self.assertEqual(parsed_3.operands[2].name, "loop_start") + + # Verify lui instruction with high bits relocation + self.assertEqual(parsed_4.mnemonic, "lui") + self.assertEqual(parsed_4.operands[0].name, "a0") + self.assertEqual(parsed_4.operands[1].name, "data") + + # Verify sw instruction with negative offset + self.assertEqual(parsed_5.mnemonic, "sw") + self.assertEqual(parsed_5.operands[0].name, "ra") + self.assertEqual(parsed_5.operands[1].offset.value, -4) + self.assertEqual(parsed_5.operands[1].base.name, "sp") + + # Verify jal instruction + self.assertEqual(parsed_6.mnemonic, "jal") + self.assertEqual(parsed_6.operands[0].name, "ra") + self.assertEqual(parsed_6.operands[1].name, "function") + + def test_parse_line(self): + line_comment = "# -- Begin main" + line_label = ".LBB0_1: # Loop Header" + line_directive = ".cfi_def_cfa sp, 0" + line_instruction = "addi sp, sp, -16 # allocate stack frame" + + instruction_form_1 = InstructionForm( + mnemonic=None, + operands=[], + directive_id=None, + comment_id="-- Begin main", + label_id=None, + line="# -- Begin main", + line_number=1, + ) + + instruction_form_2 = InstructionForm( + mnemonic=None, + operands=[], + directive_id=None, + comment_id="Loop Header", + label_id=".LBB0_1", + line=".LBB0_1: # Loop Header", + line_number=2, + ) + + instruction_form_3 = InstructionForm( + mnemonic=None, + operands=[], + directive_id=DirectiveOperand(name="cfi_def_cfa", parameters=["sp", "0"]), + comment_id=None, + label_id=None, + line=".cfi_def_cfa sp, 0", + line_number=3, + ) + + instruction_form_4 = InstructionForm( + mnemonic="addi", + operands=[ + RegisterOperand(prefix="x", name="sp"), + RegisterOperand(prefix="x", name="sp"), + ImmediateOperand(value=-16, imd_type="int"), + ], + directive_id=None, + comment_id="allocate stack frame", + label_id=None, + line="addi sp, sp, -16 # allocate stack frame", + line_number=4, + ) + + parsed_1 = self.parser.parse_line(line_comment, 1) + parsed_2 = self.parser.parse_line(line_label, 2) + parsed_3 = self.parser.parse_line(line_directive, 3) + parsed_4 = self.parser.parse_line(line_instruction, 4) + + self.assertEqual(parsed_1.comment, instruction_form_1.comment) + self.assertEqual(parsed_2.label, instruction_form_2.label) + self.assertEqual(parsed_3.directive.name, instruction_form_3.directive.name) + self.assertEqual(parsed_3.directive.parameters, instruction_form_3.directive.parameters) + self.assertEqual(parsed_4.mnemonic, instruction_form_4.mnemonic) + self.assertEqual(parsed_4.operands[0].name, instruction_form_4.operands[0].name) + self.assertEqual(parsed_4.operands[2].value, instruction_form_4.operands[2].value) + self.assertEqual(parsed_4.comment, instruction_form_4.comment) + + def test_parse_file(self): + parsed = self.parser.parse_file(self.riscv_code) + self.assertEqual(parsed[0].line_number, 1) + self.assertGreater(len(parsed), 80) # More than 80 lines should be parsed + + # Test parsing specific parts of the file + # Find vector_add label + vector_add_idx = next((i for i, instr in enumerate(parsed) if instr.label == "vector_add"), None) + self.assertIsNotNone(vector_add_idx) + + # Find floating-point instructions + flw_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic == "flw"), None) + self.assertIsNotNone(flw_idx) + + # Find vector instructions + vle_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic and instr.mnemonic.startswith("vle")), None) + self.assertIsNotNone(vle_idx) + + # Find CSR instructions + csr_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic == "csrr"), None) + self.assertIsNotNone(csr_idx) + + def test_register_mapping(self): + # Test ABI name to register number mapping + reg_zero = RegisterOperand(name="zero") + reg_ra = RegisterOperand(name="ra") + reg_sp = RegisterOperand(name="sp") + reg_a0 = RegisterOperand(name="a0") + reg_t1 = RegisterOperand(name="t1") + reg_s2 = RegisterOperand(name="s2") + + reg_x0 = RegisterOperand(prefix="x", name="0") + reg_x1 = RegisterOperand(prefix="x", name="1") + reg_x2 = RegisterOperand(prefix="x", name="2") + reg_x10 = RegisterOperand(prefix="x", name="10") + reg_x6 = RegisterOperand(prefix="x", name="6") + reg_x18 = RegisterOperand(prefix="x", name="18") + + # Test canonical name conversion + self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0") + self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1") + self.assertEqual(self.parser._get_canonical_reg_name(reg_sp), "x2") + self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10") + self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6") + self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18") + + # Test register dependency + self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0)) + self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1)) + self.assertTrue(self.parser.is_reg_dependend_of(reg_sp, reg_x2)) + self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10)) + self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6)) + self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18)) + + # Test non-dependent registers + self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1)) + self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2)) + self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1)) + + def test_normalize_imd(self): + imd_decimal = ImmediateOperand(value="42") + imd_hex = ImmediateOperand(value="0x2A") + imd_negative = ImmediateOperand(value="-12") + identifier = IdentifierOperand(name="function") + + self.assertEqual(self.parser.normalize_imd(imd_decimal), 42) + self.assertEqual(self.parser.normalize_imd(imd_hex), 42) + self.assertEqual(self.parser.normalize_imd(imd_negative), -12) + self.assertEqual(self.parser.normalize_imd(identifier), identifier) + + def test_is_gpr(self): + # Test integer registers + reg_x5 = RegisterOperand(prefix="x", name="5") + reg_t0 = RegisterOperand(name="t0") + reg_sp = RegisterOperand(name="sp") + + # Test floating point registers + reg_f10 = RegisterOperand(prefix="f", name="10") + reg_fa0 = RegisterOperand(name="fa0") + + # Test vector registers + reg_v3 = RegisterOperand(prefix="v", name="3") + + self.assertTrue(self.parser.is_gpr(reg_x5)) + self.assertTrue(self.parser.is_gpr(reg_t0)) + self.assertTrue(self.parser.is_gpr(reg_sp)) + + self.assertFalse(self.parser.is_gpr(reg_f10)) + self.assertFalse(self.parser.is_gpr(reg_fa0)) + self.assertFalse(self.parser.is_gpr(reg_v3)) + + def test_is_vector_register(self): + reg_v3 = RegisterOperand(prefix="v", name="3") + reg_x5 = RegisterOperand(prefix="x", name="5") + reg_f10 = RegisterOperand(prefix="f", name="10") + + self.assertTrue(self.parser.is_vector_register(reg_v3)) + self.assertFalse(self.parser.is_vector_register(reg_x5)) + self.assertFalse(self.parser.is_vector_register(reg_f10)) + + ################## + # Helper functions + ################## + def _get_comment(self, parser, comment): + return " ".join( + parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[ + "comment" + ] + ) + + def _get_label(self, parser, label): + return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict()) + + def _get_directive(self, parser, directive): + return parser.process_operand( + parser.directive.parseString(directive, parseAll=True).asDict() + ) + + @staticmethod + def _find_file(name): + testdir = os.path.dirname(__file__) + name = os.path.join(testdir, "test_files", name) + assert os.path.exists(name) + return name + + +if __name__ == "__main__": + suite = unittest.TestLoader().loadTestsFromTestCase(TestParserRISCV) + unittest.TextTestRunner(verbosity=2).run(suite) \ No newline at end of file