diff --git a/osaca/semantics/kernel_dg.py b/osaca/semantics/kernel_dg.py index 2dd46fb..c135b02 100644 --- a/osaca/semantics/kernel_dg.py +++ b/osaca/semantics/kernel_dg.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import copy +from enum import Enum import time from itertools import chain, groupby from multiprocessing import Manager, Process, cpu_count @@ -17,6 +18,11 @@ class KernelDG(nx.DiGraph): # threshold for checking dependency graph sequential or in parallel INSTRUCTION_THRESHOLD = 50 + class ReadKind(Enum): + NOT_A_READ = 0 + READ_FOR_LOAD = 1 + OTHER_READ = 2 + def __init__( self, parsed_kernel, @@ -46,6 +52,11 @@ class KernelDG(nx.DiGraph): dst_list.extend(tmp_list) # print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number'])) + @staticmethod + def get_load_line_number(line_number): + # The offset is irrelevant, but it must be a machine number to avoid silly rounding issues. + return line_number + 0.125 + def create_DG(self, kernel, flag_dependencies=False): """ Create directed graph from given kernel @@ -61,26 +72,37 @@ class KernelDG(nx.DiGraph): # 2. find edges (to dependend further instruction) # 3. get LT value and set as edge weight dg = nx.DiGraph() + loads = {} for i, instruction_form in enumerate(kernel): dg.add_node(instruction_form.line_number) dg.nodes[instruction_form.line_number]["instruction_form"] = instruction_form # add load as separate node if existent + load_line_number = None if ( INSTR_FLAGS.HAS_LD in instruction_form.flags and INSTR_FLAGS.LD not in instruction_form.flags ): # add new node - dg.add_node(instruction_form.line_number + 0.1) - dg.nodes[instruction_form.line_number + 0.1]["instruction_form"] = instruction_form + load_line_number = KernelDG.get_load_line_number(instruction_form.line_number) + loads[instruction_form.line_number] = load_line_number + dg.add_node(load_line_number) + dg.nodes[load_line_number]["instruction_form"] = instruction_form # and set LD latency as edge weight dg.add_edge( - instruction_form.line_number + 0.1, + load_line_number, instruction_form.line_number, latency=instruction_form.latency - instruction_form.latency_wo_load, ) + #TODO comments + #print("LOADS", loads) + for i, instruction_form in enumerate(kernel): for dep, dep_flags in self.find_depending( instruction_form, kernel[i + 1 :], flag_dependencies ): + #if instruction_form.mnemonic == 'shl': + # print("IF", instruction_form) + # print("DEP", dep) + # print("DF", dep_flags) # print(instruction_form.line_number,"\t",dep.line_number,"\n") edge_weight = ( instruction_form.latency @@ -91,11 +113,19 @@ class KernelDG(nx.DiGraph): edge_weight += self.model.get("store_to_load_forward_latency", 0) if "p_indexed" in dep_flags and self.model is not None: edge_weight = self.model.get("p_index_latency", 1) - dg.add_edge( - instruction_form.line_number, - dep.line_number, - latency=edge_weight, - ) + if "for_load" in dep_flags and self.model is not None and dep.line_number in loads: + #print("LOADDEP", instruction_form.line_number, loads[dep.line_number]) + dg.add_edge( + instruction_form.line_number, + loads[dep.line_number], + latency=edge_weight, + ) + else: + dg.add_edge( + instruction_form.line_number, + dep.line_number, + latency=edge_weight, + ) dg.nodes[dep.line_number]["instruction_form"] = dep return dg @@ -214,9 +244,12 @@ class KernelDG(nx.DiGraph): def _get_node_by_lineno(self, lineno, kernel=None, all=False): """Return instruction form with line number ``lineno`` from kernel""" + #print(lineno) if kernel is None: kernel = self.kernel - result = [instr for instr in kernel if instr.line_number == lineno] + result = [instr for instr in kernel + if instr.line_number == lineno + or KernelDG.get_load_line_number(instr.line_number) == lineno] if not all: return result[0] else: @@ -284,15 +317,18 @@ class KernelDG(nx.DiGraph): # print(" TO", instr_form.line, register_changes) if isinstance(dst, RegisterOperand): # read of register - if self.is_read(dst, instr_form): + read_kind = self._read_kind(dst, instr_form) + if read_kind != KernelDG.ReadKind.NOT_A_READ: + dep_flags = [] if ( dst.pre_indexed or dst.post_indexed or (isinstance(dst.post_indexed, dict)) ): - yield instr_form, ["p_indexed"] - else: - yield instr_form, [] + dep_flags = ["p_indexed"] + if read_kind == KernelDG.ReadKind.READ_FOR_LOAD: + dep_flags += ["for_load"] + yield instr_form, dep_flags # write to register -> abort if self.is_written(dst, instr_form): break @@ -363,11 +399,12 @@ class KernelDG(nx.DiGraph): return self.dg.successors(line_number) return iter([]) - def is_read(self, register, instruction_form): - """Check if instruction form reads from given register""" + def _read_kind(self, register, instruction_form): + """Check if instruction form reads from given register. Returns a ReadKind.""" is_read = False + for_load = False if instruction_form.semantic_operands is None: - return is_read + return KernelDG.ReadKind.NOT_A_READ for src in chain( instruction_form.semantic_operands["source"], instruction_form.semantic_operands["src_dst"], @@ -377,10 +414,15 @@ class KernelDG(nx.DiGraph): if isinstance(src, FlagOperand): is_read = self.parser.is_flag_dependend_of(register, src) or is_read if isinstance(src, MemoryOperand): + #print("1", is_read) if src.base is not None: is_read = self.parser.is_reg_dependend_of(register, src.base) or is_read + #print("2", is_read) if src.index is not None and isinstance(src.index, RegisterOperand): is_read = self.parser.is_reg_dependend_of(register, src.index) or is_read + #print("3", is_read) + #print("FORLOAD", register, src) + for_load = True # Check also if read in destination memory address for dst in chain( instruction_form.semantic_operands["destination"], @@ -391,7 +433,16 @@ class KernelDG(nx.DiGraph): is_read = self.parser.is_reg_dependend_of(register, dst.base) or is_read if dst.index is not None: is_read = self.parser.is_reg_dependend_of(register, dst.index) or is_read - return is_read + if is_read: + if for_load: + return KernelDG.ReadKind.READ_FOR_LOAD + else: + return KernelDG.ReadKind.OTHER_READ + else: + return KernelDG.ReadKind.NOT_A_READ + + def is_read(self, register, instruction_form): + return self._read_kind(register, instruction_form) != KernelDG.ReadKind.NOT_A_READ def is_memload(self, mem, instruction_form, register_changes={}): """Check if instruction form loads from given location, assuming register_changes"""