Files
OSACA/tests/test_parser_RISCV.py
Metehan Dundar 074118dee0 RISC-V: Update parser to use x-register names, add vector and FP instructions, fix tests
- Modified RISC-V parser to use x-register names instead of ABI names
- Added new vector instructions (vsetvli, vle8.v, vse8.v, vfmacc.vv, vfmul.vf)
- Added floating point instructions (fmul.d)
- Added unconditional jump instruction (j)
- Updated tests to match new register naming convention
- Added new RISC-V example files
- Updated .gitignore to exclude test environment and old examples
2025-06-30 00:28:52 +02:00

261 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Unit tests for RISC-V assembly parser
"""
import os
import unittest
from pyparsing import ParseException
from osaca.parser import ParserRISCV, InstructionForm
from osaca.parser.directive import DirectiveOperand
from osaca.parser.memory import MemoryOperand
from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand
from osaca.parser.identifier import IdentifierOperand
class TestParserRISCV(unittest.TestCase):
@classmethod
def setUpClass(self):
self.parser = ParserRISCV()
with open(self._find_file("kernel_riscv.s")) as f:
self.riscv_code = f.read()
##################
# Test
##################
def test_comment_parser(self):
self.assertEqual(self._get_comment(self.parser, "# some comments"), "some comments")
self.assertEqual(
self._get_comment(self.parser, "\t\t# RISC-V comment \t end \t"), "RISC-V comment end"
)
self.assertEqual(
self._get_comment(self.parser, "\t## comment ## comment"),
"# comment ## comment",
)
def test_label_parser(self):
# Test common label patterns from kernel_riscv.s
self.assertEqual(self._get_label(self.parser, "saxpy_golden:")[0].name, "saxpy_golden")
self.assertEqual(self._get_label(self.parser, ".L4:")[0].name, ".L4")
self.assertEqual(self._get_label(self.parser, ".L25:\t\t\t# Return")[0].name, ".L25")
self.assertEqual(
" ".join(self._get_label(self.parser, ".L25:\t\t\t# Return")[1]),
"Return",
)
with self.assertRaises(ParseException):
self._get_label(self.parser, "\t.word 1113498583")
def test_directive_parser(self):
self.assertEqual(self._get_directive(self.parser, "\t.text")[0].name, "text")
self.assertEqual(len(self._get_directive(self.parser, "\t.text")[0].parameters), 0)
self.assertEqual(self._get_directive(self.parser, "\t.word\t1113498583")[0].name, "word")
self.assertEqual(
len(self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters), 1
)
self.assertEqual(
self._get_directive(self.parser, "\t.word\t1113498583")[0].parameters[0], "1113498583"
)
# Test string directive
self.assertEqual(
self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].name, "string"
)
self.assertEqual(
self._get_directive(self.parser, '.string "fail, %f=!%f\\n"')[0].parameters[0],
'"fail, %f=!%f\\n"'
)
# Test set directive
self.assertEqual(
self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].name, "set"
)
self.assertEqual(
len(self._get_directive(self.parser, "\t.set\t.LANCHOR0,. + 0")[0].parameters), 2
)
def test_parse_instruction(self):
"""Test parsing of instructions."""
# Test 1: Simple instruction
parsed_1 = self.parser.parse_instruction("addi x10, x10, 1")
self.assertEqual(parsed_1.mnemonic, "addi")
self.assertEqual(len(parsed_1.operands), 3)
self.assertEqual(parsed_1.operands[0].name, "10")
self.assertEqual(parsed_1.operands[1].name, "10")
self.assertEqual(parsed_1.operands[2].value, 1)
# Test 2: Vector instruction
parsed_2 = self.parser.parse_instruction("vle64.v v1, (x11)")
self.assertEqual(parsed_2.mnemonic, "vle64.v")
self.assertEqual(len(parsed_2.operands), 2)
self.assertEqual(parsed_2.operands[0].name, "1")
self.assertEqual(parsed_2.operands[1].base.name, "11")
# Test 3: Floating point instruction
parsed_3 = self.parser.parse_instruction("fmul.d f10, f10, f11")
self.assertEqual(parsed_3.mnemonic, "fmul.d")
self.assertEqual(len(parsed_3.operands), 3)
self.assertEqual(parsed_3.operands[0].name, "10")
self.assertEqual(parsed_3.operands[1].name, "10")
self.assertEqual(parsed_3.operands[2].name, "11")
def test_parse_line(self):
"""Test parsing of complete lines."""
# Test 1: Line with label and instruction
parsed_1 = self.parser.parse_line(".L2:")
self.assertEqual(parsed_1.label, ".L2")
# Test 2: Line with instruction and comment
parsed_2 = self.parser.parse_line("addi x10, x10, 1 # increment")
self.assertEqual(parsed_2.mnemonic, "addi")
self.assertEqual(len(parsed_2.operands), 3)
self.assertEqual(parsed_2.operands[0].name, "10")
self.assertEqual(parsed_2.operands[1].name, "10")
self.assertEqual(parsed_2.operands[2].value, 1)
self.assertEqual(parsed_2.comment, "increment")
def test_parse_file(self):
parsed = self.parser.parse_file(self.riscv_code)
self.assertGreater(len(parsed), 10) # There should be multiple lines
# Find common elements that should exist in any RISC-V file
# without being tied to specific line numbers
# Verify that we can find at least one label
label_forms = [form for form in parsed if form.label is not None]
self.assertGreater(len(label_forms), 0, "No labels found in the file")
# Verify that we can find at least one branch instruction
branch_forms = [form for form in parsed if form.mnemonic and form.mnemonic.startswith("b")]
self.assertGreater(len(branch_forms), 0, "No branch instructions found in the file")
# Verify that we can find at least one store/load instruction
mem_forms = [form for form in parsed if form.mnemonic and (
form.mnemonic.startswith("s") or
form.mnemonic.startswith("l"))]
self.assertGreater(len(mem_forms), 0, "No memory instructions found in the file")
# Verify that we can find at least one directive
directive_forms = [form for form in parsed if form.directive is not None]
self.assertGreater(len(directive_forms), 0, "No directives found in the file")
def test_register_dependency(self):
# Test ABI name to register number mapping
reg_zero = RegisterOperand(name="zero")
reg_ra = RegisterOperand(name="ra")
reg_sp = RegisterOperand(name="sp")
reg_a0 = RegisterOperand(name="a0")
reg_t1 = RegisterOperand(name="t1")
reg_s2 = RegisterOperand(name="s2")
reg_x0 = RegisterOperand(prefix="x", name="0")
reg_x1 = RegisterOperand(prefix="x", name="1")
reg_x2 = RegisterOperand(prefix="x", name="2")
reg_x5 = RegisterOperand(prefix="x", name="5") # Define reg_x5 for use in tests below
reg_x10 = RegisterOperand(prefix="x", name="10")
reg_x6 = RegisterOperand(prefix="x", name="6")
reg_x18 = RegisterOperand(prefix="x", name="18")
# Test canonical name conversion
self.assertEqual(self.parser._get_canonical_reg_name(reg_zero), "x0")
self.assertEqual(self.parser._get_canonical_reg_name(reg_ra), "x1")
self.assertEqual(self.parser._get_canonical_reg_name(reg_sp), "x2")
self.assertEqual(self.parser._get_canonical_reg_name(reg_a0), "x10")
self.assertEqual(self.parser._get_canonical_reg_name(reg_t1), "x6")
self.assertEqual(self.parser._get_canonical_reg_name(reg_s2), "x18")
# Test register dependency
self.assertTrue(self.parser.is_reg_dependend_of(reg_zero, reg_x0))
self.assertTrue(self.parser.is_reg_dependend_of(reg_ra, reg_x1))
self.assertTrue(self.parser.is_reg_dependend_of(reg_sp, reg_x2))
self.assertTrue(self.parser.is_reg_dependend_of(reg_a0, reg_x10))
self.assertTrue(self.parser.is_reg_dependend_of(reg_t1, reg_x6))
self.assertTrue(self.parser.is_reg_dependend_of(reg_s2, reg_x18))
# Test non-dependent registers
self.assertFalse(self.parser.is_reg_dependend_of(reg_zero, reg_x1))
self.assertFalse(self.parser.is_reg_dependend_of(reg_ra, reg_x2))
self.assertFalse(self.parser.is_reg_dependend_of(reg_a0, reg_t1))
# Test floating-point registers
reg_fa0 = RegisterOperand(prefix="f", name="a0")
reg_fa1 = RegisterOperand(prefix="f", name="a1")
reg_f10 = RegisterOperand(prefix="f", name="10")
# Test vector registers
reg_v1 = RegisterOperand(prefix="v", name="1")
reg_v2 = RegisterOperand(prefix="v", name="2")
# Test register type detection
self.assertTrue(self.parser.is_gpr(reg_a0))
self.assertTrue(self.parser.is_gpr(reg_x5))
self.assertTrue(self.parser.is_gpr(reg_sp))
self.assertFalse(self.parser.is_gpr(reg_fa0))
self.assertFalse(self.parser.is_gpr(reg_f10))
self.assertTrue(self.parser.is_vector_register(reg_v1))
self.assertFalse(self.parser.is_vector_register(reg_x10))
self.assertFalse(self.parser.is_vector_register(reg_fa0))
def test_normalize_imd(self):
imd_decimal = ImmediateOperand(value="42")
imd_hex = ImmediateOperand(value="0x2A")
imd_negative = ImmediateOperand(value="-12")
identifier = IdentifierOperand(name="function")
self.assertEqual(self.parser.normalize_imd(imd_decimal), 42)
self.assertEqual(self.parser.normalize_imd(imd_hex), 42)
self.assertEqual(self.parser.normalize_imd(imd_negative), -12)
self.assertEqual(self.parser.normalize_imd(identifier), identifier)
def test_memory_operand_parsing(self):
"""Test parsing of memory operands."""
# Test 1: Basic memory operand
parsed1 = self.parser.parse_instruction("vle8.v v1, (x11)")
self.assertEqual(parsed1.mnemonic, "vle8.v")
self.assertEqual(len(parsed1.operands), 2)
self.assertEqual(parsed1.operands[0].name, "1")
self.assertEqual(parsed1.operands[1].base.name, "11")
self.assertEqual(parsed1.operands[1].offset, None)
self.assertEqual(parsed1.operands[1].index, None)
# Test 2: Memory operand with offset
parsed2 = self.parser.parse_instruction("vle8.v v1, 8(x11)")
self.assertEqual(parsed2.mnemonic, "vle8.v")
self.assertEqual(len(parsed2.operands), 2)
self.assertEqual(parsed2.operands[0].name, "1")
self.assertEqual(parsed2.operands[1].base.name, "11")
self.assertEqual(parsed2.operands[1].offset.value, 8)
self.assertEqual(parsed2.operands[1].index, None)
##################
# Helper functions
##################
def _get_comment(self, parser, comment):
return " ".join(
parser.process_operand(parser.comment.parseString(comment, parseAll=True).asDict())[
"comment"
]
)
def _get_label(self, parser, label):
return parser.process_operand(parser.label.parseString(label, parseAll=True).asDict())
def _get_directive(self, parser, directive):
return parser.process_operand(
parser.directive.parseString(directive, parseAll=True).asDict()
)
@staticmethod
def _find_file(name):
testdir = os.path.dirname(__file__)
name = os.path.join(testdir, "test_files", name)
assert os.path.exists(name)
return name
if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromTestCase(TestParserRISCV)
unittest.TextTestRunner(verbosity=2).run(suite)