Handle content-type header charset value for streaming API
[twitter-api-cdsw] / tweepy / streaming.py
index faf42eacfa3724e73d2c1cbead877e53dba00deb..ad7944c4a3aee981623949b66fff8a75a989e7ea 100644 (file)
@@ -7,6 +7,7 @@
 from __future__ import absolute_import, print_function
 
 import logging
+import re
 import requests
 from requests.exceptions import Timeout
 from threading import Thread
@@ -148,19 +149,26 @@ class ReadBuffer(object):
     use small chunks so it can read the length and the tweet in 2 read calls.
     """
 
-    def __init__(self, stream, chunk_size):
+    def __init__(self, stream, chunk_size, encoding='utf-8'):
         self._stream = stream
-        self._buffer = u""
+        self._buffer = six.b('')
         self._chunk_size = chunk_size
+        self._encoding = encoding
 
     def read_len(self, length):
         while not self._stream.closed:
             if len(self._buffer) >= length:
                 return self._pop(length)
             read_len = max(self._chunk_size, length - len(self._buffer))
-            self._buffer += self._stream.read(read_len).decode("ascii")
+            self._buffer += self._stream.read(read_len)
 
-    def read_line(self, sep='\n'):
+    def read_line(self, sep=six.b('\n')):
+        """Read the data stream until a given separator is found (default \n)
+
+        :param sep: Separator to read until. Must by of the bytes type (str in python 2,
+            bytes in python 3)
+        :return: The str of the data read until sep
+        """
         start = 0
         while not self._stream.closed:
             loc = self._buffer.find(sep, start)
@@ -168,12 +176,12 @@ class ReadBuffer(object):
                 return self._pop(loc + len(sep))
             else:
                 start = len(self._buffer)
-            self._buffer += self._stream.read(self._chunk_size).decode("ascii")
+            self._buffer += self._stream.read(self._chunk_size)
 
     def _pop(self, length):
         r = self._buffer[:length]
         self._buffer = self._buffer[length:]
-        return r
+        return r.decode(self._encoding)
 
 
 class Stream(object):
@@ -268,10 +276,10 @@ class Stream(object):
                 sleep(self.snooze_time)
                 self.snooze_time = min(self.snooze_time + self.snooze_time_step,
                                        self.snooze_time_cap)
-            except Exception as exc:
-                exception = exc
-                # any other exception is fatal, so kill loop
-                break
+            except Exception as exc:
+                exception = exc
+                # any other exception is fatal, so kill loop
+                break
 
         # cleanup
         self.running = False
@@ -283,14 +291,21 @@ class Stream(object):
         if exception:
             # call a handler first so that the exception can be logged.
             self.listener.on_exception(exception)
-            raise
+            raise exception
 
     def _data(self, data):
         if self.listener.on_data(data) is False:
             self.running = False
 
     def _read_loop(self, resp):
-        buf = ReadBuffer(resp.raw, self.chunk_size)
+        charset = resp.headers.get('content-type', default='')
+        enc_search = re.search('charset=(?P<enc>\S*)', charset)
+        if enc_search is not None:
+            encoding = enc_search.group('enc')
+        else:
+            encoding = 'utf-8'
+
+        buf = ReadBuffer(resp.raw, self.chunk_size, encoding=encoding)
 
         while self.running and not resp.raw.closed:
             length = 0
@@ -404,7 +419,7 @@ class Stream(object):
         self._start(async)
 
     def filter(self, follow=None, track=None, async=False, locations=None,
-               stall_warnings=False, languages=None, encoding='utf8'):
+               stall_warnings=False, languages=None, encoding='utf8', filter_level=None):
         self.body = {}
         self.session.headers['Content-type'] = "application/x-www-form-urlencoded"
         if self.running:
@@ -423,6 +438,8 @@ class Stream(object):
             self.body['stall_warnings'] = stall_warnings
         if languages:
             self.body['language'] = u','.join(map(str, languages))
+        if filter_level:
+            self.body['filter_level'] = unicode(filter_level, encoding)
         self.session.params = {'delimited': 'length'}
         self.host = 'stream.twitter.com'
         self._start(async)

Benjamin Mako Hill || Want to submit a patch?