Some more stuff.

This commit is contained in:
pleroy
2025-01-04 18:28:57 +01:00
parent aeda9b1d33
commit 56fbe1d172

View File

@@ -8,6 +8,7 @@ from multiprocessing import Manager, Process, cpu_count
import networkx as nx import networkx as nx
from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel
from osaca.parser.instruction_form import InstructionForm
from osaca.parser.memory import MemoryOperand from osaca.parser.memory import MemoryOperand
from osaca.parser.register import RegisterOperand from osaca.parser.register import RegisterOperand
from osaca.parser.immediate import ImmediateOperand from osaca.parser.immediate import ImmediateOperand
@@ -86,7 +87,10 @@ class KernelDG(nx.DiGraph):
load_line_number = KernelDG.get_load_line_number(instruction_form.line_number) load_line_number = KernelDG.get_load_line_number(instruction_form.line_number)
loads[instruction_form.line_number] = load_line_number loads[instruction_form.line_number] = load_line_number
dg.add_node(load_line_number) dg.add_node(load_line_number)
dg.nodes[load_line_number]["instruction_form"] = instruction_form dg.nodes[load_line_number]["instruction_form"] = InstructionForm(
line=instruction_form.line,
line_number=load_line_number
)
# and set LD latency as edge weight # and set LD latency as edge weight
dg.add_edge( dg.add_edge(
load_line_number, load_line_number,
@@ -94,16 +98,11 @@ class KernelDG(nx.DiGraph):
latency=instruction_form.latency - instruction_form.latency_wo_load, latency=instruction_form.latency - instruction_form.latency_wo_load,
) )
#TODO comments #TODO comments
#print("LOADS", loads) print("LOADS", loads)
for i, instruction_form in enumerate(kernel): for i, instruction_form in enumerate(kernel):
for dep, dep_flags in self.find_depending( for dep, dep_flags in self.find_depending(
instruction_form, kernel[i + 1 :], flag_dependencies instruction_form, kernel[i + 1 :], flag_dependencies
): ):
#if instruction_form.mnemonic == 'shl':
# print("IF", instruction_form)
# print("DEP", dep)
# print("DF", dep_flags)
# print(instruction_form.line_number,"\t",dep.line_number,"\n")
edge_weight = ( edge_weight = (
instruction_form.latency instruction_form.latency
if "mem_dep" in dep_flags or instruction_form.latency_wo_load is None if "mem_dep" in dep_flags or instruction_form.latency_wo_load is None
@@ -121,6 +120,7 @@ class KernelDG(nx.DiGraph):
latency=edge_weight, latency=edge_weight,
) )
else: else:
#print("DEP", instruction_form.line_number, dep.line_number)
dg.add_edge( dg.add_edge(
instruction_form.line_number, instruction_form.line_number,
dep.line_number, dep.line_number,
@@ -234,26 +234,17 @@ class KernelDG(nx.DiGraph):
for lat_sum, involved_lines in loopcarried_deps: for lat_sum, involved_lines in loopcarried_deps:
dict_key = "-".join([str(il[0]) for il in involved_lines]) dict_key = "-".join([str(il[0]) for il in involved_lines])
loopcarried_deps_dict[dict_key] = { loopcarried_deps_dict[dict_key] = {
"root": self._get_node_by_lineno(involved_lines[0][0]), "root": self._get_node_by_lineno(dg, involved_lines[0][0]),
"dependencies": [ "dependencies": [
(self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines (self._get_node_by_lineno(dg, ln), lat) for ln, lat in involved_lines
], ],
"latency": lat_sum, "latency": lat_sum,
} }
return loopcarried_deps_dict return loopcarried_deps_dict
def _get_node_by_lineno(self, lineno, kernel=None, all=False): def _get_node_by_lineno(self, dg, lineno):
"""Return instruction form with line number ``lineno`` from kernel""" """Return instruction form with line number ``lineno`` from dg"""
#print(lineno) return dg.nodes[lineno]["instruction_form"]
if kernel is None:
kernel = self.kernel
result = [instr for instr in kernel
if instr.line_number == lineno
or KernelDG.get_load_line_number(instr.line_number) == lineno]
if not all:
return result[0]
else:
return result
def get_critical_path(self): def get_critical_path(self):
"""Find and return critical path after the creation of a directed graph.""" """Find and return critical path after the creation of a directed graph."""
@@ -262,21 +253,21 @@ class KernelDG(nx.DiGraph):
longest_path = nx.algorithms.dag.dag_longest_path(self.dg, weight="latency") longest_path = nx.algorithms.dag.dag_longest_path(self.dg, weight="latency")
# TODO verify that we can remove the next two lince due to earlier initialization # TODO verify that we can remove the next two lince due to earlier initialization
for line_number in longest_path: for line_number in longest_path:
self._get_node_by_lineno(int(line_number)).latency_cp = 0 self._get_node_by_lineno(self.dg, line_number).latency_cp = 0
# set cp latency to instruction # set cp latency to instruction
path_latency = 0.0 path_latency = 0.0
for s, d in nx.utils.pairwise(longest_path): for s, d in nx.utils.pairwise(longest_path):
node = self._get_node_by_lineno(int(s)) node = self._get_node_by_lineno(self.dg, s)
node.latency_cp = self.dg.edges[(s, d)]["latency"] node.latency_cp = self.dg.edges[(s, d)]["latency"]
path_latency += node.latency_cp path_latency += node.latency_cp
# add latency for last instruction # add latency for last instruction
node = self._get_node_by_lineno(int(longest_path[-1])) node = self._get_node_by_lineno(self.dg, longest_path[-1])
node.latency_cp = node.latency node.latency_cp = node.latency
if max_latency_instr.latency > path_latency: if max_latency_instr.latency > path_latency:
max_latency_instr.latency_cp = float(max_latency_instr.latency) max_latency_instr.latency_cp = float(max_latency_instr.latency)
return [max_latency_instr] return [max_latency_instr]
else: else:
return [x for x in self.kernel if x.line_number in longest_path] return [self._get_node_by_lineno(self.dg, x) for x in longest_path]
else: else:
# split to DAG # split to DAG
raise NotImplementedError("Kernel is cyclic.") raise NotImplementedError("Kernel is cyclic.")
@@ -571,6 +562,7 @@ class KernelDG(nx.DiGraph):
lcd_line_numbers = {} lcd_line_numbers = {}
for dep in lcd: for dep in lcd:
lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]] lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]]
print("LCDLN", lcd_line_numbers)
# create LCD edges # create LCD edges
for dep in lcd_line_numbers: for dep in lcd_line_numbers:
@@ -583,6 +575,7 @@ class KernelDG(nx.DiGraph):
# add label to edges # add label to edges
for e in graph.edges: for e in graph.edges:
print("EDGE", e)
graph.edges[e]["label"] = graph.edges[e]["latency"] graph.edges[e]["label"] = graph.edges[e]["latency"]
# add CP values to graph # add CP values to graph
@@ -596,20 +589,12 @@ class KernelDG(nx.DiGraph):
graph.nodes[n]["style"] = "bold" graph.nodes[n]["style"] = "bold"
graph.nodes[n]["penwidth"] = 4 graph.nodes[n]["penwidth"] = 4
print("CPLN", cp_line_numbers)
# Make critical path edges bold. # Make critical path edges bold.
for e in graph.edges: for u, v in zip(cp_line_numbers[:-1], cp_line_numbers[1:]):
if ( graph.edges[u, v]["style"] = "bold"
graph.nodes[e[0]]["instruction_form"].line_number in cp_line_numbers graph.edges[u, v]["penwidth"] = 3
and graph.nodes[e[1]]["instruction_form"].line_number in cp_line_numbers
and e[0] < e[1]
):
bold_edge = True
for i in range(e[0] + 1, e[1]):
if i in cp_line_numbers:
bold_edge = False
if bold_edge:
graph.edges[e]["style"] = "bold"
graph.edges[e]["penwidth"] = 3
# Color the cycles created by loop-carried dependencies, longest first, never recoloring # Color the cycles created by loop-carried dependencies, longest first, never recoloring
# any node or edge, so that the longest LCD and most long chains that are involved in the # any node or edge, so that the longest LCD and most long chains that are involved in the
@@ -639,6 +624,7 @@ class KernelDG(nx.DiGraph):
# Dont introduce a color just for an edge. # Dont introduce a color just for an edge.
if not color: if not color:
color = colors_used color = colors_used
print("EC", u, v, color)
edge_colors[u, v] = color edge_colors[u, v] = color
max_color = min(11, colors_used) max_color = min(11, colors_used)
colorscheme = f"spectral{max(3, max_color)}" colorscheme = f"spectral{max(3, max_color)}"