diff --git a/osaca/semantics/hw_model.py b/osaca/semantics/hw_model.py index b696e2a..ef6a61d 100755 --- a/osaca/semantics/hw_model.py +++ b/osaca/semantics/hw_model.py @@ -526,7 +526,7 @@ class MachineModel(object): if 'register' in operand: if i_operand['class'] != 'register': return False - return self._is_x86_reg_type(i_operand['name'], operand['register']) + return self._is_x86_reg_type(i_operand, operand['register'], consider_masking=True) # memory if 'memory' in operand: if i_operand['class'] != 'memory': @@ -546,7 +546,9 @@ class MachineModel(object): ) for key in operand_attributes: try: - if operand_1[key] != operand_2[key] and not any([x == self.WILDCARD for x in [operand_1[key], operand_2[key]]]): + if operand_1[key] != operand_2[key] and not any( + [x == self.WILDCARD for x in [operand_1[key], operand_2[key]]] + ): return False except KeyError: return False @@ -573,8 +575,9 @@ class MachineModel(object): return False return True - def _is_x86_reg_type(self, i_reg_name, reg): + def _is_x86_reg_type(self, i_reg, reg, consider_masking=False): """Check if register type match.""" + i_reg_name = i_reg if not consider_masking else i_reg['name'] # check for wildcards if i_reg_name == self.WILDCARD or reg['name'] == self.WILDCARD: return True @@ -582,6 +585,33 @@ class MachineModel(object): parser_x86 = ParserX86ATT() if parser_x86.is_vector_register(reg): if reg['name'].rstrip(string.digits).lower() == i_reg_name: + # Consider masking and zeroing for AVX512 + if consider_masking: + mask_ok = zero_ok = True + if 'mask' in reg or 'mask' in i_reg: + # one instruction is missing the masking while the other has it + mask_ok = False + # check for wildcard + if ( + ( + 'mask' in reg + and reg['mask'].rstrip(string.digits).lower() == i_reg.get('mask') + ) + or reg.get('mask') == self.WILDCARD + or i_reg.get('mask') == self.WILDCARD + ): + mask_ok = True + if bool('zeroing' in reg) ^ bool('zeroing' in i_reg): + # one instruction is missing zeroing while the other has it + zero_ok = False + # check for wildcard + if ( + i_reg.get('zeroing') == self.WILDCARD + or reg.get('zeroing') == self.WILDCARD + ): + zero_ok = True + if not mask_ok or not zero_ok: + return False return True else: if i_reg_name == 'gpr':