support for overriding cache sets

This commit is contained in:
Andreas Abel
2019-12-30 22:34:38 +01:00
parent e3629ead9b
commit d70535f210
5 changed files with 117 additions and 45 deletions

View File

@@ -383,9 +383,16 @@ def getAddresses(level, wayID, cacheSetList, cBox=1, clearHL=True):
raise ValueError('invalid level')
# removes ?s and !s
# removes ?s and !s, and returns the part before the first '_'
def getBlockName(blockStr):
return re.sub('[?!]', '', blockStr)
return re.sub('[?!]', '', blockStr.split('_')[0])
# removes ?s and !s, and returns the part after the last '_' (as int); returns None if there is no '_'
def getBlockSet(blockStr):
if not '_' in blockStr:
return None
return int(re.match('\d+', blockStr.split('_')[-1]).group())
def parseCacheSetsStr(level, clearHL, cacheSetsStr):
@@ -399,7 +406,7 @@ def parseCacheSetsStr(level, clearHL, cacheSetsStr):
cacheSetList.append(int(s))
else:
nSets = getCacheInfo(level).nSets
if level > 1 and clearHL:
if level > 1 and clearHL and not (level == 3 and getCacheInfo(3).nSlices is not None):
nHLSets = getCacheInfo(level-1).nSets
cacheSetList = range(nHLSets, nSets)
else:
@@ -407,19 +414,34 @@ def parseCacheSetsStr(level, clearHL, cacheSetsStr):
return cacheSetList
def findCacheSetForCode(cacheSetList, level):
nSets = getCacheInfo(level).nSets
sortedCacheSetList = sorted(cacheSetList)
sortedCacheSetList += [sortedCacheSetList[0] + nSets]
maxDist = 1
bestSet = 0
for i in range(len(sortedCacheSetList)-1):
dist = sortedCacheSetList[i+1] - sortedCacheSetList[i]
if dist > maxDist:
maxDist = dist
bestSet = (sortedCacheSetList[i] + 1) % nSets
return bestSet
def getAllUsedCacheSets(cacheSetList, seq, initSeq=''):
cacheSetOverrideList = [s for s in set(map(getBlockSet, initSeq.split()+seq.split())) if s is not None]
return sorted(set(cacheSetList + cacheSetOverrideList))
AddressList = namedtuple('AddressList', 'addresses exclude flush wbinvd')
# cacheSets=None means do access in all sets
# in this case, the first nL1Sets many sets of L2 will be reserved for clearing L1
# if wbinvd is set, wbinvd will be called before initSeq
def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=True, loop=1, wbinvd=False, nMeasurements=10, warmUpCount=1, agg='avg'):
lineSize = getCacheInfo(1).lineSize
cacheSetList = parseCacheSetsStr(level, clearHL, cacheSets)
def getCodeForCacheExperiment(level, seq, initSeq, cacheSetList, cBox, clearHL, wbinvd):
allUsedSets = getAllUsedCacheSets(cacheSetList, seq, initSeq)
clearHLAddrList = None
if (clearHL and level > 1):
clearHLAddrList = AddressList(getClearHLAddresses(level, cacheSetList, cBox), True, False, False)
clearHLAddrList = AddressList(getClearHLAddresses(level, allUsedSets, cBox), True, False, False)
initAddressLists = []
seqAddressLists = []
@@ -432,29 +454,49 @@ def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=T
addrLists.append(AddressList([], True, False, True))
continue
overrideSet = getBlockSet(seqEl)
wayID = nameToID.setdefault(name, len(nameToID))
exclude = not '?' in seqEl
flush = '!' in seqEl
addresses = getAddresses(level, wayID, cacheSetList, cBox=cBox, clearHL=clearHL)
s = [overrideSet] if overrideSet is not None else cacheSetList
addresses = getAddresses(level, wayID, s, cBox=cBox, clearHL=clearHL)
if clearHLAddrList is not None and not flush:
addrLists.append(clearHLAddrList)
addrLists.append(AddressList(addresses, exclude, flush, False))
ec = getCodeForAddressLists(seqAddressLists, initAddressLists, wbinvd)
log.debug('\nInitAddresses: ' + str(initAddressLists))
log.debug('\nSeqAddresses: ' + str(seqAddressLists))
return getCodeForAddressLists(seqAddressLists, initAddressLists, wbinvd)
def runCacheExperimentCode(code, initCode, oneTimeInitCode, loop, warmUpCount, codeOffset, nMeasurements, agg):
resetNanoBench()
setNanoBenchParameters(config=getDefaultCacheConfig(), msrConfig=getDefaultCacheMSRConfig(), nMeasurements=nMeasurements, unrollCount=1, loopCount=loop,
warmUpCount=warmUpCount, aggregateFunction=agg, basicMode=True, noMem=True, codeOffset=codeOffset, verbose=None)
return runNanoBench(code=code, init=initCode, oneTimeInit=oneTimeInitCode)
# cacheSets=None means do access in all sets
# in this case, the first nL1Sets many sets of L2 will be reserved for clearing L1
# if wbinvd is set, wbinvd will be called before initSeq
def runCacheExperiment(level, seq, initSeq='', cacheSets=None, cBox=1, clearHL=True, loop=1, wbinvd=False, nMeasurements=10, warmUpCount=1, codeSet=None,
agg='avg'):
cacheSetList = parseCacheSetsStr(level, clearHL, cacheSets)
ec = getCodeForCacheExperiment(level, seq, initSeq=initSeq, cacheSetList=cacheSetList, cBox=cBox, clearHL=clearHL, wbinvd=wbinvd)
log.debug('\nOneTimeInit: ' + ec.oneTimeInit)
log.debug('\nInit: ' + ec.init)
log.debug('\nCode: ' + ec.code)
resetNanoBench()
setNanoBenchParameters(config=getDefaultCacheConfig(), msrConfig=getDefaultCacheMSRConfig(), nMeasurements=nMeasurements, unrollCount=1, loopCount=loop,
warmUpCount=warmUpCount, aggregateFunction=agg, basicMode=True, noMem=True, verbose=None)
lineSize = getCacheInfo(1).lineSize
allUsedSets = getAllUsedCacheSets(cacheSetList, seq, initSeq)
codeOffset = lineSize * (codeSet if codeSet is not None else findCacheSetForCode(allUsedSets, level))
return runNanoBench(code=ec.code, init=ec.init, oneTimeInit=ec.oneTimeInit)
return runCacheExperimentCode(ec.code, ec.init, ec.oneTimeInit, loop, warmUpCount, codeOffset, nMeasurements, agg)
def printNB(nb_result):
@@ -603,7 +645,7 @@ def getAgesOfBlocks(blocks, level, seq, initSeq='', maxAge=None, cacheSets=None,
newBlocks = getUnusedBlockNames(nNewBlocks, seq+initSeq, 'N')
curSeq += ' '.join(newBlocks) + ' ' + block + '?'
nb = runCacheExperiment(level, curSeq, initSeq=initSeq, cacheSets=cacheSets, cBox=cBox, clearHL=clearHL, loop=0, wbinvd=wbinvd, nMeasurements=nMeasurements)
nb = runCacheExperiment(level, curSeq, initSeq=initSeq, cacheSets=cacheSets, cBox=cBox, clearHL=clearHL, loop=0, wbinvd=wbinvd, nMeasurements=nMeasurements, agg=agg)
if returnNbResults: nbResults[block].append(nb)
hitEvent = 'L' + str(level) + '_HIT'

View File

@@ -33,9 +33,8 @@ def main():
if args.sim:
policyClass = cacheSim.AllPolicies[args.sim]
setCount = len(parseCacheSetsStr(args.level, (not args.noClearHL), args.sets))
seq = args.seq_init + (' ' + args.seq) * args.loop
hits = cacheSim.getHits(seq, policyClass, args.simAssoc, setCount) / args.loop
hits = cacheSim.getHits(seq, policyClass, args.simAssoc, args.sets) / args.loop
print 'Hits: ' + str(hits)
else:
nb = runCacheExperiment(args.level, args.seq, initSeq=args.seq_init, cacheSets=args.sets, cBox=args.cBox, clearHL=(not args.noClearHL), loop=args.loop,

View File

@@ -131,6 +131,12 @@ class LRU_PLRU4Sim(ReplPolicySim):
self.PLRUOrdered = [curPLRU] + [plru for plru in self.PLRUOrdered if plru!=curPLRU]
return hit
def flush(self, block):
for plru in self.PLRUs:
if block in plru.blocks:
plru.flush(block)
class QLRUSim(ReplPolicySim):
def __init__(self, assoc, hitFunc, missFunc, replIdxFunc, updFunc, updOnMissOnly=False):
super(QLRUSim, self).__init__(assoc)
@@ -298,18 +304,41 @@ AllRandPolicies = dict(AllRandQLRUVariants.items() + AllRandPLRUVariants.items()
AllPolicies = dict(AllDetPolicies.items() + AllRandPolicies.items())
def getHits(seq, policySimClass, assoc, nSets):
hits = 0
policySims = [policySimClass(assoc) for _ in range(0, nSets)]
def parseCacheSetsStrSim(cacheSetsStr):
if cacheSetsStr is None:
raise ValueError('no cache sets specified')
cacheSetList = []
for s in cacheSetsStr.split(','):
if '-' in s:
first, last = s.split('-')[:2]
cacheSetList += range(int(first), int(last)+1)
else:
cacheSetList.append(int(s))
return cacheSetList
def getHits(seq, policySimClass, assoc, cacheSetStr):
cacheSetList = parseCacheSetsStrSim(cacheSetStr)
allUsedSets = getAllUsedCacheSets(cacheSetList, seq)
policySims = {s: policySimClass(assoc) for s in allUsedSets}
hits = 0
for blockStr in seq.split():
blockName = getBlockName(blockStr)
if '!' in blockStr:
for policySim in policySims:
policySim.flush(blockName)
else:
for policySim in policySims:
hit = policySim.acc(blockName)
if blockName == '<wbinvd>':
policySims = {s: policySimClass(assoc) for s in allUsedSets}
continue
overrideSet = getBlockSet(blockStr)
sets = [overrideSet] if overrideSet is not None else cacheSetList
for s in sets:
if '!' in blockStr:
policySims[s].flush(blockName)
else:
hit = policySims[s].acc(blockName)
if '?' in blockStr:
hits += int(hit)
return hits
@@ -320,7 +349,7 @@ def getAges(blocks, seq, policySimClass, assoc):
for block in blocks:
for i in count(0):
curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?'
if getHits(curSeq, policySimClass, assoc, 1) == 0:
if getHits(curSeq, policySimClass, assoc, '0') == 0:
ages[block] = i
break
return ages
@@ -332,7 +361,7 @@ def getGraph(blocks, seq, policySimClass, assoc, maxAge, nSets=1, nRep=1, agg="m
trace = []
for i in range(0, maxAge):
curSeq = seq + ' ' + ' '.join('N' + str(n) for n in range(0,i)) + ' ' + block + '?'
hits = [getHits(curSeq, policySimClass, assoc, nSets) for _ in range(0, nRep)]
hits = [getHits(curSeq, policySimClass, assoc, '0-'+str(nSets-1)) for _ in range(0, nRep)]
if agg == "med":
aggValue = median(hits)
elif agg == "min":

View File

@@ -27,9 +27,9 @@ def main():
logging.basicConfig(stream=sys.stdout, format='%(message)s', level=logging.getLevelName(args.logLevel))
if args.sim:
policyClass = cacheSim.Policies[args.sim]
policyClass = cacheSim.AllPolicies[args.sim]
seq = re.sub('[?!]', '', ' '.join([args.seq_init, args.seq])).strip() + '?'
hits = cacheSim.getHits(policyClass(args.simAssoc), seq)
hits = cacheSim.getHits(seq, policyClass, args.simAssoc, args.sets)
if hits > 0:
print 'HIT'
exit(1)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/python
import argparse
import random
import sys
from numpy import median
@@ -17,13 +18,11 @@ def getActualHits(seq, level, cacheSets, cBox, nMeasurements=10):
def findSmallCounterexample(policy, initSeq, level, sets, cBox, assoc, seq, nMeasurements):
setCount = len(parseCacheSetsStr(level, True, sets))
seqSplit = seq.split()
for seqPrefix in [seqSplit[:i] for i in range(assoc+1, len(seqSplit)+1)]:
seq = initSeq + ' '.join(seqPrefix)
actual = getActualHits(seq, level, sets, cBox, nMeasurements)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, setCount)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, sets)
print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim)
if sim != actual:
break
@@ -32,7 +31,7 @@ def findSmallCounterexample(policy, initSeq, level, sets, cBox, assoc, seq, nMea
tmpPrefix = seqPrefix[:i] + seqPrefix[(i+1):]
seq = initSeq + ' '.join(tmpPrefix)
actual = getActualHits(seq, level, sets, cBox, nMeasurements)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, setCount)
sim = cacheSim.getHits(seq, cacheSim.AllPolicies[policy], assoc, sets)
print 'seq:' + seq + ', actual: ' + str(actual) + ', sim: ' + str(sim)
if sim != actual:
seqPrefix = tmpPrefix
@@ -62,6 +61,7 @@ def main():
parser.add_argument("-sets", help="Cache sets (if not specified, all cache sets are used)")
parser.add_argument("-cBox", help="cBox (default: 0)", type=int)
parser.add_argument("-nMeasurements", help="Number of measurements", type=int, default=3)
parser.add_argument("-rep", help="Number of repetitions of each experiment (Default: 1)", type=int, default=1)
parser.add_argument("-findCtrEx", help="Tries to find a small counterexample for each policy (only available for deterministic policies)", action='store_true')
parser.add_argument("-policies", help="Comma-separated list of policies to consider (Default: all deterministic policies)")
parser.add_argument("-best", help="Find the best matching policy (Default: abort if no policy agrees with all results)", action='store_true')
@@ -83,6 +83,8 @@ def main():
elif args.allQLRUVariants:
policies = sorted(set(cacheSim.CommonPolicies.keys())|set(cacheSim.AllDetQLRUVariants.keys()))
elif args.randPolicies:
if args.rep > 1:
sys.exit('rep > 1 not supported for random policies')
policies = sorted(cacheSim.AllRandPolicies.keys())
if args.assoc:
@@ -94,8 +96,6 @@ def main():
if args.cBox:
cBox = args.cBox
setCount = len(parseCacheSetsStr(args.level, True, args.sets))
title = cpuid.cpu_name(cpuid.CPUID()) + ', Level: ' + str(args.level) + (', CBox: ' + str(cBox) if args.cBox else '')
html = ['<html>', '<head>', '<title>' + title + '</title>', '</head>', '<body>']
@@ -117,25 +117,27 @@ def main():
print fullSeq
html += ['<tr><td>' + fullSeq + '</td>']
actual = getActualHits(fullSeq, args.level, args.sets, cBox, args.nMeasurements)
html += ['<td>' + str(actual) + '</td>']
actualHits = set([getActualHits(fullSeq, args.level, args.sets, cBox, args.nMeasurements) for _ in range(0, args.rep)])
html += ['<td>' + ('{' if len(actualHits) > 1 else '') + ', '.join(map(str, sorted(actualHits))) + ('}' if len(actualHits) > 1 else '') + '</td>']
outp = ''
for p in policies:
if not args.randPolicies:
sim = cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, setCount)
sim = cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, args.sets)
if sim != actual:
if sim not in actualHits:
possiblePolicies.discard(p)
dists[p] += 1
color = 'red'
if args.findCtrEx and not p in counterExamples:
counterExamples[p] = findSmallCounterexample(p, ((args.initSeq + ' ') if args.initSeq else ''), args.level, args.sets, cBox, assoc, seq,
args.nMeasurements)
elif len(actualHits) > 1:
color = 'yellow'
else:
color = 'green'
else:
sim = median(sum(cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, setCount) for _ in range(0, args.nMeasurements)))
sim = median(sum(cacheSim.getHits(fullSeq, cacheSim.AllPolicies[p], assoc, args.sets) for _ in range(0, args.nMeasurements)))
dist = (sim - actual) ** 2
dists[p] += dist