added win_unicode_console module
[twitter-api-cdsw] / win_unicode_console / streams.py
diff --git a/win_unicode_console/streams.py b/win_unicode_console/streams.py
new file mode 100644 (file)
index 0000000..be6927d
--- /dev/null
@@ -0,0 +1,260 @@
+
+from ctypes import byref, windll, c_ulong
+
+from win_unicode_console.buffer import get_buffer
+
+import io
+import sys
+import time
+
+
+kernel32 = windll.kernel32
+GetStdHandle = kernel32.GetStdHandle
+ReadConsoleW = kernel32.ReadConsoleW
+WriteConsoleW = kernel32.WriteConsoleW
+GetLastError = kernel32.GetLastError
+
+
+ERROR_SUCCESS = 0
+ERROR_NOT_ENOUGH_MEMORY = 8
+ERROR_OPERATION_ABORTED = 995
+
+STDIN_HANDLE = GetStdHandle(-10)
+STDOUT_HANDLE = GetStdHandle(-11)
+STDERR_HANDLE = GetStdHandle(-12)
+
+STDIN_FILENO = 0
+STDOUT_FILENO = 1
+STDERR_FILENO = 2
+
+EOF = b"\x1a"
+
+MAX_BYTES_WRITTEN = 32767      # arbitrary because WriteConsoleW ability to write big buffers depends on heap usage
+
+
+class ReprMixin:
+       def __repr__(self):
+               modname = self.__class__.__module__
+               clsname = self.__class__.__qualname__
+               attributes = []
+               for name in ["name", "encoding"]:
+                       try:
+                               value = getattr(self, name)
+                       except AttributeError:
+                               pass
+                       else:
+                               attributes.append("{}={}".format(name, repr(value)))
+               
+               return "<{}.{} {}>".format(modname, clsname, " ".join(attributes))
+
+
+class WindowsConsoleRawIOBase(ReprMixin, io.RawIOBase):
+       def __init__(self, name, handle, fileno):
+               self.name = name
+               self.handle = handle
+               self.file_no = fileno
+       
+       def fileno(self):
+               return self.file_no
+       
+       def isatty(self):
+               super().isatty()        # for close check in default implementation
+               return True
+
+class WindowsConsoleRawReader(WindowsConsoleRawIOBase):
+       def readable(self):
+               return True
+       
+       def readinto(self, b):
+               bytes_to_be_read = len(b)
+               if not bytes_to_be_read:
+                       return 0
+               elif bytes_to_be_read % 2:
+                       raise ValueError("cannot read odd number of bytes from UTF-16-LE encoded console")
+               
+               buffer = get_buffer(b, writable=True)
+               code_units_to_be_read = bytes_to_be_read // 2
+               code_units_read = c_ulong()
+               
+               retval = ReadConsoleW(self.handle, buffer, code_units_to_be_read, byref(code_units_read), None)
+               if GetLastError() == ERROR_OPERATION_ABORTED:
+                       time.sleep(0.1) # wait for KeyboardInterrupt
+               if not retval:
+                       raise OSError("Windows error {}".format(GetLastError()))
+               
+               if buffer[0] == EOF:
+                       return 0
+               else:
+                       return 2 * code_units_read.value
+
+class WindowsConsoleRawWriter(WindowsConsoleRawIOBase):
+       def writable(self):
+               return True
+       
+       @staticmethod
+       def _error_message(errno):
+               if errno == ERROR_SUCCESS:
+                       return "Windows error {} (ERROR_SUCCESS); zero bytes written on nonzero input, probably just one byte given".format(errno)
+               elif errno == ERROR_NOT_ENOUGH_MEMORY:
+                       return "Windows error {} (ERROR_NOT_ENOUGH_MEMORY); try to lower `win_unicode_console.streams.MAX_BYTES_WRITTEN`".format(errno)
+               else:
+                       return "Windows error {}".format(errno)
+       
+       def write(self, b):
+               bytes_to_be_written = len(b)
+               buffer = get_buffer(b)
+               code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
+               code_units_written = c_ulong()
+               
+               retval = WriteConsoleW(self.handle, buffer, code_units_to_be_written, byref(code_units_written), None)
+               bytes_written = 2 * code_units_written.value
+               
+               # fixes both infinite loop of io.BufferedWriter.flush() on when the buffer has odd length
+               #       and situation when WriteConsoleW refuses to write lesser that MAX_BYTES_WRITTEN bytes
+               if bytes_written == 0 != bytes_to_be_written:
+                       raise OSError(self._error_message(GetLastError()))
+               else:
+                       return bytes_written
+
+class TextTranscodingWrapper(ReprMixin, io.TextIOBase):
+       encoding = None
+       
+       def __init__(self, base, encoding):
+               self.base = base
+               self.encoding = encoding
+       
+       @property
+       def errors(self):
+               return self.base.errors
+       
+       @property
+       def line_buffering(self):
+               return self.base.line_buffering
+       
+       def seekable(self):
+               return self.base.seekable()
+       
+       def readable(self):
+               return self.base.readable()
+       
+       def writable(self):
+               return self.base.writable()
+       
+       def flush(self):
+               self.base.flush()
+       
+       def close(self):
+               self.base.close()
+       
+       @property
+       def closed(self):
+               return self.base.closed
+       
+       @property
+       def name(self):
+               return self.base.name
+       
+       def fileno(self):
+               return self.base.fileno()
+       
+       def isatty(self):
+               return self.base.isatty()
+       
+       def write(self, s):
+               return self.base.write(s)
+       
+       def tell(self):
+               return self.base.tell()
+       
+       def truncate(self, pos=None):
+               return self.base.truncate(pos)
+       
+       def seek(self, cookie, whence=0):
+               return self.base.seek(cookie, whence)
+       
+       def read(self, size=None):
+               return self.base.read(size)
+       
+       def __next__(self):
+               return next(self.base)
+       
+       def readline(self, size=-1):
+               return self.base.readline(size)
+       
+       @property
+       def newlines(self):
+               return self.base.newlines
+
+
+stdin_raw = WindowsConsoleRawReader("<stdin>", STDIN_HANDLE, STDIN_FILENO)
+stdout_raw = WindowsConsoleRawWriter("<stdout>", STDOUT_HANDLE, STDOUT_FILENO)
+stderr_raw = WindowsConsoleRawWriter("<stderr>", STDERR_HANDLE, STDERR_FILENO)
+
+stdin_text = io.TextIOWrapper(io.BufferedReader(stdin_raw), encoding="utf-16-le", line_buffering=True)
+stdout_text = io.TextIOWrapper(io.BufferedWriter(stdout_raw), encoding="utf-16-le", line_buffering=True)
+stderr_text = io.TextIOWrapper(io.BufferedWriter(stderr_raw), encoding="utf-16-le", line_buffering=True)
+
+stdin_text_transcoded = TextTranscodingWrapper(stdin_text, encoding="utf-8")
+stdout_text_transcoded = TextTranscodingWrapper(stdout_text, encoding="utf-8")
+stderr_text_transcoded = TextTranscodingWrapper(stderr_text, encoding="utf-8")
+
+
+def disable():
+       sys.stdin.flush()
+       sys.stdout.flush()
+       sys.stderr.flush()
+       sys.stdin = sys.__stdin__
+       sys.stdout = sys.__stdout__
+       sys.stderr = sys.__stderr__
+
+def check_stream(stream, fileno):
+       if stream is None:      # e.g. with IDLE
+               return True
+       
+       try:
+               _fileno = stream.fileno()
+       except io.UnsupportedOperation:
+               return False
+       else:
+               if _fileno == fileno and stream.isatty():
+                       stream.flush()
+                       return True
+               else:
+                       return False
+       
+def enable_reader(*, transcode=True):
+               # transcoding because Python tokenizer cannot handle UTF-16
+       if check_stream(sys.stdin, STDIN_FILENO):
+               if transcode:
+                       sys.stdin = stdin_text_transcoded
+               else:
+                       sys.stdin = stdin_text
+
+def enable_writer(*, transcode=True):
+       if check_stream(sys.stdout, STDOUT_FILENO):
+               if transcode:
+                       sys.stdout = stdout_text_transcoded
+               else:
+                       sys.stdout = stdout_text
+
+def enable_error_writer(*, transcode=True):
+       if check_stream(sys.stderr, STDERR_FILENO):
+               if transcode:
+                       sys.stderr = stderr_text_transcoded
+               else:
+                       sys.stderr = stderr_text
+
+enablers = {"stdin": enable_reader, "stdout": enable_writer, "stderr": enable_error_writer}
+
+def enable(streams=("stdin", "stdout", "stderr"), *, transcode=frozenset(enablers.keys())):
+       if transcode is True:
+               transcode = enablers.keys()
+       elif transcode is False:
+               transcode = set()
+       
+       if not set(streams) | set(transcode) <= enablers.keys():
+               raise ValueError("invalid stream names")
+       
+       for stream in streams:
+               enablers[stream](transcode=(stream in transcode))
+

Benjamin Mako Hill || Want to submit a patch?