support for overriding cache sets

This commit is contained in:
Andreas Abel
2019-12-30 22:34:38 +01:00
parent e3629ead9b
commit d70535f210
5 changed files with 117 additions and 45 deletions

View File

@@ -383,9 +383,16 @@ def getAddresses(level, wayID, cacheSetList, cBox=1, clearHL=True):
raise ValueError('invalid level') raise ValueError('invalid level')
# removes ?s and !s # removes ?s and !s, and returns the part before the first '_'
def getBlockName(blockStr): def getBlockName(blockStr):
return re.sub('[?!]', '', blockStr) return re.sub('[?!]', '', blockStr.split('_')[0])
# removes ?s and !s, and returns the part after the last '_' (as int); returns None if there is no '_'
def getBlockSet(blockStr):
if not '_' in blockStr:
return None
return int(re.match('\d+', blockStr.split('_')[-1]).group())
def parseCacheSetsStr(level, clearHL, cacheSetsStr): def parseCacheSetsStr(level, clearHL, cacheSetsStr):
@@ -399,7 +406,7 @@ def parseCacheSetsStr(level, clearHL, cacheSetsStr):
cacheSetList.append(int(s)) cacheSetList.append(int(s))
else: else:
nSets = getCacheInfo(level).nSets nSets = getCacheInfo(level).nSets
if level > 1 and clearHL: if level > 1 and clearHL and not (level == 3 and getCacheInfo(3).nSlices is not None):
nHLSets = getCacheInfo(level-1).nSets nHLSets = getCacheInfo(level-1).nSets
cacheSetList = range(nHLSets, nSets) cacheSetList = range(nHLSets, nSets)
else: else:
@@ -407,19 +414,34 @@ def parseCacheSetsStr(level, clearHL, cacheSetsStr):
return cacheSetList return cacheSetList
def findCacheSetForCode(cacheSetList, level):
nSets = getCacheInfo(level).nSets
sortedCacheSetList = sorted(cacheSetList)
sortedCacheSetList += [sortedCacheSetList[0] + nSets]
maxDist = 1
bestSet = 0
for i in range(len(sortedCacheSetList)-1):
dist = sortedCacheSetList[i+1] - sortedCacheSetList[i]
if dist > maxDist:
maxDist = dist
bestSet = (sortedCacheSetList[i] + 1) % nSets
return bestSet
def getAllUsedCacheSets(cacheSetList, seq, initSeq=''):
cacheSetOverrideList = [s for s in set(map(getBlockSet, initSeq.split()+seq.split())) if s is not None]
return sorted(set(cacheSetList + cacheSetOverrideList))
AddressList = namedtuple('AddressList', 'addresses exclude flush wbinvd') AddressList = namedtuple('AddressList', 'addresses exclude flush wbinvd')
# cacheSets=None means do access in all sets def getCodeForCacheExperiment(level, seq, initSeq, cacheSetList, cBox, clearHL, wbinvd):
# in this case, the first nL1Sets many sets of L2 will be reserved for clearing L1 allUsedSets = getAllUsedCacheSets(cacheSetList, seq, initSeq)
# if wbinvd is set, wbinvd will be called before initSeq
def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=True, loop=1, wbinvd=False, nMeasurements=10, warmUpCount=1, agg='avg'):
lineSize = getCacheInfo(1).lineSize
cacheSetList = parseCacheSetsStr(level, clearHL, cacheSets)
clearHLAddrList = None clearHLAddrList = None
if (clearHL and level > 1): if (clearHL and level > 1):
clearHLAddrList = AddressList(getClearHLAddresses(level, cacheSetList, cBox), True, False, False) clearHLAddrList = AddressList(getClearHLAddresses(level, allUsedSets, cBox), True, False, False)
initAddressLists = [] initAddressLists = []
seqAddressLists = [] seqAddressLists = []
@@ -432,29 +454,49 @@ def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=T
addrLists.append(AddressList([], True, False, True)) addrLists.append(AddressList([], True, False, True))
continue continue
overrideSet = getBlockSet(seqEl)
wayID = nameToID.setdefault(name, len(nameToID)) wayID = nameToID.setdefault(name, len(nameToID))
exclude = not '?' in seqEl exclude = not '?' in seqEl
flush = '!' in seqEl flush = '!' in seqEl
addresses = getAddresses(level, wayID, cacheSetList, cBox=cBox, clearHL=clearHL) s = [overrideSet] if overrideSet is not None else cacheSetList
addresses = getAddresses(level, wayID, s, cBox=cBox, clearHL=clearHL)
if clearHLAddrList is not None and not flush: if clearHLAddrList is not None and not flush:
addrLists.append(clearHLAddrList) addrLists.append(clearHLAddrList)
addrLists.append(AddressList(addresses, exclude, flush, False)) addrLists.append(AddressList(addresses, exclude, flush, False))
ec = getCodeForAddressLists(seqAddressLists, initAddressLists, wbinvd)
log.debug('\nInitAddresses: ' + str(initAddressLists)) log.debug('\nInitAddresses: ' + str(initAddressLists))
log.debug('\nSeqAddresses: ' + str(seqAddressLists)) log.debug('\nSeqAddresses: ' + str(seqAddressLists))
return getCodeForAddressLists(seqAddressLists, initAddressLists, wbinvd)
def runCacheExperimentCode(code, initCode, oneTimeInitCode, loop, warmUpCount, codeOffset, nMeasurements, agg):
resetNanoBench()
setNanoBenchParameters(config=getDefaultCacheConfig(), msrConfig=getDefaultCacheMSRConfig(), nMeasurements=nMeasurements, unrollCount=1, loopCount=loop,
warmUpCount=warmUpCount, aggregateFunction=agg, basicMode=True, noMem=True, codeOffset=codeOffset, verbose=None)
return runNanoBench(code=code, init=initCode, oneTimeInit=oneTimeInitCode)
# cacheSets=None means do access in all sets
# in this case, the first nL1Sets many sets of L2 will be reserved for clearing L1
# if wbinvd is set, wbinvd will be called before initSeq
def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=True, loop=1, wbinvd=False, nMeasurements=10, warmUpCount=1, codeSet=None,
agg='avg'):
cacheSetList = parseCacheSetsStr(level, clearHL, cacheSets)
ec = getCodeForCacheExperiment(level, seq, initSeq=initSeq, cacheSetList=cacheSetList, cBox=cBox, clearHL=clearHL, wbinvd=wbinvd)
log.debug('\nOneTimeInit: ' + ec.oneTimeInit) log.debug('\nOneTimeInit: ' + ec.oneTimeInit)
log.debug('\nInit: ' + ec.init) log.debug('\nInit: ' + ec.init)
log.debug('\nCode: ' + ec.code) log.debug('\nCode: ' + ec.code)
resetNanoBench() lineSize = getCacheInfo(1).lineSize
setNanoBenchParameters(config=getDefaultCacheConfig(), msrConfig=getDefaultCacheMSRConfig(), nMeasurements=nMeasurements, unrollCount=1, loopCount=loop, allUsedSets = getAllUsedCacheSets(cacheSetList, seq, initSeq)
warmUpCount=warmUpCount, aggregateFunction=agg, basicMode=True, noMem=True, verbose=None) codeOffset = lineSize * (codeSet if codeSet is not None else findCacheSetForCode(allUsedSets, level))
return runNanoBench(code=ec.code, init=ec.init, oneTimeInit=ec.oneTimeInit) return runCacheExperimentCode(ec.code, ec.init, ec.oneTimeInit, loop, warmUpCount, codeOffset, nMeasurements, agg)
def printNB(nb_result): def printNB(nb_result):
@@ -603,7 +645,7 @@ def getAgesOfBlocks(blocks, level, seq, initSeq='', maxAge=None, cacheSets=None,
newBlocks = getUnusedBlockNames(nNewBlocks, seq+initSeq, 'N') newBlocks = getUnusedBlockNames(nNewBlocks, seq+initSeq, 'N')
curSeq += ' '.join(newBlocks) + ' ' + block + '?' curSeq += ' '.join(newBlocks) + ' ' + block + '?'
nb = runCacheExperiment(level, curSeq, initSeq=initSeq, cacheSets=cacheSets, cBox=cBox, clearHL=clearHL, loop=0, wbinvd=wbinvd, nMeasurements=nMeasurements) nb = runCacheExperiment(level, curSeq, initSeq=initSeq, cacheSets=cacheSets, cBox=cBox, clearHL=clearHL, loop=0, wbinvd=wbinvd, nMeasurements=nMeasurements, agg=agg)
if returnNbResults: nbResults[block].append(nb) if returnNbResults: nbResults[block].append(nb)
hitEvent = 'L' + str(level) + '_HIT' hitEvent = 'L' + str(level) + '_HIT'

View File

@@ -33,9 +33,8 @@ def main():
if args.sim: if args.sim:
policyClass = cacheSim.AllPolicies[args.sim] policyClass = cacheSim.AllPolicies[args.sim]
setCount = len(parseCacheSetsStr(args.level, (not args.noClearHL), args.sets))
seq = args.seq_init + (' ' + args.seq) * args.loop seq = args.seq_init + (' ' + args.seq) * args.loop
hits = cacheSim.getHits(seq, policyClass, args.simAssoc, setCount) / args.loop hits = cacheSim.getHits(seq, policyClass, args.simAssoc, args.sets) / args.loop
print 'Hits: ' + str(hits) print 'Hits: ' + str(hits)
else: else:
nb = runCacheExperiment(args.level, args.seq, initSeq=args.seq_init, cacheSets=args.sets, cBox=args.cBox, clearHL=(not args.noClearHL), loop=args.loop, nb = runCacheExperiment(args.level, args.seq, initSeq=args.seq_init, cacheSets=args.sets, cBox=args.cBox, clearHL=(not args.noClearHL), loop=args.loop,

View File

@@ -131,6 +131,12 @@ class LRU_PLRU4Sim(ReplPolicySim):
self.PLRUOrdered = [curPLRU] + [plru for plru in self.PLRUOrdered if plru!=curPLRU] self.PLRUOrdered = [curPLRU] + [plru for plru in self.PLRUOrdered if plru!=curPLRU]
return hit return hit
def flush(self, block):
for plru in self.PLRUs:
if block in plru.blocks:
plru.flush(block)
class QLRUSim(ReplPolicySim): class QLRUSim(ReplPolicySim):
def __init__(self, assoc, hitFunc, missFunc, replIdxFunc, updFunc, updOnMissOnly=False): def __init__(self, assoc, hitFunc, missFunc, replIdxFunc, updFunc, updOnMissOnly=False):
super(QLRUSim, self).__init__(assoc) super(QLRUSim, self).__init__(assoc)
@@ -298,18 +304,41 @@ AllRandPolicies = dict(AllRandQLRUVariants.items() + AllRandPLRUVariants.items()
AllPolicies = dict(AllDetPolicies.items() + AllRandPolicies.items()) AllPolicies = dict(AllDetPolicies.items() + AllRandPolicies.items())
def getHits(seq, policySimClass, assoc, nSets): def parseCacheSetsStrSim(cacheSetsStr):
hits = 0 if cacheSetsStr is None:
policySims = [policySimClass(assoc) for _ in range(0, nSets)] raise ValueError('no cache sets specified')
cacheSetList = []
for s in cacheSetsStr.split(','):
if '-' in s:
first, last = s.split('-')[:2]
cacheSetList += range(int(first), int(last)+1)
else:
cacheSetList.append(int(s))
return cacheSetList
def getHits(seq, policySimClass, assoc, cacheSetStr):
cacheSetList = parseCacheSetsStrSim(cacheSetStr)
allUsedSets = getAllUsedCacheSets(cacheSetList, seq)
policySims = {s: policySimClass(assoc) for s in allUsedSets}
hits = 0
for blockStr in seq.split(): for blockStr in seq.split():
blockName = getBlockName(blockStr) blockName = getBlockName(blockStr)
if '!' in blockStr: if blockName == '<wbinvd>':
for policySim in policySims: policySims = {s: policySimClass(assoc) for s in allUsedSets}
policySim.flush(blockName) continue
else:
for policySim in policySims: overrideSet = getBlockSet(blockStr)
hit = policySim.acc(blockName) sets = [overrideSet] if overrideSet is not None else cacheSetList
for s in sets:
if '!' in blockStr:
policySims[s].flush(blockName)
else:
hit = policySims[s].acc(blockName)
if '?' in blockStr: if '?' in blockStr:
hits += int(hit) hits += int(hit)
return hits return hits
@@ -320,7 +349,7 @@ def getAges(blocks, seq, policySimClass, assoc):
for block in blocks: for block in blocks:
for i in count(0): for i in count(0):
curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?' curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?'
if getHits(curSeq, policySimClass, assoc, 1) == 0: if getHits(curSeq, policySimClass, assoc, '0') == 0:
ages[block] = i ages[block] = i
break break
return ages return ages
@@ -332,7 +361,7 @@ def getGraph(blocks, seq, policySimClass, assoc, maxAge, nSets=1, nRep=1, agg="m
trace = [] trace = []
for i in range(0, maxAge): for i in range(0, maxAge):
curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?' curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?'
hits = [getHits(curSeq, policySimClass, assoc, nSets) for _ in range(0, nRep)] hits = [getHits(curSeq, policySimClass, assoc, '0-'+str(nSets-1)) for _ in range(0, nRep)]
if agg == "med": if agg == "med":
aggValue = median(hits) aggValue = median(hits)
elif agg == "min": elif agg == "min":

View File

@@ -27,9 +27,9 @@ def main():
logging.basicConfig(stream=sys.stdout, format='%(message)s', level=logging.getLevelName(args.logLevel)) logging.basicConfig(stream=sys.stdout, format='%(message)s', level=logging.getLevelName(args.logLevel))
if args.sim: if args.sim:
policyClass = cacheSim.Policies[args.sim] policyClass = cacheSim.AllPolicies[args.sim]
seq = re.sub('[?!]', '', ' '.join([args.seq_init, args.seq])).strip() + '?' seq = re.sub('[?!]', '', ' '.join([args.seq_init, args.seq])).strip() + '?'
hits = cacheSim.getHits(policyClass(args.simAssoc), seq) hits = cacheSim.getHits(seq, policyClass, args.simAssoc, args.sets)
if hits > 0: if hits > 0:
print 'HIT' print 'HIT'
exit(1) exit(1)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import argparse import argparse
import random import random
import sys
from numpy import median from numpy import median
@@ -17,13 +18,11 @@ def getActualHits(seq, level, cacheSets, cBox, nMeasurements=10):
def findSmallCounterexample(policy, initSeq, level, sets, cBox, assoc, seq, nMeasurements): def findSmallCounterexample(policy, initSeq, level, sets, cBox, assoc, seq, nMeasurements):
setCount = len(parseCacheSetsStr(level, True, sets))
seqSplit = seq.split() seqSplit = seq.split()
for seqPrefix in [seqSplit[:i] for i in range(assoc+1, len(seqSplit)+1)]: for seqPrefix in [seqSplit[:i] for i in range(assoc+1, len(seqSplit)+1)]:
seq = initSeq + ' '.join(seqPrefix) seq = initSeq + ' '.join(seqPrefix)
actual = getActualHits(seq, level, sets, cBox, nMeasurements) actual = getActualHits(seq, level, sets, cBox, nMeasurements)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, setCount) sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, sets)
print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim) print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim)
if sim != actual: if sim != actual:
break break
@@ -32,7 +31,7 @@ def findSmallCounterexample(policy, initSeq, level, sets, cBox, assoc, seq, nMea
tmpPrefix = seqPrefix[:i] + seqPrefix[(i+1):] tmpPrefix = seqPrefix[:i] + seqPrefix[(i+1):]
seq = initSeq + ' '.join(tmpPrefix) seq = initSeq + ' '.join(tmpPrefix)
actual = getActualHits(seq, level, sets, cBox, nMeasurements) actual = getActualHits(seq, level, sets, cBox, nMeasurements)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, setCount) sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, sets)
print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim) print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim)
if sim != actual: if sim != actual:
seqPrefix = tmpPrefix seqPrefix = tmpPrefix
@@ -62,6 +61,7 @@ def main():
parser.add_argument("-sets", help="Cache sets (if not specified, all cache sets are used)") parser.add_argument("-sets", help="Cache sets (if not specified, all cache sets are used)")
parser.add_argument("-cBox", help="cBox (default: 0)", type=int) parser.add_argument("-cBox", help="cBox (default: 0)", type=int)
parser.add_argument("-nMeasurements", help="Number of measurements", type=int, default=3) parser.add_argument("-nMeasurements", help="Number of measurements", type=int, default=3)
parser.add_argument("-rep", help="Number of repetitions of each experiment (Default: 1)", type=int, default=1)
parser.add_argument("-findCtrEx", help="Tries to find a small counterexample for each policy (only available for deterministic policies)", action='store_true') parser.add_argument("-findCtrEx", help="Tries to find a small counterexample for each policy (only available for deterministic policies)", action='store_true')
parser.add_argument("-policies", help="Comma-separated list of policies to consider (Default: all deterministic policies)") parser.add_argument("-policies", help="Comma-separated list of policies to consider (Default: all deterministic policies)")
parser.add_argument("-best", help="Find the best matching policy (Default: abort if no policy agrees with all results)", action='store_true') parser.add_argument("-best", help="Find the best matching policy (Default: abort if no policy agrees with all results)", action='store_true')
@@ -83,6 +83,8 @@ def main():
elif args.allQLRUVariants: elif args.allQLRUVariants:
policies = sorted(set(cacheSim.CommonPolicies.keys())|set(cacheSim.AllDetQLRUVariants.keys())) policies = sorted(set(cacheSim.CommonPolicies.keys())|set(cacheSim.AllDetQLRUVariants.keys()))
elif args.randPolicies: elif args.randPolicies:
if args.rep > 1:
sys.exit('rep > 1 not supported for random policies')
policies = sorted(cacheSim.AllRandPolicies.keys()) policies = sorted(cacheSim.AllRandPolicies.keys())
if args.assoc: if args.assoc:
@@ -94,8 +96,6 @@ def main():
if args.cBox: if args.cBox:
cBox = args.cBox cBox = args.cBox
setCount = len(parseCacheSetsStr(args.level, True, args.sets))
title = cpuid.cpu_name(cpuid.CPUID()) + ', Level: ' + str(args.level) + (', CBox: ' + str(cBox) if args.cBox else '') title = cpuid.cpu_name(cpuid.CPUID()) + ', Level: ' + str(args.level) + (', CBox: ' + str(cBox) if args.cBox else '')
html = ['<html>', '<head>', '<title>' + title + '</title>', '</head>', '<body>'] html = ['<html>', '<head>', '<title>' + title + '</title>', '</head>', '<body>']
@@ -117,25 +117,27 @@ def main():
print fullSeq print fullSeq
html += ['<tr><td>' + fullSeq + '</td>'] html += ['<tr><td>' + fullSeq + '</td>']
actual = getActualHits(fullSeq, args.level, args.sets, cBox, args.nMeasurements) actualHits = set([getActualHits(fullSeq, args.level, args.sets, cBox, args.nMeasurements) for _ in range(0, args.rep)])
html += ['<td>' + str(actual) + '</td>'] html += ['<td>' + ('{' if len(actualHits) > 1 else '') + ', '.join(map(str, sorted(actualHits))) + ('}' if len(actualHits) > 1 else '') + '</td>']
outp = '' outp = ''
for p in policies: for p in policies:
if not args.randPolicies: if not args.randPolicies:
sim = cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, setCount) sim = cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, args.sets)
if sim != actual: if sim not in actualHits:
possiblePolicies.discard(p) possiblePolicies.discard(p)
dists[p] += 1 dists[p] += 1
color = 'red' color = 'red'
if args.findCtrEx and not p in counterExamples: if args.findCtrEx and not p in counterExamples:
counterExamples[p] = findSmallCounterexample(p, ((args.initSeq + ' ') if args.initSeq else ''), args.level, args.sets, cBox, assoc, seq, counterExamples[p] = findSmallCounterexample(p, ((args.initSeq + ' ') if args.initSeq else ''), args.level, args.sets, cBox, assoc, seq,
args.nMeasurements) args.nMeasurements)
elif len(actualHits) > 1:
color = 'yellow'
else: else:
color = 'green' color = 'green'
else: else:
sim = median(sum(cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, setCount) for _ in range(0, args.nMeasurements))) sim = median(sum(cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, args.sets) for _ in range(0, args.nMeasurements)))
dist = (sim - actual) ** 2 dist = (sim - actual) ** 2
dists[p] += dist dists[p] += dist