mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2026-01-05 10:40:06 +01:00
enhanced by optimal throughput analysis
This commit is contained in:
@@ -3,10 +3,7 @@
|
||||
import re
|
||||
from datetime import datetime as dt
|
||||
|
||||
from ruamel import yaml
|
||||
|
||||
from osaca import utils
|
||||
from osaca.semantics import INSTR_FLAGS, KernelDG, ArchSemantics
|
||||
from osaca.semantics import INSTR_FLAGS, ArchSemantics, KernelDG, MachineModel
|
||||
|
||||
|
||||
class Frontend(object):
|
||||
@@ -19,11 +16,10 @@ class Frontend(object):
|
||||
self._arch = arch
|
||||
if arch:
|
||||
self._arch = arch.lower()
|
||||
with open(utils.find_file(self._arch + '.yml'), 'r') as f:
|
||||
self._data = yaml.load(f, Loader=yaml.Loader)
|
||||
self._machine_model = MachineModel(arch=arch)
|
||||
elif path_to_yaml:
|
||||
with open(path_to_yaml, 'r') as f:
|
||||
self._data = yaml.load(f, Loader=yaml.Loader)
|
||||
self._machine_model = MachineModel(path_to_yaml=path_to_yaml)
|
||||
self._arch = self._machine_model.get_arch()
|
||||
|
||||
def _is_comment(self, instruction_form):
|
||||
return instruction_form['comment'] is not None and instruction_form['instruction'] is None
|
||||
@@ -45,7 +41,9 @@ class Frontend(object):
|
||||
for instruction_form in kernel:
|
||||
line = '{:4d} {} {} {}'.format(
|
||||
instruction_form['line_number'],
|
||||
self._get_port_pressure(instruction_form['port_pressure'], port_len, sep_list),
|
||||
self._get_port_pressure(
|
||||
instruction_form['port_pressure'], port_len, separator=sep_list
|
||||
),
|
||||
self._get_flag_symbols(instruction_form['flags'])
|
||||
if instruction_form['instruction'] is not None
|
||||
else ' ',
|
||||
@@ -57,7 +55,7 @@ class Frontend(object):
|
||||
print(line)
|
||||
print()
|
||||
tp_sum = ArchSemantics.get_throughput_sum(kernel)
|
||||
print(lineno_filler + self._get_port_pressure(tp_sum, port_len, ' '))
|
||||
print(lineno_filler + self._get_port_pressure(tp_sum, port_len, separator=' '))
|
||||
|
||||
def print_latency_analysis(self, cp_kernel, separator='|'):
|
||||
print('\n\nLatency Analysis Report\n' + '-----------------------')
|
||||
@@ -149,9 +147,13 @@ class Frontend(object):
|
||||
if show_cmnts is False and self._is_comment(instruction_form):
|
||||
continue
|
||||
line_number = instruction_form['line_number']
|
||||
used_ports = [list(uops[1]) for uops in instruction_form['port_uops']]
|
||||
used_ports = list(set([p for uops_ports in used_ports for p in uops_ports]))
|
||||
line = '{:4d} {}{} {} {}'.format(
|
||||
line_number,
|
||||
self._get_port_pressure(instruction_form['port_pressure'], port_len, sep_list),
|
||||
self._get_port_pressure(
|
||||
instruction_form['port_pressure'], port_len, used_ports, sep_list
|
||||
),
|
||||
self._get_lcd_cp_ports(
|
||||
instruction_form['line_number'],
|
||||
cp_kernel if line_number in cp_lines else None,
|
||||
@@ -169,7 +171,7 @@ class Frontend(object):
|
||||
cp_sum = sum([x['latency_cp'] for x in cp_kernel])
|
||||
print(
|
||||
lineno_filler
|
||||
+ self._get_port_pressure(tp_sum, port_len, ' ')
|
||||
+ self._get_port_pressure(tp_sum, port_len, separator=' ')
|
||||
+ ' {:^6} {:^6}'.format(cp_sum, lcd_sum)
|
||||
)
|
||||
|
||||
@@ -179,9 +181,9 @@ class Frontend(object):
|
||||
|
||||
def _get_separator_list(self, separator, separator_2=' '):
|
||||
separator_list = []
|
||||
for i in range(len(self._data['ports']) - 1):
|
||||
match_1 = re.search(r'\d+', self._data['ports'][i])
|
||||
match_2 = re.search(r'\d+', self._data['ports'][i + 1])
|
||||
for i in range(len(self._machine_model.get_ports()) - 1):
|
||||
match_1 = re.search(r'\d+', self._machine_model.get_ports()[i])
|
||||
match_2 = re.search(r'\d+', self._machine_model.get_ports()[i + 1])
|
||||
if match_1 is not None and match_2 is not None and match_1.group() == match_2.group():
|
||||
separator_list.append(separator_2)
|
||||
else:
|
||||
@@ -198,12 +200,12 @@ class Frontend(object):
|
||||
string_result += ' ' if len(string_result) == 0 else ''
|
||||
return string_result
|
||||
|
||||
def _get_port_pressure(self, ports, port_len, separator='|'):
|
||||
def _get_port_pressure(self, ports, port_len, used_ports=[], separator='|'):
|
||||
if not isinstance(separator, list):
|
||||
separator = [separator for x in ports]
|
||||
string_result = '{} '.format(separator[-1])
|
||||
for i in range(len(ports)):
|
||||
if float(ports[i]) == 0.0:
|
||||
if float(ports[i]) == 0.0 and self._machine_model.get_ports()[i] not in used_ports:
|
||||
string_result += port_len[i] * ' ' + ' {} '.format(separator[i])
|
||||
continue
|
||||
left_len = len(str(float(ports[i])).split('.')[0])
|
||||
@@ -226,7 +228,7 @@ class Frontend(object):
|
||||
return '{} {:>4} {} {:>4} {}'.format(separator, lat_cp, separator, lat_lcd, separator)
|
||||
|
||||
def _get_max_port_len(self, kernel):
|
||||
port_len = [4 for x in self._data['ports']]
|
||||
port_len = [4 for x in self._machine_model.get_ports()]
|
||||
for instruction_form in kernel:
|
||||
for i, port in enumerate(instruction_form['port_pressure']):
|
||||
if len('{:.2f}'.format(port)) > port_len[i]:
|
||||
@@ -238,7 +240,7 @@ class Frontend(object):
|
||||
separator_list = self._get_separator_list(separator, '-')
|
||||
for i, length in enumerate(port_len):
|
||||
substr = '{:^' + str(length + 2) + 's}'
|
||||
string_result += substr.format(self._data['ports'][i]) + separator_list[i]
|
||||
string_result += substr.format(self._machine_model.get_ports()[i]) + separator_list[i]
|
||||
return string_result
|
||||
|
||||
def _print_header_report(self):
|
||||
|
||||
@@ -59,6 +59,12 @@ def create_parser():
|
||||
type=str,
|
||||
help='Define architecture (SNB, IVB, HSW, BDW, SKX, CSX, ZEN1, TX2).',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--fixed',
|
||||
action='store_true',
|
||||
help='Run the throughput analysis with fixed probabilities for all suitable ports per '
|
||||
'instruction. Otherwise, OSACA will print out the optimal port utilization for the kernel.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--db-check',
|
||||
dest='check_db',
|
||||
@@ -183,6 +189,9 @@ def inspect(args):
|
||||
machine_model = MachineModel(arch=arch)
|
||||
semantics = ArchSemantics(machine_model)
|
||||
semantics.add_semantics(kernel)
|
||||
# Do optimal schedule for kernel throughput if wished
|
||||
if not args.fixed:
|
||||
semantics.assign_optimal_throughput(kernel)
|
||||
|
||||
# Create DiGrahps
|
||||
kernel_graph = KernelDG(kernel, parser, machine_model)
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from operator import itemgetter
|
||||
|
||||
from .isa_semantics import INSTR_FLAGS, ISASemantics
|
||||
from .hw_model import MachineModel
|
||||
from .isa_semantics import INSTR_FLAGS, ISASemantics
|
||||
|
||||
|
||||
class ArchSemantics(ISASemantics):
|
||||
@@ -21,6 +22,54 @@ class ArchSemantics(ISASemantics):
|
||||
if self._machine_model.has_hidden_loads():
|
||||
self.set_hidden_loads(kernel)
|
||||
|
||||
def assign_optimal_throughput(self, kernel):
|
||||
INC = 0.01
|
||||
kernel.reverse()
|
||||
port_list = self._machine_model.get_ports()
|
||||
for instruction_form in kernel:
|
||||
for uop in instruction_form['port_uops']:
|
||||
cycles = uop[0]
|
||||
ports = list(uop[1])
|
||||
indices = [port_list.index(p) for p in ports]
|
||||
# check if port sum of used ports for uop are unbalanced
|
||||
port_sums = self._to_list(itemgetter(*indices)(self.get_throughput_sum(kernel)))
|
||||
instr_ports = self._to_list(
|
||||
itemgetter(*indices)(instruction_form['port_pressure'])
|
||||
)
|
||||
if len(set(port_sums)) > 1:
|
||||
# balance ports
|
||||
for _ in range(cycles * 100):
|
||||
instr_ports[port_sums.index(max(port_sums))] -= INC
|
||||
instr_ports[port_sums.index(min(port_sums))] += INC
|
||||
# instr_ports = [round(p, 2) for p in instr_ports]
|
||||
self._itemsetter(*indices)(instruction_form['port_pressure'], *instr_ports)
|
||||
# check if min port is zero
|
||||
if round(min(instr_ports), 2) <= 0:
|
||||
# if port_pressure is not exactly 0.00, add the residual to
|
||||
# the former port
|
||||
if min(instr_ports) != 0.0:
|
||||
instr_ports[port_sums.index(min(port_sums))] += min(instr_ports)
|
||||
self._itemsetter(*indices)(
|
||||
instruction_form['port_pressure'], *instr_ports
|
||||
)
|
||||
zero_index = [
|
||||
p
|
||||
for p in indices
|
||||
if round(instruction_form['port_pressure'][p], 2) == 0
|
||||
][0]
|
||||
instruction_form['port_pressure'][zero_index] = 0.0
|
||||
# Remove from further balancing
|
||||
indices = [
|
||||
p for p in indices if instruction_form['port_pressure'][p] > 0
|
||||
]
|
||||
instr_ports = self._to_list(
|
||||
itemgetter(*indices)(instruction_form['port_pressure'])
|
||||
)
|
||||
port_sums = self._to_list(
|
||||
itemgetter(*indices)(self.get_throughput_sum(kernel))
|
||||
)
|
||||
kernel.reverse()
|
||||
|
||||
def set_hidden_loads(self, kernel):
|
||||
loads = [instr for instr in kernel if INSTR_FLAGS.HAS_LD in instr['flags']]
|
||||
stores = [instr for instr in kernel if INSTR_FLAGS.HAS_ST in instr['flags']]
|
||||
@@ -70,6 +119,7 @@ class ArchSemantics(ISASemantics):
|
||||
latency = 0.0
|
||||
latency_wo_load = latency
|
||||
instruction_form['port_pressure'] = [0.0 for i in range(port_number)]
|
||||
instruction_form['port_uops'] = []
|
||||
else:
|
||||
instruction_data = self._machine_model.get_instruction(
|
||||
instruction_form['instruction'], instruction_form['operands']
|
||||
@@ -80,6 +130,7 @@ class ArchSemantics(ISASemantics):
|
||||
port_pressure = self._machine_model.average_port_pressure(
|
||||
instruction_data['port_pressure']
|
||||
)
|
||||
instruction_form['port_uops'] = instruction_data['port_pressure']
|
||||
try:
|
||||
assert isinstance(port_pressure, list)
|
||||
assert len(port_pressure) == port_number
|
||||
@@ -93,6 +144,7 @@ class ArchSemantics(ISASemantics):
|
||||
+ 'Please check entry for:\n {}'.format(instruction_form)
|
||||
)
|
||||
instruction_form['port_pressure'] = [0.0 for i in range(port_number)]
|
||||
instruction_form['port_uops'] = []
|
||||
flags.append(INSTR_FLAGS.TP_UNKWN)
|
||||
if throughput is None:
|
||||
# assume 0 cy and mark as unknown
|
||||
@@ -124,14 +176,15 @@ class ArchSemantics(ISASemantics):
|
||||
for op in operands['operand_list']
|
||||
if 'register' in op
|
||||
]
|
||||
load_port_uops = self._machine_model.get_load_throughput(
|
||||
[
|
||||
x['memory']
|
||||
for x in instruction_form['operands']['source']
|
||||
if 'memory' in x
|
||||
][0]
|
||||
)
|
||||
load_port_pressure = self._machine_model.average_port_pressure(
|
||||
self._machine_model.get_load_throughput(
|
||||
[
|
||||
x['memory']
|
||||
for x in instruction_form['operands']['source']
|
||||
if 'memory' in x
|
||||
][0]
|
||||
)
|
||||
load_port_uops
|
||||
)
|
||||
if 'load_throughput_multiplier' in self._machine_model:
|
||||
multiplier = self._machine_model['load_throughput_multiplier'][
|
||||
@@ -155,12 +208,17 @@ class ArchSemantics(ISASemantics):
|
||||
),
|
||||
)
|
||||
]
|
||||
instruction_form['port_uops'] = (
|
||||
instruction_data_reg['port_pressure'] + load_port_uops
|
||||
)
|
||||
|
||||
if assign_unknown:
|
||||
# --> mark as unknown and assume 0 cy for latency/throughput
|
||||
throughput = 0.0
|
||||
latency = 0.0
|
||||
latency_wo_load = latency
|
||||
instruction_form['port_pressure'] = [0.0 for i in range(port_number)]
|
||||
instruction_form['port_uops'] = []
|
||||
flags += [INSTR_FLAGS.TP_UNKWN, INSTR_FLAGS.LT_UNKWN]
|
||||
# flatten flag list
|
||||
flags = list(set(flags))
|
||||
@@ -221,6 +279,27 @@ class ArchSemantics(ISASemantics):
|
||||
port_pressure[index] = 0.0
|
||||
return port_pressure
|
||||
|
||||
def _itemsetter(self, *items):
|
||||
if len(items) == 1:
|
||||
item = items[0]
|
||||
|
||||
def g(obj, value):
|
||||
obj[item] = value
|
||||
|
||||
else:
|
||||
|
||||
def g(obj, *values):
|
||||
for item, value in zip(items, values):
|
||||
obj[item] = value
|
||||
|
||||
return g
|
||||
|
||||
def _to_list(self, obj):
|
||||
if isinstance(obj, tuple):
|
||||
return list(obj)
|
||||
else:
|
||||
return [obj]
|
||||
|
||||
@staticmethod
|
||||
def get_throughput_sum(kernel):
|
||||
tp_sum = reduce(
|
||||
|
||||
Reference in New Issue
Block a user