X-Git-Url: https://projects.mako.cc/source/twitter-api-cdsw/blobdiff_plain/90a7a18fb765fb79af12bf4e694573fef9fd653a..c09c5f5d95d501c31271233a3c34b88daee6a997:/win_unicode_console/streams.py diff --git a/win_unicode_console/streams.py b/win_unicode_console/streams.py new file mode 100644 index 0000000..be6927d --- /dev/null +++ b/win_unicode_console/streams.py @@ -0,0 +1,260 @@ + +from ctypes import byref, windll, c_ulong + +from win_unicode_console.buffer import get_buffer + +import io +import sys +import time + + +kernel32 = windll.kernel32 +GetStdHandle = kernel32.GetStdHandle +ReadConsoleW = kernel32.ReadConsoleW +WriteConsoleW = kernel32.WriteConsoleW +GetLastError = kernel32.GetLastError + + +ERROR_SUCCESS = 0 +ERROR_NOT_ENOUGH_MEMORY = 8 +ERROR_OPERATION_ABORTED = 995 + +STDIN_HANDLE = GetStdHandle(-10) +STDOUT_HANDLE = GetStdHandle(-11) +STDERR_HANDLE = GetStdHandle(-12) + +STDIN_FILENO = 0 +STDOUT_FILENO = 1 +STDERR_FILENO = 2 + +EOF = b"\x1a" + +MAX_BYTES_WRITTEN = 32767 # arbitrary because WriteConsoleW ability to write big buffers depends on heap usage + + +class ReprMixin: + def __repr__(self): + modname = self.__class__.__module__ + clsname = self.__class__.__qualname__ + attributes = [] + for name in ["name", "encoding"]: + try: + value = getattr(self, name) + except AttributeError: + pass + else: + attributes.append("{}={}".format(name, repr(value))) + + return "<{}.{} {}>".format(modname, clsname, " ".join(attributes)) + + +class WindowsConsoleRawIOBase(ReprMixin, io.RawIOBase): + def __init__(self, name, handle, fileno): + self.name = name + self.handle = handle + self.file_no = fileno + + def fileno(self): + return self.file_no + + def isatty(self): + super().isatty() # for close check in default implementation + return True + +class WindowsConsoleRawReader(WindowsConsoleRawIOBase): + def readable(self): + return True + + def readinto(self, b): + bytes_to_be_read = len(b) + if not bytes_to_be_read: + return 0 + elif bytes_to_be_read % 2: + raise ValueError("cannot read odd number of bytes from UTF-16-LE encoded console") + + buffer = get_buffer(b, writable=True) + code_units_to_be_read = bytes_to_be_read // 2 + code_units_read = c_ulong() + + retval = ReadConsoleW(self.handle, buffer, code_units_to_be_read, byref(code_units_read), None) + if GetLastError() == ERROR_OPERATION_ABORTED: + time.sleep(0.1) # wait for KeyboardInterrupt + if not retval: + raise OSError("Windows error {}".format(GetLastError())) + + if buffer[0] == EOF: + return 0 + else: + return 2 * code_units_read.value + +class WindowsConsoleRawWriter(WindowsConsoleRawIOBase): + def writable(self): + return True + + @staticmethod + def _error_message(errno): + if errno == ERROR_SUCCESS: + return "Windows error {} (ERROR_SUCCESS); zero bytes written on nonzero input, probably just one byte given".format(errno) + elif errno == ERROR_NOT_ENOUGH_MEMORY: + return "Windows error {} (ERROR_NOT_ENOUGH_MEMORY); try to lower `win_unicode_console.streams.MAX_BYTES_WRITTEN`".format(errno) + else: + return "Windows error {}".format(errno) + + def write(self, b): + bytes_to_be_written = len(b) + buffer = get_buffer(b) + code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 + code_units_written = c_ulong() + + retval = WriteConsoleW(self.handle, buffer, code_units_to_be_written, byref(code_units_written), None) + bytes_written = 2 * code_units_written.value + + # fixes both infinite loop of io.BufferedWriter.flush() on when the buffer has odd length + # and situation when WriteConsoleW refuses to write lesser that MAX_BYTES_WRITTEN bytes + if bytes_written == 0 != bytes_to_be_written: + raise OSError(self._error_message(GetLastError())) + else: + return bytes_written + +class TextTranscodingWrapper(ReprMixin, io.TextIOBase): + encoding = None + + def __init__(self, base, encoding): + self.base = base + self.encoding = encoding + + @property + def errors(self): + return self.base.errors + + @property + def line_buffering(self): + return self.base.line_buffering + + def seekable(self): + return self.base.seekable() + + def readable(self): + return self.base.readable() + + def writable(self): + return self.base.writable() + + def flush(self): + self.base.flush() + + def close(self): + self.base.close() + + @property + def closed(self): + return self.base.closed + + @property + def name(self): + return self.base.name + + def fileno(self): + return self.base.fileno() + + def isatty(self): + return self.base.isatty() + + def write(self, s): + return self.base.write(s) + + def tell(self): + return self.base.tell() + + def truncate(self, pos=None): + return self.base.truncate(pos) + + def seek(self, cookie, whence=0): + return self.base.seek(cookie, whence) + + def read(self, size=None): + return self.base.read(size) + + def __next__(self): + return next(self.base) + + def readline(self, size=-1): + return self.base.readline(size) + + @property + def newlines(self): + return self.base.newlines + + +stdin_raw = WindowsConsoleRawReader("", STDIN_HANDLE, STDIN_FILENO) +stdout_raw = WindowsConsoleRawWriter("", STDOUT_HANDLE, STDOUT_FILENO) +stderr_raw = WindowsConsoleRawWriter("", STDERR_HANDLE, STDERR_FILENO) + +stdin_text = io.TextIOWrapper(io.BufferedReader(stdin_raw), encoding="utf-16-le", line_buffering=True) +stdout_text = io.TextIOWrapper(io.BufferedWriter(stdout_raw), encoding="utf-16-le", line_buffering=True) +stderr_text = io.TextIOWrapper(io.BufferedWriter(stderr_raw), encoding="utf-16-le", line_buffering=True) + +stdin_text_transcoded = TextTranscodingWrapper(stdin_text, encoding="utf-8") +stdout_text_transcoded = TextTranscodingWrapper(stdout_text, encoding="utf-8") +stderr_text_transcoded = TextTranscodingWrapper(stderr_text, encoding="utf-8") + + +def disable(): + sys.stdin.flush() + sys.stdout.flush() + sys.stderr.flush() + sys.stdin = sys.__stdin__ + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + +def check_stream(stream, fileno): + if stream is None: # e.g. with IDLE + return True + + try: + _fileno = stream.fileno() + except io.UnsupportedOperation: + return False + else: + if _fileno == fileno and stream.isatty(): + stream.flush() + return True + else: + return False + +def enable_reader(*, transcode=True): + # transcoding because Python tokenizer cannot handle UTF-16 + if check_stream(sys.stdin, STDIN_FILENO): + if transcode: + sys.stdin = stdin_text_transcoded + else: + sys.stdin = stdin_text + +def enable_writer(*, transcode=True): + if check_stream(sys.stdout, STDOUT_FILENO): + if transcode: + sys.stdout = stdout_text_transcoded + else: + sys.stdout = stdout_text + +def enable_error_writer(*, transcode=True): + if check_stream(sys.stderr, STDERR_FILENO): + if transcode: + sys.stderr = stderr_text_transcoded + else: + sys.stderr = stderr_text + +enablers = {"stdin": enable_reader, "stdout": enable_writer, "stderr": enable_error_writer} + +def enable(streams=("stdin", "stdout", "stderr"), *, transcode=frozenset(enablers.keys())): + if transcode is True: + transcode = enablers.keys() + elif transcode is False: + transcode = set() + + if not set(streams) | set(transcode) <= enablers.keys(): + raise ValueError("invalid stream names") + + for stream in streams: + enablers[stream](transcode=(stream in transcode)) +