mirror of
https://github.com/RRZE-HPC/OSACA.git
synced 2025-12-15 16:40:05 +01:00
668 lines
30 KiB
Python
668 lines
30 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import copy
|
|
import time
|
|
from itertools import chain
|
|
from multiprocessing import Manager, Process, cpu_count
|
|
|
|
import networkx as nx
|
|
from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel
|
|
from osaca.parser.memory import MemoryOperand
|
|
from osaca.parser.register import RegisterOperand
|
|
from osaca.parser.immediate import ImmediateOperand
|
|
from osaca.parser.flag import FlagOperand
|
|
|
|
|
|
class KernelDG(nx.DiGraph):
|
|
# threshold for checking dependency graph sequential or in parallel
|
|
INSTRUCTION_THRESHOLD = 50
|
|
|
|
def __init__(
|
|
self,
|
|
parsed_kernel,
|
|
parser,
|
|
hw_model: MachineModel,
|
|
semantics: ArchSemantics,
|
|
timeout=10,
|
|
flag_dependencies=False,
|
|
):
|
|
self.timed_out = False
|
|
self.kernel = parsed_kernel
|
|
self.parser = parser
|
|
self.model = hw_model
|
|
self.arch_sem = semantics
|
|
self.dg = self.create_DG(self.kernel, flag_dependencies)
|
|
self.loopcarried_deps = self.check_for_loopcarried_dep(
|
|
self.kernel, timeout, flag_dependencies
|
|
)
|
|
|
|
@classmethod
|
|
def _extend_path(cls, dst_list, kernel, dg, offset):
|
|
for instr in kernel:
|
|
generator_path = nx.algorithms.simple_paths.all_simple_paths(
|
|
dg, instr.line_number, instr.line_number + offset
|
|
)
|
|
tmp_list = list(generator_path)
|
|
dst_list.extend(tmp_list)
|
|
# print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number']))
|
|
|
|
def create_DG(self, kernel, flag_dependencies=False):
|
|
"""
|
|
Create directed graph from given kernel
|
|
|
|
:param kernel: Parsed asm kernel with assigned semantic information
|
|
:type kerne: list
|
|
:param flag_dependencies: indicating if dependencies of flags should be considered,
|
|
defaults to `False`
|
|
:type flag_dependencies: boolean, optional
|
|
:returns: :class:`~nx.DiGraph` -- directed graph object
|
|
"""
|
|
# 1. go through kernel instruction forms and add them as node attribute
|
|
# 2. find edges (to dependend further instruction)
|
|
# 3. get LT value and set as edge weight
|
|
dg = nx.DiGraph()
|
|
for i, instruction_form in enumerate(kernel):
|
|
dg.add_node(instruction_form.line_number)
|
|
dg.nodes[instruction_form.line_number]["instruction_form"] = instruction_form
|
|
# add load as separate node if existent
|
|
if (
|
|
INSTR_FLAGS.HAS_LD in instruction_form.flags
|
|
and INSTR_FLAGS.LD not in instruction_form.flags
|
|
):
|
|
# add new node
|
|
dg.add_node(instruction_form.line_number + 0.1)
|
|
dg.nodes[instruction_form.line_number + 0.1]["instruction_form"] = instruction_form
|
|
# and set LD latency as edge weight
|
|
dg.add_edge(
|
|
instruction_form.line_number + 0.1,
|
|
instruction_form.line_number,
|
|
latency=instruction_form.latency - instruction_form.latency_wo_load,
|
|
)
|
|
for dep, dep_flags in self.find_depending(
|
|
instruction_form, kernel[i + 1 :], flag_dependencies
|
|
):
|
|
# print(instruction_form.line_number,"\t",dep.line_number,"\n")
|
|
edge_weight = (
|
|
instruction_form.latency
|
|
if "mem_dep" in dep_flags or instruction_form.latency_wo_load is None
|
|
else instruction_form.latency_wo_load
|
|
)
|
|
if "storeload_dep" in dep_flags and self.model is not None:
|
|
edge_weight += self.model.get("store_to_load_forward_latency", 0)
|
|
if "p_indexed" in dep_flags and self.model is not None:
|
|
edge_weight = self.model.get("p_index_latency", 1)
|
|
dg.add_edge(
|
|
instruction_form.line_number,
|
|
dep.line_number,
|
|
latency=edge_weight,
|
|
)
|
|
|
|
dg.nodes[dep.line_number]["instruction_form"] = dep
|
|
return dg
|
|
|
|
def check_for_loopcarried_dep(self, kernel, timeout=10, flag_dependencies=False):
|
|
"""
|
|
Try to find loop-carried dependencies in given kernel.
|
|
|
|
:param kernel: Parsed asm kernel with assigned semantic information
|
|
:type kernel: list
|
|
:param timeout: Timeout in seconds for parallel execution, defaults
|
|
to `10`. Set to `0` for no timeout
|
|
:type timeout: int
|
|
:returns: `dict` -- dependency dictionary with all cyclic LCDs
|
|
"""
|
|
# increase line number for second kernel loop
|
|
offset = max(1000, max([i.line_number for i in kernel]))
|
|
tmp_kernel = [] + kernel
|
|
for orig_iform in kernel:
|
|
temp_iform = copy.copy(orig_iform)
|
|
temp_iform.line_number += offset
|
|
tmp_kernel.append(temp_iform)
|
|
# get dependency graph
|
|
dg = self.create_DG(tmp_kernel, flag_dependencies)
|
|
|
|
# build cyclic loop-carried dependencies
|
|
loopcarried_deps = []
|
|
all_paths = []
|
|
|
|
klen = len(kernel)
|
|
if klen >= self.INSTRUCTION_THRESHOLD:
|
|
# parallel execution with static scheduling
|
|
num_cores = cpu_count()
|
|
workload = int((klen - 1) / num_cores) + 1
|
|
starts = [tid * workload for tid in range(num_cores)]
|
|
ends = [min((tid + 1) * workload, klen) for tid in range(num_cores)]
|
|
instrs = [kernel[s:e] for s, e in zip(starts, ends)]
|
|
with Manager() as manager:
|
|
all_paths = manager.list()
|
|
processes = [
|
|
Process(
|
|
target=KernelDG._extend_path,
|
|
args=(all_paths, instr_section, dg, offset),
|
|
)
|
|
for instr_section in instrs
|
|
]
|
|
for p in processes:
|
|
p.start()
|
|
if timeout == -1:
|
|
# no timeout
|
|
for p in processes:
|
|
p.join()
|
|
else:
|
|
start_time = time.time()
|
|
while time.time() - start_time <= timeout:
|
|
if any(p.is_alive() for p in processes):
|
|
time.sleep(0.2)
|
|
else:
|
|
# all procs done
|
|
for p in processes:
|
|
p.join()
|
|
break
|
|
else:
|
|
self.timed_out = True
|
|
# terminate running processes
|
|
for p in processes:
|
|
if p.is_alive():
|
|
p.kill()
|
|
p.join()
|
|
all_paths = list(all_paths)
|
|
else:
|
|
# sequential execution to avoid overhead when analyzing smaller kernels
|
|
for instr in kernel:
|
|
all_paths.extend(
|
|
nx.algorithms.simple_paths.all_simple_paths(
|
|
dg, instr.line_number, instr.line_number + offset
|
|
)
|
|
)
|
|
|
|
paths_set = set()
|
|
for path in all_paths:
|
|
lat_sum = 0.0
|
|
# extend path by edge bound latencies (e.g., store-to-load latency)
|
|
lat_path = []
|
|
for s, d in nx.utils.pairwise(path):
|
|
edge_lat = dg.edges[s, d]["latency"]
|
|
# map source node back to original line numbers
|
|
if s > offset:
|
|
s -= offset
|
|
lat_path.append((s, edge_lat))
|
|
lat_sum += edge_lat
|
|
if d > offset:
|
|
d -= offset
|
|
lat_path.sort()
|
|
|
|
# Ignore duplicate paths which differ only in the root node
|
|
if tuple(lat_path) in paths_set:
|
|
continue
|
|
paths_set.add(tuple(lat_path))
|
|
|
|
loopcarried_deps.append((lat_sum, lat_path))
|
|
loopcarried_deps.sort(reverse=True)
|
|
|
|
# map lcd back to nodes
|
|
loopcarried_deps_dict = {}
|
|
for lat_sum, involved_lines in loopcarried_deps:
|
|
dict_key = "-".join([str(il[0]) for il in involved_lines])
|
|
loopcarried_deps_dict[dict_key] = {
|
|
"root": self._get_node_by_lineno(involved_lines[0][0]),
|
|
"dependencies": [
|
|
(self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines
|
|
],
|
|
"latency": lat_sum,
|
|
}
|
|
return loopcarried_deps_dict
|
|
|
|
def _get_node_by_lineno(self, lineno, kernel=None, all=False):
|
|
"""Return instruction form with line number ``lineno`` from kernel"""
|
|
if kernel is None:
|
|
kernel = self.kernel
|
|
result = [instr for instr in kernel if instr.line_number == lineno]
|
|
if not all:
|
|
return result[0]
|
|
else:
|
|
return result
|
|
|
|
def get_critical_path(self):
|
|
"""Find and return critical path after the creation of a directed graph."""
|
|
max_latency_instr = max(self.kernel, key=lambda k: k.latency)
|
|
if nx.algorithms.dag.is_directed_acyclic_graph(self.dg):
|
|
longest_path = nx.algorithms.dag.dag_longest_path(self.dg, weight="latency")
|
|
# TODO verify that we can remove the next two lince due to earlier initialization
|
|
for line_number in longest_path:
|
|
self._get_node_by_lineno(int(line_number)).latency_cp = 0
|
|
# set cp latency to instruction
|
|
path_latency = 0.0
|
|
for s, d in nx.utils.pairwise(longest_path):
|
|
node = self._get_node_by_lineno(int(s))
|
|
node.latency_cp = self.dg.edges[(s, d)]["latency"]
|
|
path_latency += node.latency_cp
|
|
# add latency for last instruction
|
|
node = self._get_node_by_lineno(int(longest_path[-1]))
|
|
node.latency_cp = node.latency
|
|
if max_latency_instr.latency > path_latency:
|
|
max_latency_instr.latency_cp = float(max_latency_instr.latency)
|
|
return [max_latency_instr]
|
|
else:
|
|
return [x for x in self.kernel if x.line_number in longest_path]
|
|
else:
|
|
# split to DAG
|
|
raise NotImplementedError("Kernel is cyclic.")
|
|
|
|
def get_loopcarried_dependencies(self):
|
|
"""
|
|
Return all LCDs from kernel (after :func:`~KernelDG.check_for_loopcarried_dep` was run)
|
|
"""
|
|
if nx.algorithms.dag.is_directed_acyclic_graph(self.dg):
|
|
return self.loopcarried_deps
|
|
else:
|
|
# split to DAG
|
|
raise NotImplementedError("Kernel is cyclic.")
|
|
|
|
def find_depending(self, instruction_form, instructions, flag_dependencies=False):
|
|
"""
|
|
Find instructions in `instructions` depending on a given instruction form's results.
|
|
|
|
:param dict instruction_form: instruction form to check for dependencies
|
|
:param list instructions: instructions to check
|
|
:param flag_dependencies: indicating if dependencies of flags should be considered,
|
|
defaults to `False`
|
|
:type flag_dependencies: boolean, optional
|
|
:returns: iterator if all directly dependent instruction forms and according flags
|
|
"""
|
|
if instruction_form.semantic_operands is None:
|
|
return
|
|
for dst in chain(
|
|
instruction_form.semantic_operands["destination"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
# TODO instructions before must be considered as well, if they update registers
|
|
# not used by insruction_form. E.g., validation/build/A64FX/gcc/O1/gs-2d-5pt.marked.s
|
|
register_changes = self._update_reg_changes(instruction_form)
|
|
# print("FROM", instruction_form.line, register_changes)
|
|
for i, instr_form in enumerate(instructions):
|
|
self._update_reg_changes(instr_form, register_changes)
|
|
# print(" TO", instr_form.line, register_changes)
|
|
if isinstance(dst, RegisterOperand):
|
|
# read of register
|
|
if self.is_read(dst, instr_form):
|
|
if (
|
|
dst.pre_indexed
|
|
or dst.post_indexed
|
|
or (isinstance(dst.post_indexed, dict))
|
|
):
|
|
yield instr_form, ["p_indexed"]
|
|
else:
|
|
yield instr_form, []
|
|
# write to register -> abort
|
|
if self.is_written(dst, instr_form):
|
|
break
|
|
if isinstance(dst, FlagOperand) and flag_dependencies:
|
|
# read of flag
|
|
if self.is_read(dst, instr_form):
|
|
yield instr_form, []
|
|
# write to flag -> abort
|
|
if self.is_written(dst, instr_form):
|
|
break
|
|
if isinstance(dst, MemoryOperand):
|
|
# base register is altered during memory access
|
|
if dst.pre_indexed:
|
|
if self.is_written(dst.base, instr_form):
|
|
break
|
|
# if dst.memory.base:
|
|
# if self.is_read(dst.memory.base, instr_form):
|
|
# yield instr_form, []
|
|
# if dst.memory.index:
|
|
# if self.is_read(dst.memory.index, instr_form):
|
|
# yield instr_form, []
|
|
if dst.post_indexed:
|
|
# Check for read of base register until overwrite
|
|
if self.is_written(dst.base, instr_form):
|
|
break
|
|
# TODO record register changes
|
|
# (e.g., mov, leaadd, sub, inc, dec) in instructions[:i]
|
|
# and pass to is_memload and is_memstore to consider relevance.
|
|
# load from same location (presumed)
|
|
if self.is_memload(dst, instr_form, register_changes):
|
|
yield instr_form, ["storeload_dep"]
|
|
# store to same location (presumed)
|
|
if self.is_memstore(dst, instr_form, register_changes):
|
|
break
|
|
self._update_reg_changes(instr_form, register_changes, only_postindexed=True)
|
|
|
|
def _update_reg_changes(self, iform, reg_state=None, only_postindexed=False):
|
|
if self.arch_sem is None:
|
|
# This analysis requires semenatics to be available
|
|
return {}
|
|
if reg_state is None:
|
|
reg_state = {}
|
|
for reg, change in self.arch_sem.get_reg_changes(iform, only_postindexed).items():
|
|
if change is None or reg_state.get(reg, {}) is None:
|
|
reg_state[reg] = None
|
|
else:
|
|
reg_state.setdefault(reg, {"name": reg, "value": 0})
|
|
if change["name"] != reg:
|
|
# renaming occured, ovrwrite value with up-to-now change of source register
|
|
reg_state[reg]["name"] = change["name"]
|
|
src_reg_state = reg_state.get(change["name"], {"value": 0})
|
|
if src_reg_state is None:
|
|
# original register's state was changed beyond reconstruction
|
|
reg_state[reg] = None
|
|
continue
|
|
reg_state[reg]["value"] = src_reg_state["value"]
|
|
reg_state[reg]["value"] += change["value"]
|
|
return reg_state
|
|
|
|
def get_dependent_instruction_forms(self, instr_form=None, line_number=None):
|
|
"""
|
|
Returns iterator
|
|
"""
|
|
if not instr_form and not line_number:
|
|
raise ValueError("Either instruction form or line_number required.")
|
|
line_number = line_number if line_number else instr_form["line_number"]
|
|
if self.dg.has_node(line_number):
|
|
return self.dg.successors(line_number)
|
|
return iter([])
|
|
|
|
def is_read(self, register, instruction_form):
|
|
"""Check if instruction form reads from given register"""
|
|
is_read = False
|
|
if instruction_form.semantic_operands is None:
|
|
return is_read
|
|
for src in chain(
|
|
instruction_form.semantic_operands["source"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
if isinstance(src, RegisterOperand):
|
|
is_read = self.parser.is_reg_dependend_of(register, src) or is_read
|
|
if isinstance(src, FlagOperand):
|
|
is_read = self.parser.is_flag_dependend_of(register, src) or is_read
|
|
if isinstance(src, MemoryOperand):
|
|
if src.base is not None:
|
|
is_read = self.parser.is_reg_dependend_of(register, src.base) or is_read
|
|
if src.index is not None and isinstance(src.index, RegisterOperand):
|
|
is_read = self.parser.is_reg_dependend_of(register, src.index) or is_read
|
|
# Check also if read in destination memory address
|
|
for dst in chain(
|
|
instruction_form.semantic_operands["destination"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
if isinstance(dst, MemoryOperand):
|
|
if dst.base is not None:
|
|
is_read = self.parser.is_reg_dependend_of(register, dst.base) or is_read
|
|
if dst.index is not None:
|
|
is_read = self.parser.is_reg_dependend_of(register, dst.index) or is_read
|
|
return is_read
|
|
|
|
def is_memload(self, mem, instruction_form, register_changes={}):
|
|
"""Check if instruction form loads from given location, assuming register_changes"""
|
|
if instruction_form.semantic_operands is None:
|
|
return False
|
|
for src in chain(
|
|
instruction_form.semantic_operands["source"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
# Here we check for mem dependecies only
|
|
if not isinstance(src, MemoryOperand):
|
|
continue
|
|
# src = src.memory
|
|
|
|
# determine absolute address change
|
|
addr_change = 0
|
|
if isinstance(src.offset, ImmediateOperand) and src.offset.value is not None:
|
|
addr_change += src.offset.value
|
|
if isinstance(mem.offset, ImmediateOperand) and mem.offset.value is not None:
|
|
addr_change -= mem.offset.value
|
|
if mem.base and src.base:
|
|
base_change = register_changes.get(
|
|
(src.base.prefix if src.base.prefix is not None else "") + src.base.name,
|
|
{
|
|
"name": (src.base.prefix if src.base.prefix is not None else "")
|
|
+ src.base.name,
|
|
"value": 0,
|
|
},
|
|
)
|
|
if base_change is None:
|
|
# Unknown change occurred
|
|
continue
|
|
if (
|
|
mem.base.prefix
|
|
if mem.base.prefix is not None
|
|
else "" + mem.base.name != base_change["name"]
|
|
):
|
|
# base registers do not match
|
|
continue
|
|
addr_change += base_change["value"]
|
|
elif mem.base or src.base:
|
|
# base registers do not match
|
|
continue
|
|
if mem.index and src.index:
|
|
index_change = register_changes.get(
|
|
(src.index.prefix if src.index.prefix is not None else "") + src.index.name,
|
|
{
|
|
"name": (src.index.prefix if src.index.prefix is not None else "")
|
|
+ src.index.name,
|
|
"value": 0,
|
|
},
|
|
)
|
|
if index_change is None:
|
|
# Unknown change occurred
|
|
continue
|
|
if mem.scale != src.scale:
|
|
# scale factors do not match
|
|
continue
|
|
if (
|
|
mem.index.prefix
|
|
if mem.index.prefix is not None
|
|
else "" + mem.index.name != index_change["name"]
|
|
):
|
|
# index registers do not match
|
|
continue
|
|
addr_change += index_change["value"] * src.scale
|
|
elif mem.index or src.index:
|
|
# index registers do not match
|
|
continue
|
|
# if instruction_form.line_number == 3:
|
|
if addr_change == 0:
|
|
return True
|
|
return False
|
|
|
|
def is_written(self, register, instruction_form):
|
|
"""Check if instruction form writes in given register"""
|
|
is_written = False
|
|
if instruction_form.semantic_operands is None:
|
|
return is_written
|
|
for dst in chain(
|
|
instruction_form.semantic_operands["destination"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
if isinstance(dst, RegisterOperand):
|
|
is_written = self.parser.is_reg_dependend_of(register, dst) or is_written
|
|
if isinstance(dst, FlagOperand):
|
|
is_written = self.parser.is_flag_dependend_of(register, dst) or is_written
|
|
if isinstance(dst, MemoryOperand):
|
|
if dst.pre_indexed or dst.post_indexed:
|
|
is_written = self.parser.is_reg_dependend_of(register, dst.base) or is_written
|
|
# Check also for possible pre- or post-indexing in memory addresses
|
|
for src in chain(
|
|
instruction_form.semantic_operands["source"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
if isinstance(src, MemoryOperand):
|
|
if src.pre_indexed or src.post_indexed:
|
|
is_written = self.parser.is_reg_dependend_of(register, src.base) or is_written
|
|
return is_written
|
|
|
|
def is_memstore(self, mem, instruction_form, register_changes={}):
|
|
"""Check if instruction form stores to given location, assuming unchanged registers"""
|
|
is_store = False
|
|
if instruction_form.semantic_operands is None:
|
|
return is_store
|
|
for dst in chain(
|
|
instruction_form.semantic_operands["destination"],
|
|
instruction_form.semantic_operands["src_dst"],
|
|
):
|
|
if isinstance(dst, MemoryOperand):
|
|
is_store = mem == dst or is_store
|
|
return is_store
|
|
|
|
def export_graph(self, filepath=None):
|
|
"""
|
|
Export graph with highlighted CP and LCDs as DOT file. Writes it to 'osaca_dg.dot'
|
|
if no other path is given.
|
|
|
|
:param filepath: path to write DOT file, defaults to None.
|
|
:type filepath: str, optional
|
|
"""
|
|
graph = copy.deepcopy(self.dg)
|
|
cp = self.get_critical_path()
|
|
cp_line_numbers = [x.line_number for x in cp]
|
|
lcd = self.get_loopcarried_dependencies()
|
|
lcd_line_numbers = {}
|
|
for dep in lcd:
|
|
lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]]
|
|
# add color scheme
|
|
graph.graph["node"] = {"colorscheme": "spectral9"}
|
|
graph.graph["edge"] = {"colorscheme": "spectral9"}
|
|
min_color = 2
|
|
available_colors = 8
|
|
|
|
# create LCD edges
|
|
for dep in lcd_line_numbers:
|
|
min_line_number = min(lcd_line_numbers[dep])
|
|
max_line_number = max(lcd_line_numbers[dep])
|
|
graph.add_edge(min_line_number, max_line_number, dir="back")
|
|
graph.edges[min_line_number, max_line_number]["latency"] = [
|
|
lat for x, lat in lcd[dep]["dependencies"] if x.line_number == max_line_number
|
|
]
|
|
|
|
# add label to edges
|
|
for e in graph.edges:
|
|
graph.edges[e]["label"] = graph.edges[e]["latency"]
|
|
|
|
# add CP values to graph
|
|
for n in cp:
|
|
graph.nodes[n.line_number]["instruction_form"].latency_cp = n.latency_cp
|
|
|
|
# Make the critical path bold.
|
|
for n in graph.nodes:
|
|
if n in cp_line_numbers:
|
|
# graph.nodes[n]['color'] = 1
|
|
graph.nodes[n]["style"] = "bold"
|
|
graph.nodes[n]["penwidth"] = 4
|
|
|
|
# Make critical path edges bold.
|
|
for e in graph.edges:
|
|
if (
|
|
graph.nodes[e[0]]["instruction_form"].line_number in cp_line_numbers
|
|
and graph.nodes[e[1]]["instruction_form"].line_number in cp_line_numbers
|
|
and e[0] < e[1]
|
|
):
|
|
bold_edge = True
|
|
for i in range(e[0] + 1, e[1]):
|
|
if i in cp_line_numbers:
|
|
bold_edge = False
|
|
if bold_edge:
|
|
graph.edges[e]["style"] = "bold"
|
|
graph.edges[e]["penwidth"] = 3
|
|
|
|
# Color the cycles created by loop-carried dependencies, longest first, never recoloring
|
|
# any node, so that the longest LCD and most long chains that are involved in the loop are
|
|
# legible.
|
|
for i, dep in enumerate(sorted(lcd, key=lambda dep: -lcd[dep]["latency"])):
|
|
# For cycles that are broken by already-colored (longer) cycles, the color need not be
|
|
# the same for each yet-uncolored arc.
|
|
# Do not use the same color for such an arc as for the cycles that delimit it. This is
|
|
# always possible with 3 colors, as each arc is only adjacent to the preceding and
|
|
# following interrupting cycles.
|
|
# Since we color edges as well as nodes, there would be room for a more interesting
|
|
# graph coloring problem: we could avoid having unrelated arcs with the same color
|
|
# meeting at the same vertex, and retain the same color between arcs of the same cycle
|
|
# that are interrupted by a single vertex. We mostly ignore this problem.
|
|
|
|
# The longest cycle will always have color 1, the second longest cycle will always have
|
|
# color 2 except where it overlaps with with the longest cycle, etc.; for arcs that are
|
|
# part of short cycles, the colors will be less predictable.
|
|
default_color = min_color + i % available_colors
|
|
arc = []
|
|
arc_source = lcd_line_numbers[dep][-1]
|
|
arcs = []
|
|
for n in lcd_line_numbers[dep]:
|
|
if "fillcolor" in graph.nodes[n]:
|
|
arcs.append((arc, (arc_source, n)))
|
|
arc = []
|
|
arc_source = n
|
|
else:
|
|
arc.append(n)
|
|
if not arcs: # Unconstrained cycle.
|
|
arcs.append((arc, tuple()))
|
|
else:
|
|
arcs.append((arc, (arc_source, lcd_line_numbers[dep][0])))
|
|
# Try to color the whole cycle with its default color, then with a single color, then
|
|
# with different colors by arc, preferring the default.
|
|
forbidden_colors = set(
|
|
graph.nodes[n]["fillcolor"] for arc, extremities in arcs for n in extremities
|
|
if "fillcolor" in graph.nodes[n]
|
|
)
|
|
global_color = None
|
|
if default_color not in forbidden_colors:
|
|
global_color = default_color
|
|
elif len(forbidden_colors) < available_colors:
|
|
global_color = next(
|
|
c for c in range(min_color, min_color + available_colors + 1)
|
|
if c not in forbidden_colors
|
|
)
|
|
for arc, extremities in arcs:
|
|
if global_color:
|
|
color = global_color
|
|
else:
|
|
color = default_color
|
|
while color in (graph.nodes[n].get("fillcolor") for n in extremities):
|
|
color = min_color + (color + 1) % available_colors
|
|
for n in arc:
|
|
if "style" not in graph.nodes[n]:
|
|
graph.nodes[n]["style"] = "filled"
|
|
else:
|
|
graph.nodes[n]["style"] += ",filled"
|
|
graph.nodes[n]["fillcolor"] = color
|
|
if extremities:
|
|
(source, sink) = extremities
|
|
else:
|
|
source = sink = arc[0]
|
|
arc = arc[1:]
|
|
for u, v in zip([source] + arc, arc + [sink]):
|
|
# The backward edge of the cycle is represented as the corresponding forward
|
|
# edge with the attribute dir=back.
|
|
edge = graph.edges[v, u] if (v, u) in graph.edges else graph.edges[u, v]
|
|
if arc:
|
|
if "color" in edge:
|
|
raise AssertionError(
|
|
f"Recoloring {u}->{v} in arc ({source}) {arc} ({sink}) of {dep}"
|
|
)
|
|
edge["color"] = color
|
|
|
|
# rename node from [idx] to [idx mnemonic] and add shape
|
|
mapping = {}
|
|
for n in graph.nodes:
|
|
if int(n) != n:
|
|
mapping[n] = "{}: LOAD".format(int(n))
|
|
graph.nodes[n]["fontname"] = "italic"
|
|
graph.nodes[n]["fontsize"] = 11.0
|
|
else:
|
|
node = graph.nodes[n]["instruction_form"]
|
|
if node.mnemonic is not None:
|
|
mapping[n] = "{}: {}".format(n, node.mnemonic)
|
|
else:
|
|
label = "label" if node.label is not None else None
|
|
label = "directive" if node.directive is not None else label
|
|
label = "comment" if node.comment is not None and label is None else label
|
|
mapping[n] = "{}: {}".format(n, label)
|
|
graph.nodes[n]["fontname"] = "italic"
|
|
graph.nodes[n]["fontsize"] = 11.0
|
|
graph.nodes[n]["shape"] = "rectangle"
|
|
|
|
nx.relabel.relabel_nodes(graph, mapping, copy=False)
|
|
if filepath:
|
|
nx.drawing.nx_agraph.write_dot(graph, filepath)
|
|
else:
|
|
nx.drawing.nx_agraph.write_dot(graph, "osaca_dg.dot")
|