diff --git a/etgtools/pi_generator.py b/etgtools/pi_generator.py index 42a540f3..4845969c 100644 --- a/etgtools/pi_generator.py +++ b/etgtools/pi_generator.py @@ -77,6 +77,12 @@ header_pyi = """\ """ +typing_imports = """\ +from __future__ import annotations +from typing import Any + +""" + #--------------------------------------------------------------------------- def piIgnored(obj): @@ -112,18 +118,21 @@ class PiWrapperGenerator(generators.WrapperGeneratorBase, FixWxPrefix): if not SKIP_PI_FILE: _checkAndWriteHeader(destFile_pi, header_pi, module.docstring) + self.writeSection(destFile_pi, 'typing-imports', typing_imports, at_end=False) self.writeSection(destFile_pi, module.name, stream.getvalue()) if not SKIP_PYI_FILE: _checkAndWriteHeader(destFile_pyi, header_pyi, module.docstring) + self.writeSection(destFile_pyi, 'typing-imports', typing_imports, at_end=False) self.writeSection(destFile_pyi, module.name, stream.getvalue()) - def writeSection(self, destFile, sectionName, sectionText): + def writeSection(self, destFile, sectionName, sectionText, at_end = True): """ Read all the lines from destFile, remove those currently between begin/end markers for sectionName (if any), and write the lines back to the file with the new text in sectionText. + `at_end` determines where in the file the section is added when missing """ sectionBeginLine = -1 sectionEndLine = -1 @@ -139,10 +148,23 @@ class PiWrapperGenerator(generators.WrapperGeneratorBase, FixWxPrefix): sectionEndLine = idx if sectionBeginLine == -1: - # not there already, add to the end - lines.append(sectionBeginMarker + '\n') - lines.append(sectionText) - lines.append(sectionEndMarker + '\n') + if at_end: + # not there already, add to the end + lines.append(sectionBeginMarker + '\n') + lines.append(sectionText) + lines.append(sectionEndMarker + '\n') + else: + # not there already, add to the beginning + # Skip the header + idx = 0 + for idx, line in enumerate(lines): + if not line.startswith('#'): + break + lines[idx+1:idx+1] = [ + sectionBeginMarker + '\n', + sectionText, + sectionEndMarker + '\n', + ] else: # replace the existing lines lines[sectionBeginLine+1:sectionEndLine] = [sectionText]