Merge imports

This commit is contained in:
pleroy
2025-01-04 15:55:33 +01:00
parent 33fd0a0352
commit aeda9b1d33

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import copy
from enum import Enum
import time
from itertools import chain, groupby
from multiprocessing import Manager, Process, cpu_count
@@ -17,6 +18,11 @@ class KernelDG(nx.DiGraph):
# threshold for checking dependency graph sequential or in parallel
INSTRUCTION_THRESHOLD = 50
class ReadKind(Enum):
NOT_A_READ = 0
READ_FOR_LOAD = 1
OTHER_READ = 2
def __init__(
self,
parsed_kernel,
@@ -46,6 +52,11 @@ class KernelDG(nx.DiGraph):
dst_list.extend(tmp_list)
# print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number']))
@staticmethod
def get_load_line_number(line_number):
# The offset is irrelevant, but it must be a machine number to avoid silly rounding issues.
return line_number + 0.125
def create_DG(self, kernel, flag_dependencies=False):
"""
Create directed graph from given kernel
@@ -61,26 +72,37 @@ class KernelDG(nx.DiGraph):
# 2. find edges (to dependend further instruction)
# 3. get LT value and set as edge weight
dg = nx.DiGraph()
loads = {}
for i, instruction_form in enumerate(kernel):
dg.add_node(instruction_form.line_number)
dg.nodes[instruction_form.line_number]["instruction_form"] = instruction_form
# add load as separate node if existent
load_line_number = None
if (
INSTR_FLAGS.HAS_LD in instruction_form.flags
and INSTR_FLAGS.LD not in instruction_form.flags
):
# add new node
dg.add_node(instruction_form.line_number + 0.1)
dg.nodes[instruction_form.line_number + 0.1]["instruction_form"] = instruction_form
load_line_number = KernelDG.get_load_line_number(instruction_form.line_number)
loads[instruction_form.line_number] = load_line_number
dg.add_node(load_line_number)
dg.nodes[load_line_number]["instruction_form"] = instruction_form
# and set LD latency as edge weight
dg.add_edge(
instruction_form.line_number + 0.1,
load_line_number,
instruction_form.line_number,
latency=instruction_form.latency - instruction_form.latency_wo_load,
)
#TODO comments
#print("LOADS", loads)
for i, instruction_form in enumerate(kernel):
for dep, dep_flags in self.find_depending(
instruction_form, kernel[i + 1 :], flag_dependencies
):
#if instruction_form.mnemonic == 'shl':
# print("IF", instruction_form)
# print("DEP", dep)
# print("DF", dep_flags)
# print(instruction_form.line_number,"\t",dep.line_number,"\n")
edge_weight = (
instruction_form.latency
@@ -91,11 +113,19 @@ class KernelDG(nx.DiGraph):
edge_weight += self.model.get("store_to_load_forward_latency", 0)
if "p_indexed" in dep_flags and self.model is not None:
edge_weight = self.model.get("p_index_latency", 1)
dg.add_edge(
instruction_form.line_number,
dep.line_number,
latency=edge_weight,
)
if "for_load" in dep_flags and self.model is not None and dep.line_number in loads:
#print("LOADDEP", instruction_form.line_number, loads[dep.line_number])
dg.add_edge(
instruction_form.line_number,
loads[dep.line_number],
latency=edge_weight,
)
else:
dg.add_edge(
instruction_form.line_number,
dep.line_number,
latency=edge_weight,
)
dg.nodes[dep.line_number]["instruction_form"] = dep
return dg
@@ -214,9 +244,12 @@ class KernelDG(nx.DiGraph):
def _get_node_by_lineno(self, lineno, kernel=None, all=False):
"""Return instruction form with line number ``lineno`` from kernel"""
#print(lineno)
if kernel is None:
kernel = self.kernel
result = [instr for instr in kernel if instr.line_number == lineno]
result = [instr for instr in kernel
if instr.line_number == lineno
or KernelDG.get_load_line_number(instr.line_number) == lineno]
if not all:
return result[0]
else:
@@ -284,15 +317,18 @@ class KernelDG(nx.DiGraph):
# print(" TO", instr_form.line, register_changes)
if isinstance(dst, RegisterOperand):
# read of register
if self.is_read(dst, instr_form):
read_kind = self._read_kind(dst, instr_form)
if read_kind != KernelDG.ReadKind.NOT_A_READ:
dep_flags = []
if (
dst.pre_indexed
or dst.post_indexed
or (isinstance(dst.post_indexed, dict))
):
yield instr_form, ["p_indexed"]
else:
yield instr_form, []
dep_flags = ["p_indexed"]
if read_kind == KernelDG.ReadKind.READ_FOR_LOAD:
dep_flags += ["for_load"]
yield instr_form, dep_flags
# write to register -> abort
if self.is_written(dst, instr_form):
break
@@ -363,11 +399,12 @@ class KernelDG(nx.DiGraph):
return self.dg.successors(line_number)
return iter([])
def is_read(self, register, instruction_form):
"""Check if instruction form reads from given register"""
def _read_kind(self, register, instruction_form):
"""Check if instruction form reads from given register. Returns a ReadKind."""
is_read = False
for_load = False
if instruction_form.semantic_operands is None:
return is_read
return KernelDG.ReadKind.NOT_A_READ
for src in chain(
instruction_form.semantic_operands["source"],
instruction_form.semantic_operands["src_dst"],
@@ -377,10 +414,15 @@ class KernelDG(nx.DiGraph):
if isinstance(src, FlagOperand):
is_read = self.parser.is_flag_dependend_of(register, src) or is_read
if isinstance(src, MemoryOperand):
#print("1", is_read)
if src.base is not None:
is_read = self.parser.is_reg_dependend_of(register, src.base) or is_read
#print("2", is_read)
if src.index is not None and isinstance(src.index, RegisterOperand):
is_read = self.parser.is_reg_dependend_of(register, src.index) or is_read
#print("3", is_read)
#print("FORLOAD", register, src)
for_load = True
# Check also if read in destination memory address
for dst in chain(
instruction_form.semantic_operands["destination"],
@@ -391,7 +433,16 @@ class KernelDG(nx.DiGraph):
is_read = self.parser.is_reg_dependend_of(register, dst.base) or is_read
if dst.index is not None:
is_read = self.parser.is_reg_dependend_of(register, dst.index) or is_read
return is_read
if is_read:
if for_load:
return KernelDG.ReadKind.READ_FOR_LOAD
else:
return KernelDG.ReadKind.OTHER_READ
else:
return KernelDG.ReadKind.NOT_A_READ
def is_read(self, register, instruction_form):
return self._read_kind(register, instruction_form) != KernelDG.ReadKind.NOT_A_READ
def is_memload(self, mem, instruction_form, register_changes={}):
"""Check if instruction form loads from given location, assuming register_changes"""