mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2026-01-06 19:20:07 +01:00
350 lines
16 KiB
Python
350 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unit tests for RISC-V assembly parser
|
|
"""
|
|
|
|
import os
|
|
import unittest
|
|
|
|
from pyparsing import ParseException
|
|
|
|
from osaca.parser import ParserRISCV, InstructionForm
|
|
from osaca.parser.directive import DirectiveOperand
|
|
from osaca.parser.memory import MemoryOperand
|
|
from osaca.parser.register import RegisterOperand
|
|
from osaca.parser.immediate import ImmediateOperand
|
|
from osaca.parser.identifier import IdentifierOperand
|
|
|
|
|
|
class TestParserRISCV(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(self):
|
|
self.parser = ParserRISCV()
|
|
with open(self._find_file("kernel_riscv.s")) as f:
|
|
self.riscv_code = f.read()
|
|
|
|
##################
|
|
# Test
|
|
##################
|
|
|
|
def test_comment_parser(self):
|
|
self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments")
|
|
self.assertEqual(
|
|
self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), "RISC-V comment end"
|
|
)
|
|
self.assertEqual(
|
|
self._get_comment(self.parser, "\t## comment ## comment"),
|
|
"# comment ## comment",
|
|
)
|
|
|
|
def test_label_parser(self):
|
|
# Test common label patterns from kernel_riscv.s
|
|
self.assertEqual(self._get_label(self.parser, "saxpy_golden:")[0].name, "saxpy_golden")
|
|
self.assertEqual(self._get_label(self.parser, ".L4:")[0].name, ".L4")
|
|
self.assertEqual(self._get_label(self.parser, ".L25:\t\t\t# Return")[0].name, ".L25")
|
|
self.assertEqual(
|
|
" ".join(self._get_label(self.parser, ".L25:\t\t\t# Return")[1]),
|
|
"Return",
|
|
)
|
|
with self.assertRaises(ParseException):
|
|
self._get_label(self.parser, "\t.word 1113498583")
|
|
|
|
def test_directive_parser(self):
|
|
self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text")
|
|
self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0)
|
|
self.assertEqual(self._get_directive(self.parser, "\t.word\t1113498583")[0].name, "word")
|
|
self.assertEqual(
|
|
len(self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters), 1
|
|
)
|
|
self.assertEqual(
|
|
self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters[0], "1113498583"
|
|
)
|
|
# Test string directive
|
|
self.assertEqual(
|
|
self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].name, "string"
|
|
)
|
|
self.assertEqual(
|
|
self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].parameters[0],
|
|
'"fail, %f=!%f\\n"'
|
|
)
|
|
# Test set directive
|
|
self.assertEqual(
|
|
self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].name, "set"
|
|
)
|
|
self.assertEqual(
|
|
len(self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].parameters), 2
|
|
)
|
|
|
|
def test_parse_instruction(self):
|
|
# Use generic RISC-V instructions for testing, not tied to a specific file
|
|
instr1 = "beq a0,zero,.L12" # Branch instruction
|
|
instr2 = "vsetvli a5,zero,e32,m1,ta,ma" # Vector instruction
|
|
instr3 = "vle32.v v1,0(a1)" # Vector load instruction
|
|
instr4 = "fmadd.s fa5,fa0,fa5,fa4" # Floating-point instruction
|
|
instr5 = "addi sp,sp,-64" # Integer immediate instruction
|
|
instr6 = "csrr a4,vlenb" # CSR instruction
|
|
instr7 = "ret" # Return instruction
|
|
instr8 = "lui a0,%hi(data)" # Load upper immediate with relocation
|
|
instr9 = "sw ra,-4(sp)" # Store with negative offset
|
|
|
|
parsed_1 = self.parser.parse_instruction(instr1)
|
|
parsed_2 = self.parser.parse_instruction(instr2)
|
|
parsed_3 = self.parser.parse_instruction(instr3)
|
|
parsed_4 = self.parser.parse_instruction(instr4)
|
|
parsed_5 = self.parser.parse_instruction(instr5)
|
|
parsed_6 = self.parser.parse_instruction(instr6)
|
|
parsed_7 = self.parser.parse_instruction(instr7)
|
|
parsed_8 = self.parser.parse_instruction(instr8)
|
|
parsed_9 = self.parser.parse_instruction(instr9)
|
|
|
|
# Verify branch instruction
|
|
self.assertEqual(parsed_1.mnemonic, "beq")
|
|
self.assertEqual(len(parsed_1.operands), 3)
|
|
self.assertTrue(isinstance(parsed_1.operands[0], RegisterOperand))
|
|
self.assertEqual(parsed_1.operands[0].name, "a0")
|
|
self.assertTrue(isinstance(parsed_1.operands[1], RegisterOperand))
|
|
self.assertEqual(parsed_1.operands[1].name, "zero")
|
|
self.assertTrue(isinstance(parsed_1.operands[2], IdentifierOperand))
|
|
self.assertEqual(parsed_1.operands[2].name, ".L12")
|
|
|
|
# Verify vector configuration instruction
|
|
self.assertEqual(parsed_2.mnemonic, "vsetvli")
|
|
self.assertEqual(len(parsed_2.operands), 6) # Verify correct operand count
|
|
self.assertEqual(parsed_2.operands[0].name, "a5")
|
|
self.assertEqual(parsed_2.operands[1].name, "zero")
|
|
|
|
# Verify vector load instruction
|
|
self.assertEqual(parsed_3.mnemonic, "vle32.v")
|
|
self.assertEqual(len(parsed_3.operands), 2)
|
|
self.assertEqual(parsed_3.operands[0].prefix, "v")
|
|
self.assertEqual(parsed_3.operands[0].name, "1")
|
|
self.assertTrue(isinstance(parsed_3.operands[1], MemoryOperand))
|
|
self.assertEqual(parsed_3.operands[1].base.name, "a1")
|
|
|
|
# Verify floating-point instruction
|
|
self.assertEqual(parsed_4.mnemonic, "fmadd.s")
|
|
self.assertEqual(len(parsed_4.operands), 4)
|
|
self.assertEqual(parsed_4.operands[0].prefix, "f")
|
|
|
|
# Verify integer immediate instruction
|
|
self.assertEqual(parsed_5.mnemonic, "addi")
|
|
self.assertEqual(len(parsed_5.operands), 3)
|
|
self.assertEqual(parsed_5.operands[0].name, "sp")
|
|
self.assertEqual(parsed_5.operands[1].name, "sp")
|
|
self.assertTrue(isinstance(parsed_5.operands[2], ImmediateOperand))
|
|
self.assertEqual(parsed_5.operands[2].value, -64)
|
|
|
|
# Verify CSR instruction
|
|
self.assertEqual(parsed_6.mnemonic, "csrr")
|
|
self.assertEqual(len(parsed_6.operands), 2)
|
|
self.assertEqual(parsed_6.operands[0].name, "a4")
|
|
self.assertEqual(parsed_6.operands[1].name, "vlenb")
|
|
|
|
# Verify return instruction
|
|
self.assertEqual(parsed_7.mnemonic, "ret")
|
|
self.assertEqual(len(parsed_7.operands), 0)
|
|
|
|
# Verify load upper immediate with relocation
|
|
self.assertEqual(parsed_8.mnemonic, "lui")
|
|
self.assertEqual(len(parsed_8.operands), 2)
|
|
self.assertEqual(parsed_8.operands[0].name, "a0")
|
|
self.assertEqual(parsed_8.operands[1].name, "data")
|
|
|
|
# Verify store with negative offset
|
|
self.assertEqual(parsed_9.mnemonic, "sw")
|
|
self.assertEqual(len(parsed_9.operands), 2)
|
|
self.assertEqual(parsed_9.operands[0].name, "ra")
|
|
self.assertTrue(isinstance(parsed_9.operands[1], MemoryOperand))
|
|
self.assertEqual(parsed_9.operands[1].base.name, "sp")
|
|
self.assertEqual(parsed_9.operands[1].offset.value, -4)
|
|
|
|
def test_parse_line(self):
|
|
# Use generic RISC-V lines for testing
|
|
line_label = "saxpy_golden:"
|
|
line_branch = " beq a0,zero,.L12"
|
|
line_memory = " vle32.v v1,0(a1)"
|
|
line_directive = " .word 1113498583"
|
|
line_with_comment = " ret # Return from function"
|
|
|
|
parsed_1 = self.parser.parse_line(line_label, 1)
|
|
parsed_2 = self.parser.parse_line(line_branch, 2)
|
|
parsed_3 = self.parser.parse_line(line_memory, 3)
|
|
parsed_4 = self.parser.parse_line(line_directive, 4)
|
|
parsed_5 = self.parser.parse_line(line_with_comment, 5)
|
|
|
|
# Verify label parsing
|
|
self.assertEqual(parsed_1.label, "saxpy_golden")
|
|
self.assertIsNone(parsed_1.mnemonic)
|
|
|
|
# Verify branch instruction parsing
|
|
self.assertEqual(parsed_2.mnemonic, "beq")
|
|
self.assertEqual(len(parsed_2.operands), 3)
|
|
self.assertEqual(parsed_2.operands[0].name, "a0")
|
|
self.assertEqual(parsed_2.operands[1].name, "zero")
|
|
self.assertEqual(parsed_2.operands[2].name, ".L12")
|
|
|
|
# Verify memory instruction parsing
|
|
self.assertEqual(parsed_3.mnemonic, "vle32.v")
|
|
self.assertEqual(len(parsed_3.operands), 2)
|
|
self.assertEqual(parsed_3.operands[0].prefix, "v")
|
|
self.assertEqual(parsed_3.operands[0].name, "1")
|
|
self.assertTrue(isinstance(parsed_3.operands[1], MemoryOperand))
|
|
|
|
# Verify directive parsing
|
|
self.assertIsNone(parsed_4.mnemonic)
|
|
self.assertEqual(parsed_4.directive.name, "word")
|
|
self.assertEqual(parsed_4.directive.parameters[0], "1113498583")
|
|
|
|
# Verify comment parsing
|
|
self.assertEqual(parsed_5.mnemonic, "ret")
|
|
self.assertEqual(parsed_5.comment, "Return from function")
|
|
|
|
def test_parse_file(self):
|
|
parsed = self.parser.parse_file(self.riscv_code)
|
|
self.assertGreater(len(parsed), 10) # There should be multiple lines
|
|
|
|
# Find common elements that should exist in any RISC-V file
|
|
# without being tied to specific line numbers
|
|
|
|
# Verify that we can find at least one label
|
|
label_forms = [form for form in parsed if form.label is not None]
|
|
self.assertGreater(len(label_forms), 0, "No labels found in the file")
|
|
|
|
# Verify that we can find at least one branch instruction
|
|
branch_forms = [form for form in parsed if form.mnemonic and form.mnemonic.startswith("b")]
|
|
self.assertGreater(len(branch_forms), 0, "No branch instructions found in the file")
|
|
|
|
# Verify that we can find at least one store/load instruction
|
|
mem_forms = [form for form in parsed if form.mnemonic and (
|
|
form.mnemonic.startswith("s") or
|
|
form.mnemonic.startswith("l"))]
|
|
self.assertGreater(len(mem_forms), 0, "No memory instructions found in the file")
|
|
|
|
# Verify that we can find at least one directive
|
|
directive_forms = [form for form in parsed if form.directive is not None]
|
|
self.assertGreater(len(directive_forms), 0, "No directives found in the file")
|
|
|
|
def test_register_dependency(self):
|
|
# Test ABI name to register number mapping
|
|
reg_zero = RegisterOperand(name="zero")
|
|
reg_ra = RegisterOperand(name="ra")
|
|
reg_sp = RegisterOperand(name="sp")
|
|
reg_a0 = RegisterOperand(name="a0")
|
|
reg_t1 = RegisterOperand(name="t1")
|
|
reg_s2 = RegisterOperand(name="s2")
|
|
|
|
reg_x0 = RegisterOperand(prefix="x", name="0")
|
|
reg_x1 = RegisterOperand(prefix="x", name="1")
|
|
reg_x2 = RegisterOperand(prefix="x", name="2")
|
|
reg_x5 = RegisterOperand(prefix="x", name="5") # Define reg_x5 for use in tests below
|
|
reg_x10 = RegisterOperand(prefix="x", name="10")
|
|
reg_x6 = RegisterOperand(prefix="x", name="6")
|
|
reg_x18 = RegisterOperand(prefix="x", name="18")
|
|
|
|
# Test canonical name conversion
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0")
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1")
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_sp), "x2")
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10")
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6")
|
|
self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18")
|
|
|
|
# Test register dependency
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0))
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1))
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_sp, reg_x2))
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10))
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6))
|
|
self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18))
|
|
|
|
# Test non-dependent registers
|
|
self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1))
|
|
self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2))
|
|
self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1))
|
|
|
|
# Test floating-point registers
|
|
reg_fa0 = RegisterOperand(prefix="f", name="a0")
|
|
reg_fa1 = RegisterOperand(prefix="f", name="a1")
|
|
reg_f10 = RegisterOperand(prefix="f", name="10")
|
|
|
|
# Test vector registers
|
|
reg_v1 = RegisterOperand(prefix="v", name="1")
|
|
reg_v2 = RegisterOperand(prefix="v", name="2")
|
|
|
|
# Test register type detection
|
|
self.assertTrue(self.parser.is_gpr(reg_a0))
|
|
self.assertTrue(self.parser.is_gpr(reg_x5))
|
|
self.assertTrue(self.parser.is_gpr(reg_sp))
|
|
|
|
self.assertFalse(self.parser.is_gpr(reg_fa0))
|
|
self.assertFalse(self.parser.is_gpr(reg_f10))
|
|
|
|
self.assertTrue(self.parser.is_vector_register(reg_v1))
|
|
self.assertFalse(self.parser.is_vector_register(reg_x10))
|
|
self.assertFalse(self.parser.is_vector_register(reg_fa0))
|
|
|
|
def test_normalize_imd(self):
|
|
imd_decimal = ImmediateOperand(value="42")
|
|
imd_hex = ImmediateOperand(value="0x2A")
|
|
imd_negative = ImmediateOperand(value="-12")
|
|
identifier = IdentifierOperand(name="function")
|
|
|
|
self.assertEqual(self.parser.normalize_imd(imd_decimal), 42)
|
|
self.assertEqual(self.parser.normalize_imd(imd_hex), 42)
|
|
self.assertEqual(self.parser.normalize_imd(imd_negative), -12)
|
|
self.assertEqual(self.parser.normalize_imd(identifier), identifier)
|
|
|
|
def test_memory_operand_parsing(self):
|
|
# Test memory operand parsing with different offsets and base registers
|
|
|
|
# Parse memory operands from real instructions
|
|
instr1 = "vle32.v v1,0(a1)"
|
|
instr2 = "lw a0,8(sp)"
|
|
instr3 = "sw ra,-4(sp)"
|
|
|
|
parsed1 = self.parser.parse_instruction(instr1)
|
|
parsed2 = self.parser.parse_instruction(instr2)
|
|
parsed3 = self.parser.parse_instruction(instr3)
|
|
|
|
# Verify memory operands
|
|
self.assertTrue(isinstance(parsed1.operands[1], MemoryOperand))
|
|
self.assertEqual(parsed1.operands[1].base.name, "a1")
|
|
self.assertEqual(parsed1.operands[1].offset.value, 0)
|
|
|
|
self.assertTrue(isinstance(parsed2.operands[1], MemoryOperand))
|
|
self.assertEqual(parsed2.operands[1].base.name, "sp")
|
|
self.assertEqual(parsed2.operands[1].offset.value, 8)
|
|
|
|
self.assertTrue(isinstance(parsed3.operands[1], MemoryOperand))
|
|
self.assertEqual(parsed3.operands[1].base.name, "sp")
|
|
self.assertEqual(parsed3.operands[1].offset.value, -4)
|
|
|
|
##################
|
|
# Helper functions
|
|
##################
|
|
def _get_comment(self, parser, comment):
|
|
return " ".join(
|
|
parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[
|
|
"comment"
|
|
]
|
|
)
|
|
|
|
def _get_label(self, parser, label):
|
|
return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict())
|
|
|
|
def _get_directive(self, parser, directive):
|
|
return parser.process_operand(
|
|
parser.directive.parseString(directive, parseAll=True).asDict()
|
|
)
|
|
|
|
@staticmethod
|
|
def _find_file(name):
|
|
testdir = os.path.dirname(__file__)
|
|
name = os.path.join(testdir, "test_files", name)
|
|
assert os.path.exists(name)
|
|
return name
|
|
|
|
|
|
if __name__ == "__main__":
|
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestParserRISCV)
|
|
unittest.TextTestRunner(verbosity=2).run(suite) |