enhanced LCD analysis by making it parallel and added timeout flag

This commit is contained in:
JanLJL
2021-04-19 00:04:03 +02:00
parent 3d580960b6
commit 5b95f1f909
3 changed files with 174 additions and 72 deletions

View File

@@ -163,6 +163,7 @@ class Frontend(object):
ignore_unknown=False,
arch_warning=False,
length_warning=False,
lcd_warning=False,
verbose=False,
):
"""
@@ -176,17 +177,19 @@ class Frontend(object):
:param ignore_unknown: flag for ignore warning if performance data is missing, defaults to
`False`
:type ignore_unknown: boolean, optional
:param print_arch_warning: flag for additional user warning to specify micro-arch
:type print_arch_warning: boolean, optional
:param print_length_warning: flag for additional user warning to specify kernel length with
:param arch_warning: flag for additional user warning to specify micro-arch
:type arch_warning: boolean, optional
:param length_warning: flag for additional user warning to specify kernel length with
--lines
:type print_length_warning: boolean, optional
:type length_warning: boolean, optional
:param lcd_warning: flag for additional user warning due to LCD analysis timed out
:type lcd_warning: boolean, optional
:param verbose: flag for verbosity level, defaults to False
:type verbose: boolean, optional
"""
return (
self._header_report()
+ self._user_warnings(arch_warning, length_warning)
+ self._user_warnings_header(arch_warning, length_warning)
+ self._symbol_map()
+ self.combined_view(
kernel,
@@ -194,6 +197,7 @@ class Frontend(object):
kernel_dg.get_loopcarried_dependencies(),
ignore_unknown,
)
+ self._user_warnings_footer(lcd_warning)
+ self.loopcarried_dependencies(kernel_dg.get_loopcarried_dependencies())
)
@@ -236,8 +240,9 @@ class Frontend(object):
if dep_dict:
longest_lcd = max(dep_dict, key=lambda ln: dep_dict[ln]['latency'])
lcd_sum = dep_dict[longest_lcd]['latency']
lcd_lines = {instr["line_number"]: lat
for instr, lat in dep_dict[longest_lcd]["dependencies"]}
lcd_lines = {
instr["line_number"]: lat for instr, lat in dep_dict[longest_lcd]["dependencies"]
}
s += headline_str.format(headline) + "\n"
s += (
@@ -311,18 +316,24 @@ class Frontend(object):
).format(amount, "-" * len(str(amount)))
return s
def _user_warnings(self, arch_warning, length_warning):
def _user_warnings_header(self, arch_warning, length_warning):
"""Returns warning texts for giving the user more insight in what he is doing."""
dashed_line = (
"-------------------------------------------------------------------------"
"------------------------\n"
)
arch_text = (
"WARNING: No micro-architecture was specified and a default uarch was used.\n"
" Specify the uarch with --arch. See --help for more information.\n"
"-------------------------- WARNING: No micro-architecture was specified "
"-------------------------\n"
" A default uarch for this particular ISA was used. Specify "
"the uarch with --arch.\n See --help for more information.\n" + dashed_line
)
length_text = (
"WARNING: You are analyzing a large amount of instruction forms. Analysis "
"across loops/block boundaries often do not make much sense.\n"
" Specify the kernel length with --length. See --help for more "
"information.\n"
" If this is intentional, you can safely ignore this message.\n"
"----------------- WARNING: You are analyzing a large amount of instruction forms "
"----------------\n Analysis across loops/block boundaries often do not make"
" much sense.\n Specify the kernel length with --length. See --help for more "
"information.\n If this is intentional, you can safely ignore this message.\n"
+ dashed_line
)
warnings = ""
@@ -331,6 +342,24 @@ class Frontend(object):
warnings += "\n"
return warnings
def _user_warnings_footer(self, lcd_warning):
"""Returns warning texts for giving the user more insight in what he is doing."""
dashed_line = (
"-------------------------------------------------------------------------"
"------------------------\n"
)
lcd_text = (
"-------------------------------- WARNING: LCD analysis timed out "
"-------------------------------\n While searching for all dependency chains"
" the analysis timed out.\n Decrease the number of instructions or set the "
"timeout threshold with --lcd-timeout.\n See --help for more "
"information.\n" + dashed_line
)
warnings = "\n"
warnings += lcd_text if lcd_warning else ""
warnings += "\n"
return warnings
def _get_separator_list(self, separator, separator_2=" "):
"""Creates column view for seperators in the TP/combined view."""
separator_list = []

View File

@@ -146,6 +146,15 @@ def create_parser(parser=None):
action="store_true",
help="Ignore if instructions cannot be found in the data file and print analysis anyway.",
)
parser.add_argument(
"--lcd-timeout",
dest="lcd_timeout",
metavar="SECONDS",
type=int,
default=10,
help="Set timeout in seconds for LCD analysis. After timeout, OSACA will continue"
" its analysis with the dependency paths found up to this point. Defaults to 10.",
)
parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Increases verbosity level."
)
@@ -303,7 +312,7 @@ def inspect(args, output_file=sys.stdout):
semantics.assign_optimal_throughput(kernel)
# Create DiGrahps
kernel_graph = KernelDG(kernel, parser, machine_model, semantics)
kernel_graph = KernelDG(kernel, parser, machine_model, semantics, args.lcd_timeout)
if args.dotpath is not None:
kernel_graph.export_graph(args.dotpath if args.dotpath != "." else None)
# Print analysis
@@ -315,6 +324,7 @@ def inspect(args, output_file=sys.stdout):
ignore_unknown=ignore_unknown,
arch_warning=print_arch_warning,
length_warning=print_length_warning,
lcd_warning=kernel_graph.timed_out,
verbose=verbose,
),
file=output_file,

View File

@@ -1,22 +1,39 @@
#!/usr/bin/env python3
import copy
from itertools import chain, product
import time
from collections import defaultdict
from itertools import accumulate, chain, product
from multiprocessing import Manager, Process, cpu_count
import networkx as nx
from osaca.parser import AttrDict
from osaca.semantics import INSTR_FLAGS, MachineModel, ArchSemantics
from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel
class KernelDG(nx.DiGraph):
def __init__(self, parsed_kernel, parser, hw_model: MachineModel, semantics: ArchSemantics):
# threshold for checking dependency graph sequential or in parallel
INSTRUCTION_THRESHOLD = 50
def __init__(
self, parsed_kernel, parser, hw_model: MachineModel, semantics: ArchSemantics, timeout=10
):
self.timed_out = False
self.kernel = parsed_kernel
self.parser = parser
self.model = hw_model
self.arch_sem = semantics
self.dg = self.create_DG(self.kernel)
self.loopcarried_deps = self.check_for_loopcarried_dep(self.kernel)
self.loopcarried_deps = self.check_for_loopcarried_dep(self.kernel, timeout)
def _extend_path(self, dst_list, kernel, dg, offset):
for instr in kernel:
generator_path = nx.algorithms.simple_paths.all_simple_paths(
dg, instr.line_number, instr.line_number + offset
)
tmp_list = list(generator_path)
dst_list.extend(tmp_list)
# print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number']))
def create_DG(self, kernel):
"""
@@ -65,17 +82,19 @@ class KernelDG(nx.DiGraph):
dg.nodes[dep["line_number"]]["instruction_form"] = dep
return dg
def check_for_loopcarried_dep(self, kernel):
def check_for_loopcarried_dep(self, kernel, timeout=10):
"""
Try to find loop-carried dependencies in given kernel.
:param kernel: Parsed asm kernel with assigned semantic information
:type kernel: list
:param timeout: Timeout in seconds for parallel execution, defaults
to `10`. Set to `0` for no timeout
:type timeout: int
:returns: `dict` -- dependency dictionary with all cyclic LCDs
"""
# increase line number for second kernel loop
offset = max(1000, max([i.line_number for i in kernel]))
first_line_no = kernel[0].line_number
tmp_kernel = [] + kernel
for orig_iform in kernel:
temp_iform = copy.copy(orig_iform)
@@ -86,34 +105,72 @@ class KernelDG(nx.DiGraph):
# build cyclic loop-carried dependencies
loopcarried_deps = []
paths = []
for instr in kernel:
paths.append(nx.algorithms.simple_paths.all_simple_paths(
dg, instr.line_number, instr.line_number + offset))
all_paths = []
klen = len(kernel)
if klen >= self.INSTRUCTION_THRESHOLD:
# parallel execution with static scheduling
num_cores = cpu_count()
workload = int((klen - 1) / num_cores) + 1
starts = [tid * workload for tid in range(num_cores)]
ends = [min((tid + 1) * workload, klen) for tid in range(num_cores)]
instrs = [kernel[s:e] for s, e in zip(starts, ends)]
with Manager() as manager:
all_paths = manager.list()
processes = [
Process(target=self._extend_path, args=(all_paths, instr_section, dg, offset))
for instr_section in instrs
]
for p in processes:
p.start()
start_time = time.time()
while time.time() - start_time <= timeout:
if any(p.is_alive() for p in processes):
time.sleep(0.2)
else:
# all procs done
for p in processes:
p.join()
break
else:
self.timed_out = True
# terminate running processes
for p in processes:
if p.is_alive():
p.kill()
p.join()
all_paths = list(all_paths)
else:
# sequential execution to avoid overhead when analyzing smaller kernels
for instr in kernel:
all_paths.extend(
nx.algorithms.simple_paths.all_simple_paths(
dg, instr.line_number, instr.line_number + offset
)
)
paths_set = set()
for path_gen in paths:
for path in path_gen:
lat_sum = 0.0
# extend path by edge bound latencies (e.g., store-to-load latency)
lat_path = []
for s, d in nx.utils.pairwise(path):
edge_lat = dg.edges[s, d]['latency']
# map source node back to original line numbers
if s >= offset:
s -= offset
lat_path.append((s, edge_lat))
lat_sum += edge_lat
if d >= offset:
d -= offset
lat_path.sort()
for path in all_paths:
lat_sum = 0.0
# extend path by edge bound latencies (e.g., store-to-load latency)
lat_path = []
for s, d in nx.utils.pairwise(path):
edge_lat = dg.edges[s, d]['latency']
# map source node back to original line numbers
if s >= offset:
s -= offset
lat_path.append((s, edge_lat))
lat_sum += edge_lat
if d >= offset:
d -= offset
lat_path.sort()
# Ignore duplicate paths which differ only in the root node
if tuple(lat_path) in paths_set:
continue
paths_set.add(tuple(lat_path))
# Ignore duplicate paths which differ only in the root node
if tuple(lat_path) in paths_set:
continue
paths_set.add(tuple(lat_path))
loopcarried_deps.append((lat_sum, lat_path))
loopcarried_deps.append((lat_sum, lat_path))
loopcarried_deps.sort(reverse=True)
# map lcd back to nodes
@@ -121,8 +178,10 @@ class KernelDG(nx.DiGraph):
for lat_sum, involved_lines in loopcarried_deps:
loopcarried_deps_dict[involved_lines[0][0]] = {
"root": self._get_node_by_lineno(involved_lines[0][0]),
"dependencies": [(self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines],
"latency": lat_sum
"dependencies": [
(self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines
],
"latency": lat_sum,
}
return loopcarried_deps_dict
@@ -168,9 +227,7 @@ class KernelDG(nx.DiGraph):
# split to DAG
raise NotImplementedError("Kernel is cyclic.")
def find_depending(
self, instruction_form, instructions, flag_dependencies=False
):
def find_depending(self, instruction_form, instructions, flag_dependencies=False):
"""
Find instructions in `instructions` depending on a given instruction form's results.
@@ -190,15 +247,15 @@ class KernelDG(nx.DiGraph):
# TODO instructions before must be considered as well, if they update registers
# not used by insruction_form. E.g., validation/build/A64FX/gcc/O1/gs-2d-5pt.marked.s
register_changes = self._update_reg_changes(instruction_form)
#print("FROM", instruction_form.line, register_changes)
# print("FROM", instruction_form.line, register_changes)
for i, instr_form in enumerate(instructions):
self._update_reg_changes(instr_form, register_changes)
#print(" TO", instr_form.line, register_changes)
# print(" TO", instr_form.line, register_changes)
if "register" in dst:
# read of register
if self.is_read(dst.register, instr_form) and not (
dst.get("pre_indexed", False) or
dst.get("post_indexed", False)):
dst.get("pre_indexed", False) or dst.get("post_indexed", False)
):
yield instr_form, []
# write to register -> abort
if self.is_written(dst.register, instr_form):
@@ -215,10 +272,10 @@ class KernelDG(nx.DiGraph):
if "pre_indexed" in dst.memory:
if self.is_written(dst.memory.base, instr_form):
break
#if dst.memory.base:
# if dst.memory.base:
# if self.is_read(dst.memory.base, instr_form):
# yield instr_form, []
#if dst.memory.index:
# if dst.memory.index:
# if self.is_read(dst.memory.index, instr_form):
# yield instr_form, []
if "post_indexed" in dst.memory:
@@ -226,7 +283,7 @@ class KernelDG(nx.DiGraph):
if self.is_written(dst.memory.base, instr_form):
break
# TODO record register changes
# (e.g., mov, leaadd, sub, inc, dec) in instructions[:i]
# (e.g., mov, leaadd, sub, inc, dec) in instructions[:i]
# and pass to is_memload and is_memstore to consider relevance.
# load from same location (presumed)
if self.is_memload(dst.memory, instr_form, register_changes):
@@ -286,7 +343,9 @@ class KernelDG(nx.DiGraph):
if src.memory.base is not None:
is_read = self.parser.is_reg_dependend_of(register, src.memory.base) or is_read
if src.memory.index is not None:
is_read = self.parser.is_reg_dependend_of(register, src.memory.index) or is_read
is_read = (
self.parser.is_reg_dependend_of(register, src.memory.index) or is_read
)
# Check also if read in destination memory address
for dst in chain(
instruction_form.semantic_operands.destination,
@@ -296,7 +355,9 @@ class KernelDG(nx.DiGraph):
if dst.memory.base is not None:
is_read = self.parser.is_reg_dependend_of(register, dst.memory.base) or is_read
if dst.memory.index is not None:
is_read = self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read
is_read = (
self.parser.is_reg_dependend_of(register, dst.memory.index) or is_read
)
return is_read
def is_memload(self, mem, instruction_form, register_changes={}):
@@ -319,36 +380,38 @@ class KernelDG(nx.DiGraph):
addr_change -= int(mem.offset.value, 0)
if mem.base and src.base:
base_change = register_changes.get(
src.base.get('prefix', '')+src.base.name,
{'name': src.base.get('prefix', '')+src.base.name, 'value': 0})
src.base.get('prefix', '') + src.base.name,
{'name': src.base.get('prefix', '') + src.base.name, 'value': 0},
)
if base_change is None:
# Unknown change occurred
continue
if mem.base.get('prefix', '')+mem.base['name'] != base_change['name']:
if mem.base.get('prefix', '') + mem.base['name'] != base_change['name']:
# base registers do not match
continue
addr_change += base_change['value']
elif mem.base or src.base:
# base registers do not match
continue
# base registers do not match
continue
if mem.index and src.index:
index_change = register_changes.get(
src.index.get('prefix', '')+src.index.name,
{'name': src.index.get('prefix', '')+src.index.name, 'value': 0})
src.index.get('prefix', '') + src.index.name,
{'name': src.index.get('prefix', '') + src.index.name, 'value': 0},
)
if index_change is None:
# Unknown change occurred
continue
if mem.scale != src.scale:
# scale factors do not match
continue
if mem.index.get('prefix', '')+mem.index['name'] != index_change['name']:
if mem.index.get('prefix', '') + mem.index['name'] != index_change['name']:
# index registers do not match
continue
addr_change += index_change['value'] * src.scale
elif mem.index or src.index:
# index registers do not match
continue
#if instruction_form.line_number == 3:
# index registers do not match
continue
# if instruction_form.line_number == 3:
if addr_change == 0:
return True
return False