Merge pull request #3 from guyrt/master
[twitter-api-cdsw-solutions] / tweepy / streaming.py
index be233ffe0ef158bf7fa732976b2abfe6c18285f9..faf42eacfa3724e73d2c1cbead877e53dba00deb 100644 (file)
@@ -2,18 +2,25 @@
 # Copyright 2009-2010 Joshua Roesslein
 # See LICENSE for details.
 
+# Appengine users: https://developers.google.com/appengine/docs/python/sockets/#making_httplib_use_sockets
+
+from __future__ import absolute_import, print_function
+
 import logging
-import httplib
-from socket import timeout
+import requests
+from requests.exceptions import Timeout
 from threading import Thread
 from time import sleep
+
+import six
+
 import ssl
 
 from tweepy.models import Status
 from tweepy.api import API
 from tweepy.error import TweepError
 
-from tweepy.utils import import_simplejson, urlencode_noplus
+from tweepy.utils import import_simplejson
 json = import_simplejson()
 
 STREAM_VERSION = '1.1'
@@ -57,15 +64,25 @@ class StreamListener(object):
             status = Status.parse(self.api, data)
             if self.on_direct_message(status) is False:
                 return False
+        elif 'friends' in data:
+            if self.on_friends(data['friends']) is False:
+                return False
         elif 'limit' in data:
             if self.on_limit(data['limit']['track']) is False:
                 return False
         elif 'disconnect' in data:
             if self.on_disconnect(data['disconnect']) is False:
                 return False
+        elif 'warning' in data:
+            if self.on_warning(data['warning']) is False:
+                return False
         else:
             logging.error("Unknown message type: " + str(raw_data))
 
+    def keep_alive(self):
+        """Called when a keep-alive arrived"""
+        return
+
     def on_status(self, status):
         """Called when a new status arrives"""
         return
@@ -86,8 +103,15 @@ class StreamListener(object):
         """Called when a new direct message arrives"""
         return
 
+    def on_friends(self, friends):
+        """Called when a friends list arrives.
+
+        friends is a list that contains user_id
+        """
+        return
+
     def on_limit(self, track):
-        """Called when a limitation notice arrvies"""
+        """Called when a limitation notice arrives"""
         return
 
     def on_error(self, status_code):
@@ -106,6 +130,51 @@ class StreamListener(object):
         """
         return
 
+    def on_warning(self, notice):
+        """Called when a disconnection warning message arrives"""
+        return
+
+
+class ReadBuffer(object):
+    """Buffer data from the response in a smarter way than httplib/requests can.
+
+    Tweets are roughly in the 2-12kb range, averaging around 3kb.
+    Requests/urllib3/httplib/socket all use socket.read, which blocks
+    until enough data is returned. On some systems (eg google appengine), socket
+    reads are quite slow. To combat this latency we can read big chunks,
+    but the blocking part means we won't get results until enough tweets
+    have arrived. That may not be a big deal for high throughput systems.
+    For low throughput systems we don't want to sacrafice latency, so we
+    use small chunks so it can read the length and the tweet in 2 read calls.
+    """
+
+    def __init__(self, stream, chunk_size):
+        self._stream = stream
+        self._buffer = u""
+        self._chunk_size = chunk_size
+
+    def read_len(self, length):
+        while not self._stream.closed:
+            if len(self._buffer) >= length:
+                return self._pop(length)
+            read_len = max(self._chunk_size, length - len(self._buffer))
+            self._buffer += self._stream.read(read_len).decode("ascii")
+
+    def read_line(self, sep='\n'):
+        start = 0
+        while not self._stream.closed:
+            loc = self._buffer.find(sep, start)
+            if loc >= 0:
+                return self._pop(loc + len(sep))
+            else:
+                start = len(self._buffer)
+            self._buffer += self._stream.read(self._chunk_size).decode("ascii")
+
+    def _pop(self, length):
+        r = self._buffer[:length]
+        self._buffer = self._buffer[length:]
+        return r
+
 
 class Stream(object):
 
@@ -117,82 +186,99 @@ class Stream(object):
         self.running = False
         self.timeout = options.get("timeout", 300.0)
         self.retry_count = options.get("retry_count")
-        # values according to https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting
+        # values according to
+        # https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting
         self.retry_time_start = options.get("retry_time", 5.0)
         self.retry_420_start = options.get("retry_420", 60.0)
         self.retry_time_cap = options.get("retry_time_cap", 320.0)
         self.snooze_time_step = options.get("snooze_time", 0.25)
         self.snooze_time_cap = options.get("snooze_time_cap", 16)
-        self.buffer_size = options.get("buffer_size",  1500)
-        if options.get("secure", True):
-            self.scheme = "https"
-        else:
-            self.scheme = "http"
+
+        # The default socket.read size. Default to less than half the size of
+        # a tweet so that it reads tweets with the minimal latency of 2 reads
+        # per tweet. Values higher than ~1kb will increase latency by waiting
+        # for more data to arrive but may also increase throughput by doing
+        # fewer socket read calls.
+        self.chunk_size = options.get("chunk_size",  512)
+
+        self.verify = options.get("verify", True)
 
         self.api = API()
         self.headers = options.get("headers") or {}
-        self.parameters = None
+        self.new_session()
         self.body = None
         self.retry_time = self.retry_time_start
         self.snooze_time = self.snooze_time_step
 
+    def new_session(self):
+        self.session = requests.Session()
+        self.session.headers = self.headers
+        self.session.params = None
+
     def _run(self):
         # Authenticate
-        url = "%s://%s%s" % (self.scheme, self.host, self.url)
+        url = "https://%s%s" % (self.host, self.url)
 
         # Connect and process the stream
         error_counter = 0
-        conn = None
+        resp = None
         exception = None
         while self.running:
-            if self.retry_count is not None and error_counter > self.retry_count:
-                # quit if error count greater than retry count
-                break
+            if self.retry_count is not None:
+                if error_counter > self.retry_count:
+                    # quit if error count greater than retry count
+                    break
             try:
-                if self.scheme == "http":
-                    conn = httplib.HTTPConnection(self.host, timeout=self.timeout)
-                else:
-                    conn = httplib.HTTPSConnection(self.host, timeout=self.timeout)
-                self.auth.apply_auth(url, 'POST', self.headers, self.parameters)
-                conn.connect()
-                conn.request('POST', self.url, self.body, headers=self.headers)
-                resp = conn.getresponse()
-                if resp.status != 200:
-                    if self.listener.on_error(resp.status) is False:
+                auth = self.auth.apply_auth()
+                resp = self.session.request('POST',
+                                            url,
+                                            data=self.body,
+                                            timeout=self.timeout,
+                                            stream=True,
+                                            auth=auth,
+                                            verify=self.verify)
+                if resp.status_code != 200:
+                    if self.listener.on_error(resp.status_code) is False:
                         break
                     error_counter += 1
-                    if resp.status == 420:
-                        self.retry_time = max(self.retry_420_start, self.retry_time)
+                    if resp.status_code == 420:
+                        self.retry_time = max(self.retry_420_start,
+                                              self.retry_time)
                     sleep(self.retry_time)
-                    self.retry_time = min(self.retry_time * 2, self.retry_time_cap)
+                    self.retry_time = min(self.retry_time * 2,
+                                          self.retry_time_cap)
                 else:
                     error_counter = 0
                     self.retry_time = self.retry_time_start
                     self.snooze_time = self.snooze_time_step
                     self.listener.on_connect()
                     self._read_loop(resp)
-            except (timeout, ssl.SSLError) as exc:
+            except (Timeout, ssl.SSLError) as exc:
+                # This is still necessary, as a SSLError can actually be
+                # thrown when using Requests
                 # If it's not time out treat it like any other exception
-                if isinstance(exc, ssl.SSLError) and not (exc.args and 'timed out' in str(exc.args[0])):
-                    exception = exc
-                    break
-
-                if self.listener.on_timeout() == False:
+                if isinstance(exc, ssl.SSLError):
+                    if not (exc.args and 'timed out' in str(exc.args[0])):
+                        exception = exc
+                        break
+                if self.listener.on_timeout() is False:
                     break
                 if self.running is False:
                     break
-                conn.close()
                 sleep(self.snooze_time)
                 self.snooze_time = min(self.snooze_time + self.snooze_time_step,
                                        self.snooze_time_cap)
-            except Exception as exception:
-                # any other exception is fatal, so kill loop
-                break
+            # except Exception as exc:
+            #     exception = exc
+            #     # any other exception is fatal, so kill loop
+            #     break
 
         # cleanup
         self.running = False
-        if conn:
-            conn.close()
+        if resp:
+            resp.close()
+
+        self.new_session()
 
         if exception:
             # call a handler first so that the exception can be logged.
@@ -204,34 +290,58 @@ class Stream(object):
             self.running = False
 
     def _read_loop(self, resp):
+        buf = ReadBuffer(resp.raw, self.chunk_size)
+
+        while self.running and not resp.raw.closed:
+            length = 0
+            while not resp.raw.closed:
+                line = buf.read_line().strip()
+                if not line:
+                    self.listener.keep_alive()  # keep-alive new lines are expected
+                elif line.isdigit():
+                    length = int(line)
+                    break
+                else:
+                    raise TweepError('Expecting length, unexpected value found')
 
-        while self.running and not resp.isclosed():
-
-            # Note: keep-alive newlines might be inserted before each length value.
-            # read until we get a digit...
-            c = '\n'
-            while c == '\n' and self.running and not resp.isclosed():
-                c = resp.read(1)
-            delimited_string = c
-
-            # read rest of delimiter length..
-            d = ''
-            while d != '\n' and self.running and not resp.isclosed():
-                d = resp.read(1)
-                delimited_string += d
-
-            # read the next twitter status object
-            if delimited_string.strip().isdigit():
-                next_status_obj = resp.read( int(delimited_string) )
+            next_status_obj = buf.read_len(length)
+            if self.running:
                 self._data(next_status_obj)
 
-        if resp.isclosed():
+            # # Note: keep-alive newlines might be inserted before each length value.
+            # # read until we get a digit...
+            # c = b'\n'
+            # for c in resp.iter_content(decode_unicode=True):
+            #     if c == b'\n':
+            #         continue
+            #     break
+            #
+            # delimited_string = c
+            #
+            # # read rest of delimiter length..
+            # d = b''
+            # for d in resp.iter_content(decode_unicode=True):
+            #     if d != b'\n':
+            #         delimited_string += d
+            #         continue
+            #     break
+            #
+            # # read the next twitter status object
+            # if delimited_string.decode('utf-8').strip().isdigit():
+            #     status_id = int(delimited_string)
+            #     next_status_obj = resp.raw.read(status_id)
+            #     if self.running:
+            #         self._data(next_status_obj.decode('utf-8'))
+
+
+        if resp.raw.closed:
             self.on_closed(resp)
 
     def _start(self, async):
         self.running = True
         if async:
-            Thread(target=self._run).start()
+            self._thread = Thread(target=self._run)
+            self._thread.start()
         else:
             self._run()
 
@@ -239,81 +349,101 @@ class Stream(object):
         """ Called when the response has been closed by Twitter """
         pass
 
-    def userstream(self, stall_warnings=False, _with=None, replies=None,
-            track=None, locations=None, async=False, encoding='utf8'):
-        self.parameters = {'delimited': 'length'}
+    def userstream(self,
+                   stall_warnings=False,
+                   _with=None,
+                   replies=None,
+                   track=None,
+                   locations=None,
+                   async=False,
+                   encoding='utf8'):
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/user.json?delimited=length' % STREAM_VERSION
-        self.host='userstream.twitter.com'
+        self.url = '/%s/user.json' % STREAM_VERSION
+        self.host = 'userstream.twitter.com'
         if stall_warnings:
-            self.parameters['stall_warnings'] = stall_warnings
+            self.session.params['stall_warnings'] = stall_warnings
         if _with:
-            self.parameters['with'] = _with
+            self.session.params['with'] = _with
         if replies:
-            self.parameters['replies'] = replies
+            self.session.params['replies'] = replies
         if locations and len(locations) > 0:
-            assert len(locations) % 4 == 0
-            self.parameters['locations'] = ','.join(['%.2f' % l for l in locations])
+            if len(locations) % 4 != 0:
+                raise TweepError("Wrong number of locations points, "
+                                 "it has to be a multiple of 4")
+            self.session.params['locations'] = ','.join(['%.2f' % l for l in locations])
         if track:
-            encoded_track = [s.encode(encoding) for s in track]
-            self.parameters['track'] = ','.join(encoded_track)
-        self.body = urlencode_noplus(self.parameters)
+            self.session.params['track'] = u','.join(track).encode(encoding)
+
         self._start(async)
 
     def firehose(self, count=None, async=False):
-        self.parameters = {'delimited': 'length'}
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/firehose.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/firehose.json' % STREAM_VERSION
         if count:
             self.url += '&count=%s' % count
         self._start(async)
 
     def retweet(self, async=False):
-        self.parameters = {'delimited': 'length'}
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/retweet.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/retweet.json' % STREAM_VERSION
         self._start(async)
 
-    def sample(self, count=None, async=False):
-        self.parameters = {'delimited': 'length'}
+    def sample(self, async=False, languages=None):
+        self.session.params = {'delimited': 'length'}
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/sample.json?delimited=length' % STREAM_VERSION
-        if count:
-            self.url += '&count=%s' % count
+        self.url = '/%s/statuses/sample.json' % STREAM_VERSION
+        if languages:
+            self.session.params['language'] = ','.join(map(str, languages))
         self._start(async)
 
     def filter(self, follow=None, track=None, async=False, locations=None,
-               count=None, stall_warnings=False, languages=None, encoding='utf8'):
-        self.parameters = {}
-        self.headers['Content-type'] = "application/x-www-form-urlencoded"
+               stall_warnings=False, languages=None, encoding='utf8'):
+        self.body = {}
+        self.session.headers['Content-type'] = "application/x-www-form-urlencoded"
         if self.running:
             raise TweepError('Stream object already connected!')
-        self.url = '/%s/statuses/filter.json?delimited=length' % STREAM_VERSION
+        self.url = '/%s/statuses/filter.json' % STREAM_VERSION
         if follow:
-            encoded_follow = [s.encode(encoding) for s in follow]
-            self.parameters['follow'] = ','.join(encoded_follow)
+            self.body['follow'] = u','.join(follow).encode(encoding)
         if track:
-            encoded_track = [s.encode(encoding) for s in track]
-            self.parameters['track'] = ','.join(encoded_track)
+            self.body['track'] = u','.join(track).encode(encoding)
         if locations and len(locations) > 0:
-            assert len(locations) % 4 == 0
-            self.parameters['locations'] = ','.join(['%.4f' % l for l in locations])
-        if count:
-            self.parameters['count'] = count
+            if len(locations) % 4 != 0:
+                raise TweepError("Wrong number of locations points, "
+                                 "it has to be a multiple of 4")
+            self.body['locations'] = u','.join(['%.4f' % l for l in locations])
         if stall_warnings:
-            self.parameters['stall_warnings'] = stall_warnings
+            self.body['stall_warnings'] = stall_warnings
         if languages:
-            self.parameters['language'] = ','.join(map(str, languages))
-        self.body = urlencode_noplus(self.parameters)
-        self.parameters['delimited'] = 'length'
+            self.body['language'] = u','.join(map(str, languages))
+        self.session.params = {'delimited': 'length'}
+        self.host = 'stream.twitter.com'
+        self._start(async)
+
+    def sitestream(self, follow, stall_warnings=False,
+                   with_='user', replies=False, async=False):
+        self.body = {}
+        if self.running:
+            raise TweepError('Stream object already connected!')
+        self.url = '/%s/site.json' % STREAM_VERSION
+        self.body['follow'] = u','.join(map(six.text_type, follow))
+        self.body['delimited'] = 'length'
+        if stall_warnings:
+            self.body['stall_warnings'] = stall_warnings
+        if with_:
+            self.body['with'] = with_
+        if replies:
+            self.body['replies'] = replies
         self._start(async)
 
     def disconnect(self):
         if self.running is False:
             return
         self.running = False
-

Benjamin Mako Hill || Want to submit a patch?