Ensure needed imports from typing are included in type-stubs

Leverages the `writeSection` machinery, with a tweak to specify to add a
new section to the beginning of a file, after the header. This ensures
the required imports gets updated (and also only imported once per file)
if new imports are needed for type-hints. Hint: there's a few more to come.
This commit is contained in:
lojack5
2023-10-17 16:01:10 -06:00
parent d303548d43
commit a28de82bbb

View File

@@ -77,6 +77,12 @@ header_pyi = """\
"""
typing_imports = """\
from __future__ import annotations
from typing import Any
"""
#---------------------------------------------------------------------------
def piIgnored(obj):
@@ -112,18 +118,21 @@ class PiWrapperGenerator(generators.WrapperGeneratorBase, FixWxPrefix):
if not SKIP_PI_FILE:
_checkAndWriteHeader(destFile_pi, header_pi, module.docstring)
self.writeSection(destFile_pi, 'typing-imports', typing_imports, at_end=False)
self.writeSection(destFile_pi, module.name, stream.getvalue())
if not SKIP_PYI_FILE:
_checkAndWriteHeader(destFile_pyi, header_pyi, module.docstring)
self.writeSection(destFile_pyi, 'typing-imports', typing_imports, at_end=False)
self.writeSection(destFile_pyi, module.name, stream.getvalue())
def writeSection(self, destFile, sectionName, sectionText):
def writeSection(self, destFile, sectionName, sectionText, at_end = True):
"""
Read all the lines from destFile, remove those currently between
begin/end markers for sectionName (if any), and write the lines back
to the file with the new text in sectionText.
`at_end` determines where in the file the section is added when missing
"""
sectionBeginLine = -1
sectionEndLine = -1
@@ -139,10 +148,23 @@ class PiWrapperGenerator(generators.WrapperGeneratorBase, FixWxPrefix):
sectionEndLine = idx
if sectionBeginLine == -1:
# not there already, add to the end
lines.append(sectionBeginMarker + '\n')
lines.append(sectionText)
lines.append(sectionEndMarker + '\n')
if at_end:
# not there already, add to the end
lines.append(sectionBeginMarker + '\n')
lines.append(sectionText)
lines.append(sectionEndMarker + '\n')
else:
# not there already, add to the beginning
# Skip the header
idx = 0
for idx, line in enumerate(lines):
if not line.startswith('#'):
break
lines[idx+1:idx+1] = [
sectionBeginMarker + '\n',
sectionText,
sectionEndMarker + '\n',
]
else:
# replace the existing lines
lines[sectionBeginLine+1:sectionEndLine] = [sectionText]