mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2026-01-06 19:20:07 +01:00
added DiGraph creation and more tests
This commit is contained in:
@@ -2,14 +2,15 @@
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from .hw_model import MachineModel
|
||||
|
||||
|
||||
class KernelDG(nx.DiGraph):
|
||||
def __init__(self, parsed_kernel, parser, hw_model):
|
||||
def __init__(self, parsed_kernel, parser, hw_model: MachineModel):
|
||||
self.kernel = parsed_kernel
|
||||
self.parser = parser
|
||||
self.model = hw_model
|
||||
|
||||
# self.dag = self.create_DG()
|
||||
self.dg = self.create_DG()
|
||||
|
||||
def check_for_loop(self, kernel):
|
||||
raise NotImplementedError
|
||||
@@ -17,26 +18,32 @@ class KernelDG(nx.DiGraph):
|
||||
def create_DG(self):
|
||||
# 1. go through kernel instruction forms (as vertices)
|
||||
# 2. find edges (to dependend further instruction)
|
||||
# 3. get LT/TP value and set as edge weight
|
||||
dag = nx.DiGraph()
|
||||
for i, instruction in enumerate(self.kernel):
|
||||
throughput = self.model.get_throughput(instruction)
|
||||
latency = self.model.get_latency(instruction)
|
||||
for dep in self.find_depending(instruction, self.kernel[i + 1:]):
|
||||
dag.add_edge(
|
||||
instruction.line_number,
|
||||
dep.line_number,
|
||||
latency=latency,
|
||||
thorughput=throughput,
|
||||
# 3. get LT value and set as edge weight
|
||||
# 4. add instr forms as node attribute
|
||||
dg = nx.DiGraph()
|
||||
for i, instruction_form in enumerate(self.kernel):
|
||||
for dep in self.find_depending(instruction_form, self.kernel[i + 1:]):
|
||||
dg.add_edge(
|
||||
instruction_form['line_number'],
|
||||
dep['line_number'],
|
||||
latency=instruction_form['latency'],
|
||||
)
|
||||
dg.nodes[instruction_form['line_number']]['instruction_form'] = instruction_form
|
||||
dg.nodes[dep['line_number']]['instruction_form'] = dep
|
||||
return dg
|
||||
|
||||
def find_depending(self, instruction_form, kernel):
|
||||
for dst in instruction_form.operands.destination:
|
||||
if instruction_form.operands is None:
|
||||
return
|
||||
for dst in instruction_form.operands.destination + instruction_form.operands.src_dst:
|
||||
if 'register' in dst:
|
||||
# Check for read of register until overwrite
|
||||
for instr_form in kernel:
|
||||
if self.is_read(dst.register, instr_form):
|
||||
yield instr_form
|
||||
if self.is_written(dst.register, instr_form):
|
||||
# operand in src_dst list
|
||||
break
|
||||
elif self.is_written(dst.register, instr_form):
|
||||
break
|
||||
elif 'memory' in dst:
|
||||
@@ -46,12 +53,26 @@ class KernelDG(nx.DiGraph):
|
||||
for instr_form in kernel:
|
||||
if self.is_read(dst.memory.base, instr_form):
|
||||
yield instr_form
|
||||
if self.is_written(dst.memory.base, instr_form):
|
||||
# operand in src_dst list
|
||||
break
|
||||
elif self.is_written(dst.memory.base, instr_form):
|
||||
break
|
||||
|
||||
def get_dependent_instruction_forms(self, instr_form=None, line_number=None):
|
||||
"""
|
||||
Returns iterator
|
||||
"""
|
||||
if not instr_form and not line_number:
|
||||
raise ValueError('Either instruction form or line_number required.')
|
||||
line_number = line_number if line_number else instr_form['line_number']
|
||||
if self.dg.has_node(line_number):
|
||||
return self.dg.successors(line_number)
|
||||
return iter([])
|
||||
|
||||
def is_read(self, register, instruction_form):
|
||||
is_read = False
|
||||
for src in instruction_form.operands.source:
|
||||
for src in instruction_form.operands.source + instruction_form.operands.src_dst:
|
||||
if 'register' in src:
|
||||
is_read = self.parser.is_reg_dependend_of(register, src.register) or is_read
|
||||
if 'memory' in src:
|
||||
@@ -62,7 +83,7 @@ class KernelDG(nx.DiGraph):
|
||||
self.parser.is_reg_dependend_of(register, src.memory.index) or is_read
|
||||
)
|
||||
# Check also if read in destination memory address
|
||||
for dst in instruction_form.operands.destination:
|
||||
for dst in instruction_form.operands.destination + instruction_form.operands.src_dst:
|
||||
if 'memory' in dst:
|
||||
if dst.memory.base is not None:
|
||||
is_read = self.parser.is_reg_dependend_of(register, dst.memory.base) or is_read
|
||||
@@ -74,7 +95,7 @@ class KernelDG(nx.DiGraph):
|
||||
|
||||
def is_written(self, register, instruction_form):
|
||||
is_written = False
|
||||
for dst in instruction_form.operands.destination:
|
||||
for dst in instruction_form.operands.destination + instruction_form.operands.src_dst:
|
||||
if 'register' in dst:
|
||||
is_written = self.parser.is_reg_dependend_of(register, dst.register) or is_written
|
||||
if 'memory' in dst:
|
||||
@@ -83,7 +104,7 @@ class KernelDG(nx.DiGraph):
|
||||
self.parser.is_reg_dependend_of(register, dst.memory.base) or is_written
|
||||
)
|
||||
# Check also for possible pre- or post-indexing in memory addresses
|
||||
for src in instruction_form.operands.source:
|
||||
for src in instruction_form.operands.source + instruction_form.operands.src_dst:
|
||||
if 'memory' in src:
|
||||
if 'pre_indexed' in src.memory or 'post_indexed' in src.memory:
|
||||
is_written = (
|
||||
|
||||
Reference in New Issue
Block a user