Parser for RISCV is implemented and tested with a

simple kernel.

Changes to be committed:
	modified:   osaca/parser/__init__.py
	new file:   osaca/parser/parser_RISCV.py
	new file:   tests/test_files/kernel_riscv.s
	new file:   tests/test_parser_RISCV.py
This commit is contained in:
Metehan Dundar
2025-03-04 00:44:38 +01:00
parent dffea6d066
commit 7e546d970f
4 changed files with 1088 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ Only the parser below will be exported, so please add new parsers to __all__.
from .base_parser import BaseParser
from .parser_x86att import ParserX86ATT
from .parser_AArch64 import ParserAArch64
from .parser_RISCV import ParserRISCV
from .instruction_form import InstructionForm
from .operand import Operand
@@ -16,6 +17,7 @@ __all__ = [
"BaseParser",
"ParserX86ATT",
"ParserAArch64",
"ParserRISCV",
"get_parser",
]
@@ -25,5 +27,7 @@ def get_parser(isa):
return ParserX86ATT()
elif isa.lower() == "aarch64":
return ParserAArch64()
elif isa.lower() == "riscv":
return ParserRISCV()
else:
raise ValueError("Unknown ISA {!r}.".format(isa))

View File

@@ -0,0 +1,643 @@
#!/usr/bin/env python3
import re
import os
import logging
from copy import deepcopy
import pyparsing as pp
logger = logging.getLogger(__name__)
from osaca.parser import BaseParser
from osaca.parser.instruction_form import InstructionForm
from osaca.parser.operand import Operand
from osaca.parser.directive import DirectiveOperand
from osaca.parser.memory import MemoryOperand
from osaca.parser.label import LabelOperand
from osaca.parser.register import RegisterOperand
from osaca.parser.identifier import IdentifierOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.condition import ConditionOperand
class ParserRISCV(BaseParser):
_instance = None
# Singleton pattern, as this is created very many times
def __new__(cls):
if cls._instance is None:
cls._instance = super(ParserRISCV, cls).__new__(cls)
return cls._instance
def __init__(self):
super().__init__()
self.isa = "riscv"
def construct_parser(self):
"""Create parser for RISC-V ISA."""
# Comment - RISC-V uses # for comments
symbol_comment = "#"
self.comment = pp.Literal(symbol_comment) + pp.Group(
pp.ZeroOrMore(pp.Word(pp.printables))
).setResultsName(self.comment_id)
# Define RISC-V assembly identifier
decimal_number = pp.Combine(
pp.Optional(pp.Literal("-")) + pp.Word(pp.nums)
).setResultsName("value")
hex_number = pp.Combine(
pp.Optional(pp.Literal("-")) + pp.Literal("0x") + pp.Word(pp.hexnums)
).setResultsName("value")
# Additional identifiers used in vector instructions
vector_identifier = pp.Word(pp.alphas, pp.alphanums)
special_identifier = pp.Word(pp.alphas + "%")
first = pp.Word(pp.alphas + "_.", exact=1)
rest = pp.Word(pp.alphanums + "_.")
identifier = pp.Group(
pp.Combine(first + pp.Optional(rest)).setResultsName("name")
+ pp.Optional(
pp.Suppress(pp.Literal("+"))
+ (hex_number | decimal_number).setResultsName("offset")
)
).setResultsName(self.identifier)
# Label
self.label = pp.Group(
identifier.setResultsName("name") + pp.Literal(":") + pp.Optional(self.comment)
).setResultsName(self.label_id)
# Directive
directive_option = pp.Combine(
pp.Word(pp.alphas + "#@.%", exact=1)
+ pp.Optional(pp.Word(pp.printables + " ", excludeChars=","))
)
# For vector instructions
vector_parameter = pp.Word(pp.alphas)
directive_parameter = (
pp.quotedString | directive_option | identifier | hex_number | decimal_number
)
commaSeparatedList = pp.delimitedList(pp.Optional(directive_parameter), delim=",")
self.directive = pp.Group(
pp.Literal(".")
+ pp.Word(pp.alphanums + "_").setResultsName("name")
+ (pp.OneOrMore(directive_parameter) ^ commaSeparatedList).setResultsName("parameters")
+ pp.Optional(self.comment)
).setResultsName(self.directive_id)
# LLVM-MCA markers
self.llvm_markers = pp.Group(
pp.Literal("#")
+ pp.Combine(
pp.CaselessLiteral("LLVM-MCA-")
+ (pp.CaselessLiteral("BEGIN") | pp.CaselessLiteral("END"))
)
+ pp.Optional(self.comment)
).setResultsName(self.comment_id)
##############################
# Instructions
# Mnemonic
mnemonic = pp.Word(pp.alphanums + ".").setResultsName("mnemonic")
# Immediate:
# int: ^-?[0-9]+ | hex: ^0x[0-9a-fA-F]+
immediate = pp.Group(
(hex_number ^ decimal_number)
| identifier
).setResultsName(self.immediate_id)
# Register:
# RISC-V has two main types of registers:
# 1. Integer registers (x0-x31 or ABI names)
# 2. Floating-point registers (f0-f31 or ABI names)
# Integer register ABI names
integer_reg_abi = (
pp.CaselessLiteral("zero") |
pp.CaselessLiteral("ra") |
pp.CaselessLiteral("sp") |
pp.CaselessLiteral("gp") |
pp.CaselessLiteral("tp") |
pp.Regex(r"[tas][0-9]+") # t0-t6, a0-a7, s0-s11
).setResultsName("name")
# Integer registers x0-x31
integer_reg_x = (
pp.CaselessLiteral("x").setResultsName("prefix") +
pp.Word(pp.nums).setResultsName("name")
)
# Floating point registers
fp_reg_abi = pp.Regex(r"f[tas][0-9]+").setResultsName("name") # ft0-ft11, fa0-fa7, fs0-fs11
fp_reg_f = (
pp.CaselessLiteral("f").setResultsName("prefix") +
pp.Word(pp.nums).setResultsName("name")
)
# Control and status registers (CSRs)
csr_reg = pp.Combine(
pp.CaselessLiteral("csr") + pp.Word(pp.alphanums + "_")
).setResultsName("name")
# Vector registers (for the "V" extension)
vector_reg = (
pp.CaselessLiteral("v").setResultsName("prefix") +
pp.Word(pp.nums).setResultsName("name")
)
# Combined register definition
register = pp.Group(
integer_reg_x | integer_reg_abi | fp_reg_f | fp_reg_abi | vector_reg | csr_reg
).setResultsName(self.register_id)
self.register = register
# Memory addressing mode in RISC-V: offset(base_register)
memory = pp.Group(
pp.Optional(immediate.setResultsName("offset"))
+ pp.Suppress(pp.Literal("("))
+ register.setResultsName("base")
+ pp.Suppress(pp.Literal(")"))
).setResultsName(self.memory_id)
# Combine to instruction form
operand_first = pp.Group(
register ^ immediate ^ memory ^ identifier
)
operand_rest = pp.Group(
register ^ immediate ^ memory ^ identifier
)
# Vector instruction special parameters (e.g., e32, m4, ta, ma)
vector_param = pp.Word(pp.alphas + pp.nums)
# Handle additional vector parameters
additional_params = pp.ZeroOrMore(
pp.Suppress(pp.Literal(",")) +
vector_param.setResultsName("vector_param", listAllMatches=True)
)
# Main instruction parser
self.instruction_parser = (
mnemonic
+ pp.Optional(operand_first.setResultsName("operand1"))
+ pp.Optional(pp.Suppress(pp.Literal(",")))
+ pp.Optional(operand_rest.setResultsName("operand2"))
+ pp.Optional(pp.Suppress(pp.Literal(",")))
+ pp.Optional(operand_rest.setResultsName("operand3"))
+ pp.Optional(pp.Suppress(pp.Literal(",")))
+ pp.Optional(operand_rest.setResultsName("operand4"))
+ pp.Optional(additional_params) # For vector instructions with more params
+ pp.Optional(self.comment)
)
def parse_line(self, line, line_number=None):
"""
Parse line and return instruction form.
:param str line: line of assembly code
:param line_number: identifier of instruction form, defaults to None
:type line_number: int, optional
:return: `dict` -- parsed asm line (comment, label, directive or instruction form)
"""
instruction_form = InstructionForm(
mnemonic=None,
operands=[],
directive_id=None,
comment_id=None,
label_id=None,
line=line,
line_number=line_number,
)
result = None
# 1. Parse comment
try:
result = self.process_operand(self.comment.parseString(line, parseAll=True).asDict())
instruction_form.comment = " ".join(result[self.comment_id])
except pp.ParseException:
pass
# 1.2 check for llvm-mca marker
try:
result = self.process_operand(
self.llvm_markers.parseString(line, parseAll=True).asDict()
)
instruction_form.comment = " ".join(result[self.comment_id])
except pp.ParseException:
pass
# 2. Parse label
if result is None:
try:
# returns tuple with label operand and comment, if any
result = self.process_operand(self.label.parseString(line, parseAll=True).asDict())
instruction_form.label = result[0].name
if result[1] is not None:
instruction_form.comment = " ".join(result[1])
except pp.ParseException:
pass
# 3. Parse directive
if result is None:
try:
# returns directive with label operand and comment, if any
result = self.process_operand(
self.directive.parseString(line, parseAll=True).asDict()
)
instruction_form.directive = DirectiveOperand(
name=result[0].name, parameters=result[0].parameters
)
if result[1] is not None:
instruction_form.comment = " ".join(result[1])
except pp.ParseException:
pass
# 4. Parse instruction
if result is None:
try:
result = self.parse_instruction(line)
except (pp.ParseException, KeyError) as e:
raise ValueError(
"Unable to parse {!r} on line {}".format(line, line_number)
) from e
instruction_form.mnemonic = result.mnemonic
instruction_form.operands = result.operands
instruction_form.comment = result.comment
return instruction_form
def parse_instruction(self, instruction):
"""
Parse instruction in asm line.
:param str instruction: Assembly line string.
:returns: `dict` -- parsed instruction form
"""
# Special handling for vector instructions like vsetvli with many parameters
if instruction.startswith("vsetvli"):
parts = instruction.split("#")[0].strip().split()
mnemonic = parts[0]
# Split operands by commas
if len(parts) > 1:
operand_part = parts[1]
operands_list = [op.strip() for op in operand_part.split(",")]
# Process each operand
operands = []
for op in operands_list:
if op.startswith("x") or op in ["zero", "ra", "sp", "gp", "tp"] or re.match(r"[tas][0-9]+", op):
operands.append(RegisterOperand(name=op))
elif op in ["e8", "e16", "e32", "e64", "m1", "m2", "m4", "m8", "ta", "tu", "ma", "mu"]:
operands.append(IdentifierOperand(name=op))
else:
operands.append(IdentifierOperand(name=op))
# Get comment if present
comment = None
if "#" in instruction:
comment = instruction.split("#", 1)[1].strip()
return InstructionForm(
mnemonic=mnemonic,
operands=operands,
comment_id=comment
)
# Regular instruction parsing
try:
result = self.instruction_parser.parseString(instruction, parseAll=True).asDict()
operands = []
# Add operands to list
# Check first operand
if "operand1" in result:
operand = self.process_operand(result["operand1"])
operands.extend(operand) if isinstance(operand, list) else operands.append(operand)
# Check second operand
if "operand2" in result:
operand = self.process_operand(result["operand2"])
operands.extend(operand) if isinstance(operand, list) else operands.append(operand)
# Check third operand
if "operand3" in result:
operand = self.process_operand(result["operand3"])
operands.extend(operand) if isinstance(operand, list) else operands.append(operand)
# Check fourth operand
if "operand4" in result:
operand = self.process_operand(result["operand4"])
operands.extend(operand) if isinstance(operand, list) else operands.append(operand)
# Handle vector_param for vector instructions
if "vector_param" in result:
if isinstance(result["vector_param"], list):
for param in result["vector_param"]:
operands.append(IdentifierOperand(name=param))
else:
operands.append(IdentifierOperand(name=result["vector_param"]))
return_dict = InstructionForm(
mnemonic=result["mnemonic"],
operands=operands,
comment_id=" ".join(result[self.comment_id]) if self.comment_id in result else None,
)
return return_dict
except Exception as e:
logger.debug(f"Error parsing instruction: {instruction} - {str(e)}")
# For special vector instructions or ones with % in them
if "%" in instruction or instruction.startswith("v"):
parts = instruction.split("#")[0].strip().split(None, 1)
mnemonic = parts[0]
operands = []
if len(parts) > 1:
operand_part = parts[1]
operands_list = [op.strip() for op in operand_part.split(",")]
for op in operands_list:
# Process '%hi(data)' to 'data' for certain operands
if op.startswith("%") and '(' in op and ')' in op:
# Extract data from %hi(data) format
data = op[op.index('(')+1:op.index(')')]
operands.append(IdentifierOperand(name=data))
else:
operands.append(IdentifierOperand(name=op))
comment = None
if "#" in instruction:
comment = instruction.split("#", 1)[1].strip()
return InstructionForm(
mnemonic=mnemonic,
operands=operands,
comment_id=comment
)
else:
raise
def process_operand(self, operand):
"""Post-process operand"""
# structure memory addresses
if self.memory_id in operand:
return self.process_memory_address(operand[self.memory_id])
# add value attribute to immediates
if self.immediate_id in operand:
return self.process_immediate(operand[self.immediate_id])
if self.label_id in operand:
return self.process_label(operand[self.label_id])
if self.identifier in operand:
return self.process_identifier(operand[self.identifier])
if self.register_id in operand:
return self.process_register_operand(operand[self.register_id])
if self.directive_id in operand:
return self.process_directive_operand(operand[self.directive_id])
return operand
def process_directive_operand(self, operand):
return (
DirectiveOperand(
name=operand["name"],
parameters=operand["parameters"],
),
operand["comment"] if "comment" in operand else None,
)
def process_register_operand(self, operand):
"""Process register operands, including ABI name to x-register mapping"""
# Handle ABI names by adding the appropriate prefix
if "prefix" not in operand:
name = operand["name"].lower()
# Integer register ABI names
if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]:
prefix = "x"
# Floating point register ABI names
elif name[0] == "f" and name[1] in ["t", "a", "s"]:
prefix = "f"
# CSR registers
elif name.startswith("csr"):
prefix = ""
else:
prefix = ""
return RegisterOperand(
prefix=prefix,
name=name
)
else:
return RegisterOperand(
prefix=operand["prefix"].lower(),
name=operand["name"]
)
def process_memory_address(self, memory_address):
"""Post-process memory address operand"""
# Process offset
offset = memory_address.get("offset", None)
if isinstance(offset, list) and len(offset) == 1:
offset = offset[0]
if offset is not None and "value" in offset:
offset = ImmediateOperand(value=int(offset["value"], 0))
if isinstance(offset, dict) and "identifier" in offset:
offset = self.process_identifier(offset["identifier"])
# Process base register
base = memory_address.get("base", None)
if base is not None:
base = self.process_register_operand(base)
# Create memory operand
return MemoryOperand(
offset=offset,
base=base,
index=None,
scale=1
)
def process_label(self, label):
"""Post-process label asm line"""
return (
LabelOperand(name=label["name"]["name"]),
label["comment"] if self.comment_id in label else None,
)
def process_identifier(self, identifier):
"""Post-process identifier operand"""
return IdentifierOperand(
name=identifier["name"] if "name" in identifier else None,
offset=identifier["offset"] if "offset" in identifier else None
)
def process_immediate(self, immediate):
"""Post-process immediate operand"""
if "identifier" in immediate:
# actually an identifier, change declaration
return self.process_identifier(immediate["identifier"])
if "value" in immediate:
# normal integer value
immediate["type"] = "int"
# convert hex/bin immediates to dec
new_immediate = ImmediateOperand(imd_type=immediate["type"], value=immediate["value"])
new_immediate.value = self.normalize_imd(new_immediate)
return new_immediate
return immediate
def get_full_reg_name(self, register):
"""Return one register name string including all attributes"""
if register.prefix and register.name:
return register.prefix + str(register.name)
return str(register.name)
def normalize_imd(self, imd):
"""Normalize immediate to decimal based representation"""
if isinstance(imd, IdentifierOperand):
return imd
elif imd.value is not None:
if isinstance(imd.value, str):
# hex or bin, return decimal
return int(imd.value, 0)
else:
return imd.value
# identifier
return imd
def parse_register(self, register_string):
"""
Parse register string and return register dictionary.
:param str register_string: register representation as string
:returns: dict with register info
"""
# Remove any leading/trailing whitespace
register_string = register_string.strip()
# Check for integer registers (x0-x31)
x_match = re.match(r'^x([0-9]|[1-2][0-9]|3[0-1])$', register_string)
if x_match:
reg_num = int(x_match.group(1))
return {"class": "register", "register": {"prefix": "x", "name": str(reg_num)}}
# Check for floating-point registers (f0-f31)
f_match = re.match(r'^f([0-9]|[1-2][0-9]|3[0-1])$', register_string)
if f_match:
reg_num = int(f_match.group(1))
return {"class": "register", "register": {"prefix": "f", "name": str(reg_num)}}
# Check for vector registers (v0-v31)
v_match = re.match(r'^v([0-9]|[1-2][0-9]|3[0-1])$', register_string)
if v_match:
reg_num = int(v_match.group(1))
return {"class": "register", "register": {"prefix": "v", "name": str(reg_num)}}
# Check for ABI names
abi_names = {
"zero": 0, "ra": 1, "sp": 2, "gp": 3, "tp": 4,
"t0": 5, "t1": 6, "t2": 7,
"s0": 8, "fp": 8, "s1": 9,
"a0": 10, "a1": 11, "a2": 12, "a3": 13, "a4": 14, "a5": 15, "a6": 16, "a7": 17,
"s2": 18, "s3": 19, "s4": 20, "s5": 21, "s6": 22, "s7": 23, "s8": 24, "s9": 25, "s10": 26, "s11": 27,
"t3": 28, "t4": 29, "t5": 30, "t6": 31
}
if register_string in abi_names:
return {"class": "register", "register": {"prefix": "", "name": register_string}}
# If no match is found
return None
def is_gpr(self, register):
"""Check if register is a general purpose register"""
# Integer registers: x0-x31 or ABI names
if register.prefix == "x":
return True
if not register.prefix and register.name in ["zero", "ra", "sp", "gp", "tp"]:
return True
if not register.prefix and register.name[0] in ["t", "a", "s"]:
return True
return False
def is_vector_register(self, register):
"""Check if register is a vector register"""
# Vector registers: v0-v31
if register.prefix == "v":
return True
return False
def is_flag_dependend_of(self, flag_a, flag_b):
"""Check if ``flag_a`` is dependent on ``flag_b``"""
# RISC-V doesn't have explicit flags like x86 or AArch64
return flag_a.name == flag_b.name
def is_reg_dependend_of(self, reg_a, reg_b):
"""Check if ``reg_a`` is dependent on ``reg_b``"""
if not isinstance(reg_a, Operand):
reg_a = RegisterOperand(name=reg_a["name"])
# Get canonical register names
reg_a_canonical = self._get_canonical_reg_name(reg_a)
reg_b_canonical = self._get_canonical_reg_name(reg_b)
# Same register type and number means dependency
return reg_a_canonical == reg_b_canonical
def _get_canonical_reg_name(self, register):
"""Get the canonical form of a register (x-form for integer, f-form for FP)"""
# If already in canonical form (x# or f#)
if register.prefix in ["x", "f", "v"] and register.name.isdigit():
return f"{register.prefix}{register.name}"
# ABI name mapping for integer registers
abi_to_x = {
"zero": "x0", "ra": "x1", "sp": "x2", "gp": "x3", "tp": "x4",
"t0": "x5", "t1": "x6", "t2": "x7",
"s0": "x8", "s1": "x9",
"a0": "x10", "a1": "x11", "a2": "x12", "a3": "x13",
"a4": "x14", "a5": "x15", "a6": "x16", "a7": "x17",
"s2": "x18", "s3": "x19", "s4": "x20", "s5": "x21",
"s6": "x22", "s7": "x23", "s8": "x24", "s9": "x25",
"s10": "x26", "s11": "x27",
"t3": "x28", "t4": "x29", "t5": "x30", "t6": "x31"
}
# For integer register ABI names
name = register.name.lower()
if name in abi_to_x:
return abi_to_x[name]
# For FP register ABI names like fa0, fs1, etc.
if name.startswith("f") and len(name) > 1:
if name[1] == "a": # fa0-fa7
idx = int(name[2:])
return f"f{idx + 10}"
elif name[1] == "s": # fs0-fs11
idx = int(name[2:])
if idx <= 1:
return f"f{idx + 8}"
else:
return f"f{idx + 16}"
elif name[1] == "t": # ft0-ft11
idx = int(name[2:])
if idx <= 7:
return f"f{idx}"
else:
return f"f{idx + 20}"
# Return as is if no mapping found
return f"{register.prefix}{register.name}"
def get_reg_type(self, register):
"""Get register type"""
# Return register prefix if exists
if register.prefix:
return register.prefix
# Determine type from ABI name
name = register.name.lower()
if name in ["zero", "ra", "sp", "gp", "tp"] or name[0] in ["t", "a", "s"]:
return "x" # Integer register
elif name.startswith("f"):
return "f" # Floating point register
elif name.startswith("csr"):
return "csr" # Control and Status Register
return "unknown"

View File

@@ -0,0 +1,120 @@
# Basic RISC-V test kernel with various instructions
.text
.globl vector_add
.align 2
# Example of a basic function
vector_add:
# Prologue
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
addi s0, sp, 16
# Setup
mv a3, a0
lw a0, 0(a0) # Load first element
lw a4, 0(a1) # Load second element
add a0, a0, a4 # Add elements
sw a0, 0(a2) # Store to result array
# Integer operations
addi t0, zero, 10
addi t1, zero, 5
add t2, t0, t1
sub t3, t0, t1
and t4, t0, t1
or t5, t0, t1
xor t6, t0, t1
sll a0, t0, t1
srl a1, t0, t1
sra a2, t0, t1
# Memory operations
lw a0, 8(sp)
sw a1, 4(sp)
lbu a2, 1(sp)
sb a3, 0(sp)
lh a4, 2(sp)
sh a5, 2(sp)
# Branch and jump instructions
beq t0, t1, skip
bne t0, t1, continue
jal ra, function
jalr t0, 0(ra)
.L1: # Loop Header
beq t0, t1, .L2
addi t0, t0, 1
j .L1
.L2:
# Floating point operations
flw fa0, 0(a0)
flw fa1, 4(a0)
fadd.s fa2, fa0, fa1
fsub.s fa3, fa0, fa1
fmv.x.w a0, fa0
fmv.w.x fa4, a0
# CSR operations
csrr t0, mstatus
csrw mtvec, t0
csrs mie, t0
csrc mip, t0
# Vector instructions (RVV)
vsetvli t0, a0, e32, m4, ta, ma
vle32.v v0, (a0)
vle32.v v4, (a1)
vadd.vv v8, v0, v4
vse32.v v8, (a2)
# Atomic operations
lr.w t0, (a0)
sc.w t1, t2, (a0)
amoswap.w t3, t4, (a0)
amoadd.w t5, t6, (a0)
# Multiply/divide instructions
mul t0, t1, t2
mulh t3, t4, t5
div t0, t1, t2
rem t3, t4, t5
# Pseudo-instructions
li t0, 1234
la t1, data
li a0, %hi(data)
addi a1, a0, %lo(data)
skip:
# Skip destination
addi t2, zero, 20
continue:
# Continue destination
addi t3, zero, 30
function:
# Function destination
addi a0, zero, 0
ret
# Epilogue
lw ra, 12(sp)
lw s0, 8(sp)
addi sp, sp, 16
ret
.data
.align 4
data:
.word 0x12345678
.byte 0x01, 0x02, 0x03, 0x04
.half 0xABCD, 0xEF01
.float 3.14159
.space 16
.ascii "RISC-V Test String"

321
tests/test_parser_RISCV.py Normal file
View File

@@ -0,0 +1,321 @@
#!/usr/bin/env python3
"""
Unit tests for RISC-V assembly parser
"""
import os
import unittest
from pyparsing import ParseException
from osaca.parser import ParserRISCV, InstructionForm
from osaca.parser.directive import DirectiveOperand
from osaca.parser.memory import MemoryOperand
from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.identifier import IdentifierOperand
class TestParserRISCV(unittest.TestCase):
@classmethod
def setUpClass(self):
self.parser = ParserRISCV()
with open(self._find_file("kernel_riscv.s")) as f:
self.riscv_code = f.read()
##################
# Test
##################
def test_comment_parser(self):
self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments")
self.assertEqual(
self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), "RISC-V comment end"
)
self.assertEqual(
self._get_comment(self.parser, "\t## comment ## comment"),
"# comment ## comment",
)
def test_label_parser(self):
self.assertEqual(self._get_label(self.parser, "main:")[0].name, "main")
self.assertEqual(self._get_label(self.parser, "loop_start:")[0].name, "loop_start")
self.assertEqual(self._get_label(self.parser, ".L1:\t\t\t# comment")[0].name, ".L1")
self.assertEqual(
" ".join(self._get_label(self.parser, ".L1:\t\t\t# comment")[1]),
"comment",
)
with self.assertRaises(ParseException):
self._get_label(self.parser, "\t.cfi_startproc")
def test_directive_parser(self):
self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text")
self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0)
self.assertEqual(self._get_directive(self.parser, "\t.align\t4")[0].name, "align")
self.assertEqual(
len(self._get_directive(self.parser, "\t.align\t4")[0].parameters), 1
)
self.assertEqual(
self._get_directive(self.parser, "\t.align\t4")[0].parameters[0], "4"
)
self.assertEqual(
self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[
0
].name,
"byte",
)
self.assertEqual(
self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[
0
].parameters[2],
"144",
)
self.assertEqual(
" ".join(
self._get_directive(self.parser, " .byte 100,103,144 # IACA START")[1]
),
"IACA START",
)
def test_parse_instruction(self):
instr1 = "addi t0, zero, 1"
instr2 = "lw a0, 8(sp)"
instr3 = "beq t0, t1, loop_start"
instr4 = "lui a0, %hi(data)"
instr5 = "sw ra, -4(sp)"
instr6 = "jal ra, function"
parsed_1 = self.parser.parse_instruction(instr1)
parsed_2 = self.parser.parse_instruction(instr2)
parsed_3 = self.parser.parse_instruction(instr3)
parsed_4 = self.parser.parse_instruction(instr4)
parsed_5 = self.parser.parse_instruction(instr5)
parsed_6 = self.parser.parse_instruction(instr6)
# Verify addi instruction
self.assertEqual(parsed_1.mnemonic, "addi")
self.assertEqual(parsed_1.operands[0].name, "t0")
self.assertEqual(parsed_1.operands[1].name, "zero")
self.assertEqual(parsed_1.operands[2].value, 1)
# Verify lw instruction
self.assertEqual(parsed_2.mnemonic, "lw")
self.assertEqual(parsed_2.operands[0].name, "a0")
self.assertEqual(parsed_2.operands[1].offset.value, 8)
self.assertEqual(parsed_2.operands[1].base.name, "sp")
# Verify beq instruction
self.assertEqual(parsed_3.mnemonic, "beq")
self.assertEqual(parsed_3.operands[0].name, "t0")
self.assertEqual(parsed_3.operands[1].name, "t1")
self.assertEqual(parsed_3.operands[2].name, "loop_start")
# Verify lui instruction with high bits relocation
self.assertEqual(parsed_4.mnemonic, "lui")
self.assertEqual(parsed_4.operands[0].name, "a0")
self.assertEqual(parsed_4.operands[1].name, "data")
# Verify sw instruction with negative offset
self.assertEqual(parsed_5.mnemonic, "sw")
self.assertEqual(parsed_5.operands[0].name, "ra")
self.assertEqual(parsed_5.operands[1].offset.value, -4)
self.assertEqual(parsed_5.operands[1].base.name, "sp")
# Verify jal instruction
self.assertEqual(parsed_6.mnemonic, "jal")
self.assertEqual(parsed_6.operands[0].name, "ra")
self.assertEqual(parsed_6.operands[1].name, "function")
def test_parse_line(self):
line_comment = "# -- Begin main"
line_label = ".LBB0_1: # Loop Header"
line_directive = ".cfi_def_cfa sp, 0"
line_instruction = "addi sp, sp, -16 # allocate stack frame"
instruction_form_1 = InstructionForm(
mnemonic=None,
operands=[],
directive_id=None,
comment_id="-- Begin main",
label_id=None,
line="# -- Begin main",
line_number=1,
)
instruction_form_2 = InstructionForm(
mnemonic=None,
operands=[],
directive_id=None,
comment_id="Loop Header",
label_id=".LBB0_1",
line=".LBB0_1: # Loop Header",
line_number=2,
)
instruction_form_3 = InstructionForm(
mnemonic=None,
operands=[],
directive_id=DirectiveOperand(name="cfi_def_cfa", parameters=["sp", "0"]),
comment_id=None,
label_id=None,
line=".cfi_def_cfa sp, 0",
line_number=3,
)
instruction_form_4 = InstructionForm(
mnemonic="addi",
operands=[
RegisterOperand(prefix="x", name="sp"),
RegisterOperand(prefix="x", name="sp"),
ImmediateOperand(value=-16, imd_type="int"),
],
directive_id=None,
comment_id="allocate stack frame",
label_id=None,
line="addi sp, sp, -16 # allocate stack frame",
line_number=4,
)
parsed_1 = self.parser.parse_line(line_comment, 1)
parsed_2 = self.parser.parse_line(line_label, 2)
parsed_3 = self.parser.parse_line(line_directive, 3)
parsed_4 = self.parser.parse_line(line_instruction, 4)
self.assertEqual(parsed_1.comment, instruction_form_1.comment)
self.assertEqual(parsed_2.label, instruction_form_2.label)
self.assertEqual(parsed_3.directive.name, instruction_form_3.directive.name)
self.assertEqual(parsed_3.directive.parameters, instruction_form_3.directive.parameters)
self.assertEqual(parsed_4.mnemonic, instruction_form_4.mnemonic)
self.assertEqual(parsed_4.operands[0].name, instruction_form_4.operands[0].name)
self.assertEqual(parsed_4.operands[2].value, instruction_form_4.operands[2].value)
self.assertEqual(parsed_4.comment, instruction_form_4.comment)
def test_parse_file(self):
parsed = self.parser.parse_file(self.riscv_code)
self.assertEqual(parsed[0].line_number, 1)
self.assertGreater(len(parsed), 80) # More than 80 lines should be parsed
# Test parsing specific parts of the file
# Find vector_add label
vector_add_idx = next((i for i, instr in enumerate(parsed) if instr.label == "vector_add"), None)
self.assertIsNotNone(vector_add_idx)
# Find floating-point instructions
flw_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic == "flw"), None)
self.assertIsNotNone(flw_idx)
# Find vector instructions
vle_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic and instr.mnemonic.startswith("vle")), None)
self.assertIsNotNone(vle_idx)
# Find CSR instructions
csr_idx = next((i for i, instr in enumerate(parsed) if instr.mnemonic == "csrr"), None)
self.assertIsNotNone(csr_idx)
def test_register_mapping(self):
# Test ABI name to register number mapping
reg_zero = RegisterOperand(name="zero")
reg_ra = RegisterOperand(name="ra")
reg_sp = RegisterOperand(name="sp")
reg_a0 = RegisterOperand(name="a0")
reg_t1 = RegisterOperand(name="t1")
reg_s2 = RegisterOperand(name="s2")
reg_x0 = RegisterOperand(prefix="x", name="0")
reg_x1 = RegisterOperand(prefix="x", name="1")
reg_x2 = RegisterOperand(prefix="x", name="2")
reg_x10 = RegisterOperand(prefix="x", name="10")
reg_x6 = RegisterOperand(prefix="x", name="6")
reg_x18 = RegisterOperand(prefix="x", name="18")
# Test canonical name conversion
self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0")
self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1")
self.assertEqual(self.parser._get_canonical_reg_name(reg_sp), "x2")
self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10")
self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6")
self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18")
# Test register dependency
self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0))
self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1))
self.assertTrue(self.parser.is_reg_dependend_of(reg_sp, reg_x2))
self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10))
self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6))
self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18))
# Test non-dependent registers
self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1))
self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2))
self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1))
def test_normalize_imd(self):
imd_decimal = ImmediateOperand(value="42")
imd_hex = ImmediateOperand(value="0x2A")
imd_negative = ImmediateOperand(value="-12")
identifier = IdentifierOperand(name="function")
self.assertEqual(self.parser.normalize_imd(imd_decimal), 42)
self.assertEqual(self.parser.normalize_imd(imd_hex), 42)
self.assertEqual(self.parser.normalize_imd(imd_negative), -12)
self.assertEqual(self.parser.normalize_imd(identifier), identifier)
def test_is_gpr(self):
# Test integer registers
reg_x5 = RegisterOperand(prefix="x", name="5")
reg_t0 = RegisterOperand(name="t0")
reg_sp = RegisterOperand(name="sp")
# Test floating point registers
reg_f10 = RegisterOperand(prefix="f", name="10")
reg_fa0 = RegisterOperand(name="fa0")
# Test vector registers
reg_v3 = RegisterOperand(prefix="v", name="3")
self.assertTrue(self.parser.is_gpr(reg_x5))
self.assertTrue(self.parser.is_gpr(reg_t0))
self.assertTrue(self.parser.is_gpr(reg_sp))
self.assertFalse(self.parser.is_gpr(reg_f10))
self.assertFalse(self.parser.is_gpr(reg_fa0))
self.assertFalse(self.parser.is_gpr(reg_v3))
def test_is_vector_register(self):
reg_v3 = RegisterOperand(prefix="v", name="3")
reg_x5 = RegisterOperand(prefix="x", name="5")
reg_f10 = RegisterOperand(prefix="f", name="10")
self.assertTrue(self.parser.is_vector_register(reg_v3))
self.assertFalse(self.parser.is_vector_register(reg_x5))
self.assertFalse(self.parser.is_vector_register(reg_f10))
##################
# Helper functions
##################
def _get_comment(self, parser, comment):
return " ".join(
parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[
"comment"
]
)
def _get_label(self, parser, label):
return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict())
def _get_directive(self, parser, directive):
return parser.process_operand(
parser.directive.parseString(directive, parseAll=True).asDict()
)
@staticmethod
def _find_file(name):
testdir = os.path.dirname(__file__)
name = os.path.join(testdir, "test_files", name)
assert os.path.exists(name)
return name
if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromTestCase(TestParserRISCV)
unittest.TextTestRunner(verbosity=2).run(suite)