catscan is now petscan
[harrypotter-wikipedia-cdsw] / win_unicode_console / streams.py
1
2 from ctypes import byref, windll, c_ulong
3
4 from win_unicode_console.buffer import get_buffer
5
6 import io
7 import sys
8 import time
9
10
11 kernel32 = windll.kernel32
12 GetStdHandle = kernel32.GetStdHandle
13 ReadConsoleW = kernel32.ReadConsoleW
14 WriteConsoleW = kernel32.WriteConsoleW
15 GetLastError = kernel32.GetLastError
16
17
18 ERROR_SUCCESS = 0
19 ERROR_NOT_ENOUGH_MEMORY = 8
20 ERROR_OPERATION_ABORTED = 995
21
22 STDIN_HANDLE = GetStdHandle(-10)
23 STDOUT_HANDLE = GetStdHandle(-11)
24 STDERR_HANDLE = GetStdHandle(-12)
25
26 STDIN_FILENO = 0
27 STDOUT_FILENO = 1
28 STDERR_FILENO = 2
29
30 EOF = b"\x1a"
31
32 MAX_BYTES_WRITTEN = 32767       # arbitrary because WriteConsoleW ability to write big buffers depends on heap usage
33
34
35 class ReprMixin:
36         def __repr__(self):
37                 modname = self.__class__.__module__
38                 clsname = self.__class__.__qualname__
39                 attributes = []
40                 for name in ["name", "encoding"]:
41                         try:
42                                 value = getattr(self, name)
43                         except AttributeError:
44                                 pass
45                         else:
46                                 attributes.append("{}={}".format(name, repr(value)))
47                 
48                 return "<{}.{} {}>".format(modname, clsname, " ".join(attributes))
49
50
51 class WindowsConsoleRawIOBase(ReprMixin, io.RawIOBase):
52         def __init__(self, name, handle, fileno):
53                 self.name = name
54                 self.handle = handle
55                 self.file_no = fileno
56         
57         def fileno(self):
58                 return self.file_no
59         
60         def isatty(self):
61                 super().isatty()        # for close check in default implementation
62                 return True
63
64 class WindowsConsoleRawReader(WindowsConsoleRawIOBase):
65         def readable(self):
66                 return True
67         
68         def readinto(self, b):
69                 bytes_to_be_read = len(b)
70                 if not bytes_to_be_read:
71                         return 0
72                 elif bytes_to_be_read % 2:
73                         raise ValueError("cannot read odd number of bytes from UTF-16-LE encoded console")
74                 
75                 buffer = get_buffer(b, writable=True)
76                 code_units_to_be_read = bytes_to_be_read // 2
77                 code_units_read = c_ulong()
78                 
79                 retval = ReadConsoleW(self.handle, buffer, code_units_to_be_read, byref(code_units_read), None)
80                 if GetLastError() == ERROR_OPERATION_ABORTED:
81                         time.sleep(0.1) # wait for KeyboardInterrupt
82                 if not retval:
83                         raise OSError("Windows error {}".format(GetLastError()))
84                 
85                 if buffer[0] == EOF:
86                         return 0
87                 else:
88                         return 2 * code_units_read.value
89
90 class WindowsConsoleRawWriter(WindowsConsoleRawIOBase):
91         def writable(self):
92                 return True
93         
94         @staticmethod
95         def _error_message(errno):
96                 if errno == ERROR_SUCCESS:
97                         return "Windows error {} (ERROR_SUCCESS); zero bytes written on nonzero input, probably just one byte given".format(errno)
98                 elif errno == ERROR_NOT_ENOUGH_MEMORY:
99                         return "Windows error {} (ERROR_NOT_ENOUGH_MEMORY); try to lower `win_unicode_console.streams.MAX_BYTES_WRITTEN`".format(errno)
100                 else:
101                         return "Windows error {}".format(errno)
102         
103         def write(self, b):
104                 bytes_to_be_written = len(b)
105                 buffer = get_buffer(b)
106                 code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
107                 code_units_written = c_ulong()
108                 
109                 retval = WriteConsoleW(self.handle, buffer, code_units_to_be_written, byref(code_units_written), None)
110                 bytes_written = 2 * code_units_written.value
111                 
112                 # fixes both infinite loop of io.BufferedWriter.flush() on when the buffer has odd length
113                 #       and situation when WriteConsoleW refuses to write lesser that MAX_BYTES_WRITTEN bytes
114                 if bytes_written == 0 != bytes_to_be_written:
115                         raise OSError(self._error_message(GetLastError()))
116                 else:
117                         return bytes_written
118
119 class TextTranscodingWrapper(ReprMixin, io.TextIOBase):
120         encoding = None
121         
122         def __init__(self, base, encoding):
123                 self.base = base
124                 self.encoding = encoding
125         
126         @property
127         def errors(self):
128                 return self.base.errors
129         
130         @property
131         def line_buffering(self):
132                 return self.base.line_buffering
133         
134         def seekable(self):
135                 return self.base.seekable()
136         
137         def readable(self):
138                 return self.base.readable()
139         
140         def writable(self):
141                 return self.base.writable()
142         
143         def flush(self):
144                 self.base.flush()
145         
146         def close(self):
147                 self.base.close()
148         
149         @property
150         def closed(self):
151                 return self.base.closed
152         
153         @property
154         def name(self):
155                 return self.base.name
156         
157         def fileno(self):
158                 return self.base.fileno()
159         
160         def isatty(self):
161                 return self.base.isatty()
162         
163         def write(self, s):
164                 return self.base.write(s)
165         
166         def tell(self):
167                 return self.base.tell()
168         
169         def truncate(self, pos=None):
170                 return self.base.truncate(pos)
171         
172         def seek(self, cookie, whence=0):
173                 return self.base.seek(cookie, whence)
174         
175         def read(self, size=None):
176                 return self.base.read(size)
177         
178         def __next__(self):
179                 return next(self.base)
180         
181         def readline(self, size=-1):
182                 return self.base.readline(size)
183         
184         @property
185         def newlines(self):
186                 return self.base.newlines
187
188
189 stdin_raw = WindowsConsoleRawReader("<stdin>", STDIN_HANDLE, STDIN_FILENO)
190 stdout_raw = WindowsConsoleRawWriter("<stdout>", STDOUT_HANDLE, STDOUT_FILENO)
191 stderr_raw = WindowsConsoleRawWriter("<stderr>", STDERR_HANDLE, STDERR_FILENO)
192
193 stdin_text = io.TextIOWrapper(io.BufferedReader(stdin_raw), encoding="utf-16-le", line_buffering=True)
194 stdout_text = io.TextIOWrapper(io.BufferedWriter(stdout_raw), encoding="utf-16-le", line_buffering=True)
195 stderr_text = io.TextIOWrapper(io.BufferedWriter(stderr_raw), encoding="utf-16-le", line_buffering=True)
196
197 stdin_text_transcoded = TextTranscodingWrapper(stdin_text, encoding="utf-8")
198 stdout_text_transcoded = TextTranscodingWrapper(stdout_text, encoding="utf-8")
199 stderr_text_transcoded = TextTranscodingWrapper(stderr_text, encoding="utf-8")
200
201
202 def disable():
203         sys.stdin.flush()
204         sys.stdout.flush()
205         sys.stderr.flush()
206         sys.stdin = sys.__stdin__
207         sys.stdout = sys.__stdout__
208         sys.stderr = sys.__stderr__
209
210 def check_stream(stream, fileno):
211         if stream is None:      # e.g. with IDLE
212                 return True
213         
214         try:
215                 _fileno = stream.fileno()
216         except io.UnsupportedOperation:
217                 return False
218         else:
219                 if _fileno == fileno and stream.isatty():
220                         stream.flush()
221                         return True
222                 else:
223                         return False
224         
225 def enable_reader(*, transcode=True):
226                 # transcoding because Python tokenizer cannot handle UTF-16
227         if check_stream(sys.stdin, STDIN_FILENO):
228                 if transcode:
229                         sys.stdin = stdin_text_transcoded
230                 else:
231                         sys.stdin = stdin_text
232
233 def enable_writer(*, transcode=True):
234         if check_stream(sys.stdout, STDOUT_FILENO):
235                 if transcode:
236                         sys.stdout = stdout_text_transcoded
237                 else:
238                         sys.stdout = stdout_text
239
240 def enable_error_writer(*, transcode=True):
241         if check_stream(sys.stderr, STDERR_FILENO):
242                 if transcode:
243                         sys.stderr = stderr_text_transcoded
244                 else:
245                         sys.stderr = stderr_text
246
247 enablers = {"stdin": enable_reader, "stdout": enable_writer, "stderr": enable_error_writer}
248
249 def enable(streams=("stdin", "stdout", "stderr"), *, transcode=frozenset(enablers.keys())):
250         if transcode is True:
251                 transcode = enablers.keys()
252         elif transcode is False:
253                 transcode = set()
254         
255         if not set(streams) | set(transcode) <= enablers.keys():
256                 raise ValueError("invalid stream names")
257         
258         for stream in streams:
259                 enablers[stream](transcode=(stream in transcode))
260

Benjamin Mako Hill || Want to submit a patch?