From 152360bad2c9cdcb17cf8b522541db3db63126c0 Mon Sep 17 00:00:00 2001 From: JanLJL Date: Mon, 19 Apr 2021 00:04:03 +0200 Subject: [PATCH] enhanced LCD analysis by making it parallel and added timeout flag --- osaca/frontend.py | 59 +++++++++--- osaca/osaca.py | 12 ++- osaca/semantics/kernel_dg.py | 175 ++++++++++++++++++++++++----------- 3 files changed, 174 insertions(+), 72 deletions(-) diff --git a/osaca/frontend.py b/osaca/frontend.py index f9dc030..65ca702 100755 --- a/osaca/frontend.py +++ b/osaca/frontend.py @@ -163,6 +163,7 @@ class Frontend(object): ignore_unknown=False, arch_warning=False, length_warning=False, + lcd_warning=False, verbose=False, ): """ @@ -176,17 +177,19 @@ class Frontend(object): :param ignore_unknown: flag for ignore warning if performance data is missing, defaults to `False` :type ignore_unknown: boolean, optional - :param print_arch_warning: flag for additional user warning to specify micro-arch - :type print_arch_warning: boolean, optional - :param print_length_warning: flag for additional user warning to specify kernel length with + :param arch_warning: flag for additional user warning to specify micro-arch + :type arch_warning: boolean, optional + :param length_warning: flag for additional user warning to specify kernel length with --lines - :type print_length_warning: boolean, optional + :type length_warning: boolean, optional + :param lcd_warning: flag for additional user warning due to LCD analysis timed out + :type lcd_warning: boolean, optional :param verbose: flag for verbosity level, defaults to False :type verbose: boolean, optional """ return ( self._header_report() - + self._user_warnings(arch_warning, length_warning) + + self._user_warnings_header(arch_warning, length_warning) + self._symbol_map() + self.combined_view( kernel, @@ -194,6 +197,7 @@ class Frontend(object): kernel_dg.get_loopcarried_dependencies(), ignore_unknown, ) + + self._user_warnings_footer(lcd_warning) + self.loopcarried_dependencies(kernel_dg.get_loopcarried_dependencies()) ) @@ -236,8 +240,9 @@ class Frontend(object): if dep_dict: longest_lcd = max(dep_dict, key=lambda ln: dep_dict[ln]['latency']) lcd_sum = dep_dict[longest_lcd]['latency'] - lcd_lines = {instr["line_number"]: lat - for instr, lat in dep_dict[longest_lcd]["dependencies"]} + lcd_lines = { + instr["line_number"]: lat for instr, lat in dep_dict[longest_lcd]["dependencies"] + } s += headline_str.format(headline) + "\n" s += ( @@ -311,18 +316,24 @@ class Frontend(object): ).format(amount, "-" * len(str(amount))) return s - def _user_warnings(self, arch_warning, length_warning): + def _user_warnings_header(self, arch_warning, length_warning): """Returns warning texts for giving the user more insight in what he is doing.""" + dashed_line = ( + "-------------------------------------------------------------------------" + "------------------------\n" + ) arch_text = ( - "WARNING: No micro-architecture was specified and a default uarch was used.\n" - " Specify the uarch with --arch. See --help for more information.\n" + "-------------------------- WARNING: No micro-architecture was specified " + "-------------------------\n" + " A default uarch for this particular ISA was used. Specify " + "the uarch with --arch.\n See --help for more information.\n" + dashed_line ) length_text = ( - "WARNING: You are analyzing a large amount of instruction forms. Analysis " - "across loops/block boundaries often do not make much sense.\n" - " Specify the kernel length with --length. See --help for more " - "information.\n" - " If this is intentional, you can safely ignore this message.\n" + "----------------- WARNING: You are analyzing a large amount of instruction forms " + "----------------\n Analysis across loops/block boundaries often do not make" + " much sense.\n Specify the kernel length with --length. See --help for more " + "information.\n If this is intentional, you can safely ignore this message.\n" + + dashed_line ) warnings = "" @@ -331,6 +342,24 @@ class Frontend(object): warnings += "\n" return warnings + def _user_warnings_footer(self, lcd_warning): + """Returns warning texts for giving the user more insight in what he is doing.""" + dashed_line = ( + "-------------------------------------------------------------------------" + "------------------------\n" + ) + lcd_text = ( + "-------------------------------- WARNING: LCD analysis timed out " + "-------------------------------\n While searching for all dependency chains" + " the analysis timed out.\n Decrease the number of instructions or set the " + "timeout threshold with --lcd-timeout.\n See --help for more " + "information.\n" + dashed_line + ) + warnings = "\n" + warnings += lcd_text if lcd_warning else "" + warnings += "\n" + return warnings + def _get_separator_list(self, separator, separator_2=" "): """Creates column view for seperators in the TP/combined view.""" separator_list = [] diff --git a/osaca/osaca.py b/osaca/osaca.py index 40b25d9..b97bcb7 100755 --- a/osaca/osaca.py +++ b/osaca/osaca.py @@ -146,6 +146,15 @@ def create_parser(parser=None): action="store_true", help="Ignore if instructions cannot be found in the data file and print analysis anyway.", ) + parser.add_argument( + "--lcd-timeout", + dest="lcd_timeout", + metavar="SECONDS", + type=int, + default=10, + help="Set timeout in seconds for LCD analysis. After timeout, OSACA will continue" + " its analysis with the dependency paths found up to this point. Defaults to 10.", + ) parser.add_argument( "--verbose", "-v", action="count", default=0, help="Increases verbosity level." ) @@ -303,7 +312,7 @@ def inspect(args, output_file=sys.stdout): semantics.assign_optimal_throughput(kernel) # Create DiGrahps - kernel_graph = KernelDG(kernel, parser, machine_model, semantics) + kernel_graph = KernelDG(kernel, parser, machine_model, semantics, args.lcd_timeout) if args.dotpath is not None: kernel_graph.export_graph(args.dotpath if args.dotpath != "." else None) # Print analysis @@ -315,6 +324,7 @@ def inspect(args, output_file=sys.stdout): ignore_unknown=ignore_unknown, arch_warning=print_arch_warning, length_warning=print_length_warning, + lcd_warning=kernel_graph.timed_out, verbose=verbose, ), file=output_file, diff --git a/osaca/semantics/kernel_dg.py b/osaca/semantics/kernel_dg.py index dd6d5c3..94a5dfc 100755 --- a/osaca/semantics/kernel_dg.py +++ b/osaca/semantics/kernel_dg.py @@ -1,22 +1,39 @@ #!/usr/bin/env python3 import copy -from itertools import chain, product +import time from collections import defaultdict +from itertools import accumulate, chain, product +from multiprocessing import Manager, Process, cpu_count import networkx as nx - from osaca.parser import AttrDict -from osaca.semantics import INSTR_FLAGS, MachineModel, ArchSemantics +from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel + class KernelDG(nx.DiGraph): - def __init__(self, parsed_kernel, parser, hw_model: MachineModel, semantics: ArchSemantics): + # threshold for checking dependency graph sequential or in parallel + INSTRUCTION_THRESHOLD = 50 + + def __init__( + self, parsed_kernel, parser, hw_model: MachineModel, semantics: ArchSemantics, timeout=10 + ): + self.timed_out = False self.kernel = parsed_kernel self.parser = parser self.model = hw_model self.arch_sem = semantics self.dg = self.create_DG(self.kernel) - self.loopcarried_deps = self.check_for_loopcarried_dep(self.kernel) + self.loopcarried_deps = self.check_for_loopcarried_dep(self.kernel, timeout) + + def _extend_path(self, dst_list, kernel, dg, offset): + for instr in kernel: + generator_path = nx.algorithms.simple_paths.all_simple_paths( + dg, instr.line_number, instr.line_number + offset + ) + tmp_list = list(generator_path) + dst_list.extend(tmp_list) + # print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number'])) def create_DG(self, kernel): """ @@ -65,17 +82,19 @@ class KernelDG(nx.DiGraph): dg.nodes[dep["line_number"]]["instruction_form"] = dep return dg - def check_for_loopcarried_dep(self, kernel): + def check_for_loopcarried_dep(self, kernel, timeout=10): """ Try to find loop-carried dependencies in given kernel. :param kernel: Parsed asm kernel with assigned semantic information :type kernel: list + :param timeout: Timeout in seconds for parallel execution, defaults + to `10`. Set to `0` for no timeout + :type timeout: int :returns: `dict` -- dependency dictionary with all cyclic LCDs """ # increase line number for second kernel loop offset = max(1000, max([i.line_number for i in kernel])) - first_line_no = kernel[0].line_number tmp_kernel = [] + kernel for orig_iform in kernel: temp_iform = copy.copy(orig_iform) @@ -86,34 +105,72 @@ class KernelDG(nx.DiGraph): # build cyclic loop-carried dependencies loopcarried_deps = [] - paths = [] - for instr in kernel: - paths.append(nx.algorithms.simple_paths.all_simple_paths( - dg, instr.line_number, instr.line_number + offset)) + all_paths = [] + + klen = len(kernel) + if klen >= self.INSTRUCTION_THRESHOLD: + # parallel execution with static scheduling + num_cores = cpu_count() + workload = int((klen - 1) / num_cores) + 1 + starts = [tid * workload for tid in range(num_cores)] + ends = [min((tid + 1) * workload, klen) for tid in range(num_cores)] + instrs = [kernel[s:e] for s, e in zip(starts, ends)] + with Manager() as manager: + all_paths = manager.list() + processes = [ + Process(target=self._extend_path, args=(all_paths, instr_section, dg, offset)) + for instr_section in instrs + ] + for p in processes: + p.start() + start_time = time.time() + while time.time() - start_time <= timeout: + if any(p.is_alive() for p in processes): + time.sleep(0.2) + else: + # all procs done + for p in processes: + p.join() + break + else: + self.timed_out = True + # terminate running processes + for p in processes: + if p.is_alive(): + p.kill() + p.join() + all_paths = list(all_paths) + else: + # sequential execution to avoid overhead when analyzing smaller kernels + for instr in kernel: + all_paths.extend( + nx.algorithms.simple_paths.all_simple_paths( + dg, instr.line_number, instr.line_number + offset + ) + ) paths_set = set() - for path_gen in paths: - for path in path_gen: - lat_sum = 0.0 - # extend path by edge bound latencies (e.g., store-to-load latency) - lat_path = [] - for s, d in nx.utils.pairwise(path): - edge_lat = dg.edges[s, d]['latency'] - # map source node back to original line numbers - if s >= offset: - s -= offset - lat_path.append((s, edge_lat)) - lat_sum += edge_lat - if d >= offset: - d -= offset - lat_path.sort() + for path in all_paths: + lat_sum = 0.0 + # extend path by edge bound latencies (e.g., store-to-load latency) + lat_path = [] + for s, d in nx.utils.pairwise(path): + edge_lat = dg.edges[s, d]['latency'] + # map source node back to original line numbers + if s >= offset: + s -= offset + lat_path.append((s, edge_lat)) + lat_sum += edge_lat + if d >= offset: + d -= offset + lat_path.sort() - # Ignore duplicate paths which differ only in the root node - if tuple(lat_path) in paths_set: - continue - paths_set.add(tuple(lat_path)) + # Ignore duplicate paths which differ only in the root node + if tuple(lat_path) in paths_set: + continue + paths_set.add(tuple(lat_path)) - loopcarried_deps.append((lat_sum, lat_path)) + loopcarried_deps.append((lat_sum, lat_path)) loopcarried_deps.sort(reverse=True) # map lcd back to nodes @@ -121,8 +178,10 @@ class KernelDG(nx.DiGraph): for lat_sum, involved_lines in loopcarried_deps: loopcarried_deps_dict[involved_lines[0][0]] = { "root": self._get_node_by_lineno(involved_lines[0][0]), - "dependencies": [(self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines], - "latency": lat_sum + "dependencies": [ + (self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines + ], + "latency": lat_sum, } return loopcarried_deps_dict @@ -168,9 +227,7 @@ class KernelDG(nx.DiGraph): # split to DAG raise NotImplementedError("Kernel is cyclic.") - def find_depending( - self, instruction_form, instructions, flag_dependencies=False - ): + def find_depending(self, instruction_form, instructions, flag_dependencies=False): """ Find instructions in `instructions` depending on a given instruction form's results. @@ -190,15 +247,15 @@ class KernelDG(nx.DiGraph): # TODO instructions before must be considered as well, if they update registers # not used by insruction_form. E.g., validation/build/A64FX/gcc/O1/gs-2d-5pt.marked.s register_changes = self._update_reg_changes(instruction_form) - #print("FROM", instruction_form.line, register_changes) + # print("FROM", instruction_form.line, register_changes) for i, instr_form in enumerate(instructions): self._update_reg_changes(instr_form, register_changes) - #print(" TO", instr_form.line, register_changes) + # print(" TO", instr_form.line, register_changes) if "register" in dst: # read of register if self.is_read(dst.register, instr_form) and not ( - dst.get("pre_indexed", False) or - dst.get("post_indexed", False)): + dst.get("pre_indexed", False) or dst.get("post_indexed", False) + ): yield instr_form, [] # write to register -> abort if self.is_written(dst.register, instr_form): @@ -215,10 +272,10 @@ class KernelDG(nx.DiGraph): if "pre_indexed" in dst.memory: if self.is_written(dst.memory.base, instr_form): break - #if dst.memory.base: + # if dst.memory.base: # if self.is_read(dst.memory.base, instr_form): # yield instr_form, [] - #if dst.memory.index: + # if dst.memory.index: # if self.is_read(dst.memory.index, instr_form): # yield instr_form, [] if "post_indexed" in dst.memory: @@ -226,7 +283,7 @@ class KernelDG(nx.DiGraph): if self.is_written(dst.memory.base, instr_form): break # TODO record register changes - # (e.g., mov, leaadd, sub, inc, dec) in instructions[:i] + # (e.g., mov, leaadd, sub, inc, dec) in instructions[:i] # and pass to is_memload and is_memstore to consider relevance. # load from same location (presumed) if self.is_memload(dst.memory, instr_form, register_changes): @@ -286,7 +343,9 @@ class KernelDG(nx.DiGraph): if src.memory.base is not None: is_read = self.parser.is_reg_dependend_of(register, src.memory.base) or is_read if src.memory.index is not None: - is_read = self.parser.is_reg_dependend_of(register, src.memory.index) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, src.memory.index) or is_read + ) # Check also if read in destination memory address for dst in chain( instruction_form.semantic_operands.destination, @@ -296,7 +355,9 @@ class KernelDG(nx.DiGraph): if dst.memory.base is not None: is_read = self.parser.is_reg_dependend_of(register, dst.memory.base) or is_read if dst.memory.index is not None: - is_read = self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read + is_read = ( + self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read + ) return is_read def is_memload(self, mem, instruction_form, register_changes={}): @@ -319,36 +380,38 @@ class KernelDG(nx.DiGraph): addr_change -= int(mem.offset.value, 0) if mem.base and src.base: base_change = register_changes.get( - src.base.get('prefix', '')+src.base.name, - {'name': src.base.get('prefix', '')+src.base.name, 'value': 0}) + src.base.get('prefix', '') + src.base.name, + {'name': src.base.get('prefix', '') + src.base.name, 'value': 0}, + ) if base_change is None: # Unknown change occurred continue - if mem.base.get('prefix', '')+mem.base['name'] != base_change['name']: + if mem.base.get('prefix', '') + mem.base['name'] != base_change['name']: # base registers do not match continue addr_change += base_change['value'] elif mem.base or src.base: - # base registers do not match - continue + # base registers do not match + continue if mem.index and src.index: index_change = register_changes.get( - src.index.get('prefix', '')+src.index.name, - {'name': src.index.get('prefix', '')+src.index.name, 'value': 0}) + src.index.get('prefix', '') + src.index.name, + {'name': src.index.get('prefix', '') + src.index.name, 'value': 0}, + ) if index_change is None: # Unknown change occurred continue if mem.scale != src.scale: # scale factors do not match continue - if mem.index.get('prefix', '')+mem.index['name'] != index_change['name']: + if mem.index.get('prefix', '') + mem.index['name'] != index_change['name']: # index registers do not match continue addr_change += index_change['value'] * src.scale elif mem.index or src.index: - # index registers do not match - continue - #if instruction_form.line_number == 3: + # index registers do not match + continue + # if instruction_form.line_number == 3: if addr_change == 0: return True return False