enhanced TP scheduling

This commit is contained in:
JanLJL
2020-07-06 18:49:46 +02:00
parent ce8c3ff9ab
commit 0e77b7bc9a
2 changed files with 22 additions and 5 deletions

View File

@@ -53,9 +53,15 @@ class ArchSemantics(ISASemantics):
)
if len(set(port_sums)) > 1:
# balance ports
for _ in range(cycles * 100):
instr_ports[port_sums.index(max(port_sums))] -= INC
instr_ports[port_sums.index(min(port_sums))] += INC
# init list for keeping track of the current change
differences = [cycles / len(ports) for p in ports]
for _ in range(int(cycles * (1 / INC))):
max_port_idx = port_sums.index(max(port_sums))
min_port_idx = port_sums.index(min(port_sums))
instr_ports[max_port_idx] -= INC
instr_ports[min_port_idx] += INC
differences[max_port_idx] -= INC
differences[min_port_idx] += INC
# instr_ports = [round(p, 2) for p in instr_ports]
self._itemsetter(*indices)(instruction_form['port_pressure'], *instr_ports)
# check if min port is zero
@@ -63,7 +69,11 @@ class ArchSemantics(ISASemantics):
# if port_pressure is not exactly 0.00, add the residual to
# the former port
if min(instr_ports) != 0.0:
instr_ports[port_sums.index(min(port_sums))] += min(instr_ports)
min_port_idx = port_sums.index(min(port_sums))
instr_ports[min_port_idx] += min(instr_ports)
differences[min_port_idx] += min(instr_ports)
# we don't need to decrease difference for other port, just delete it
del differences[instr_ports.index(min(instr_ports))]
self._itemsetter(*indices)(
instruction_form['port_pressure'], *instr_ports
)
@@ -80,6 +90,14 @@ class ArchSemantics(ISASemantics):
instr_ports = self._to_list(
itemgetter(*indices)(instruction_form['port_pressure'])
)
# never remove more than the fixed utilization per uop and port, i.e., cycles/len(ports)
if round(min(differences), 2) <= 0:
# don't worry if port_pressure isn't exactly 0 and just
# remove from further balancing
indices = [p for p in indices if differences[indices.index(p)] > 0]
instr_ports = self._to_list(
itemgetter(*indices)(instruction_form['port_pressure'])
)
port_sums = self._to_list(
itemgetter(*indices)(self.get_throughput_sum(kernel))
)