X-Git-Url: https://projects.mako.cc/source/twitter-api-cdsw-solutions/blobdiff_plain/4c43c9724c25c3d919cb535973559fab1f1c0a7b..HEAD:/tweepy/streaming.py diff --git a/tweepy/streaming.py b/tweepy/streaming.py index be233ff..9b246bd 100644 --- a/tweepy/streaming.py +++ b/tweepy/streaming.py @@ -2,18 +2,25 @@ # Copyright 2009-2010 Joshua Roesslein # See LICENSE for details. +# Appengine users: https://developers.google.com/appengine/docs/python/sockets/#making_httplib_use_sockets + +from __future__ import absolute_import, print_function + import logging -import httplib -from socket import timeout +import requests +from requests.exceptions import Timeout from threading import Thread from time import sleep + +import six + import ssl from tweepy.models import Status from tweepy.api import API from tweepy.error import TweepError -from tweepy.utils import import_simplejson, urlencode_noplus +from tweepy.utils import import_simplejson json = import_simplejson() STREAM_VERSION = '1.1' @@ -57,15 +64,25 @@ class StreamListener(object): status = Status.parse(self.api, data) if self.on_direct_message(status) is False: return False + elif 'friends' in data: + if self.on_friends(data['friends']) is False: + return False elif 'limit' in data: if self.on_limit(data['limit']['track']) is False: return False elif 'disconnect' in data: if self.on_disconnect(data['disconnect']) is False: return False + elif 'warning' in data: + if self.on_warning(data['warning']) is False: + return False else: logging.error("Unknown message type: " + str(raw_data)) + def keep_alive(self): + """Called when a keep-alive arrived""" + return + def on_status(self, status): """Called when a new status arrives""" return @@ -86,8 +103,15 @@ class StreamListener(object): """Called when a new direct message arrives""" return + def on_friends(self, friends): + """Called when a friends list arrives. + + friends is a list that contains user_id + """ + return + def on_limit(self, track): - """Called when a limitation notice arrvies""" + """Called when a limitation notice arrives""" return def on_error(self, status_code): @@ -106,6 +130,51 @@ class StreamListener(object): """ return + def on_warning(self, notice): + """Called when a disconnection warning message arrives""" + return + + +class ReadBuffer(object): + """Buffer data from the response in a smarter way than httplib/requests can. + + Tweets are roughly in the 2-12kb range, averaging around 3kb. + Requests/urllib3/httplib/socket all use socket.read, which blocks + until enough data is returned. On some systems (eg google appengine), socket + reads are quite slow. To combat this latency we can read big chunks, + but the blocking part means we won't get results until enough tweets + have arrived. That may not be a big deal for high throughput systems. + For low throughput systems we don't want to sacrafice latency, so we + use small chunks so it can read the length and the tweet in 2 read calls. + """ + + def __init__(self, stream, chunk_size): + self._stream = stream + self._buffer = u"" + self._chunk_size = chunk_size + + def read_len(self, length): + while not self._stream.closed: + if len(self._buffer) >= length: + return self._pop(length) + read_len = max(self._chunk_size, length - len(self._buffer)) + self._buffer += self._stream.read(read_len).decode("ascii") + + def read_line(self, sep='\n'): + start = 0 + while not self._stream.closed: + loc = self._buffer.find(sep, start) + if loc >= 0: + return self._pop(loc + len(sep)) + else: + start = len(self._buffer) + self._buffer += self._stream.read(self._chunk_size).decode("ascii") + + def _pop(self, length): + r = self._buffer[:length] + self._buffer = self._buffer[length:] + return r + class Stream(object): @@ -117,82 +186,99 @@ class Stream(object): self.running = False self.timeout = options.get("timeout", 300.0) self.retry_count = options.get("retry_count") - # values according to https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting + # values according to + # https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting self.retry_time_start = options.get("retry_time", 5.0) self.retry_420_start = options.get("retry_420", 60.0) self.retry_time_cap = options.get("retry_time_cap", 320.0) self.snooze_time_step = options.get("snooze_time", 0.25) self.snooze_time_cap = options.get("snooze_time_cap", 16) - self.buffer_size = options.get("buffer_size", 1500) - if options.get("secure", True): - self.scheme = "https" - else: - self.scheme = "http" + + # The default socket.read size. Default to less than half the size of + # a tweet so that it reads tweets with the minimal latency of 2 reads + # per tweet. Values higher than ~1kb will increase latency by waiting + # for more data to arrive but may also increase throughput by doing + # fewer socket read calls. + self.chunk_size = options.get("chunk_size", 512) + + self.verify = options.get("verify", True) self.api = API() self.headers = options.get("headers") or {} - self.parameters = None + self.new_session() self.body = None self.retry_time = self.retry_time_start self.snooze_time = self.snooze_time_step + def new_session(self): + self.session = requests.Session() + self.session.headers = self.headers + self.session.params = None + def _run(self): # Authenticate - url = "%s://%s%s" % (self.scheme, self.host, self.url) + url = "https://%s%s" % (self.host, self.url) # Connect and process the stream error_counter = 0 - conn = None + resp = None exception = None while self.running: - if self.retry_count is not None and error_counter > self.retry_count: - # quit if error count greater than retry count - break + if self.retry_count is not None: + if error_counter > self.retry_count: + # quit if error count greater than retry count + break try: - if self.scheme == "http": - conn = httplib.HTTPConnection(self.host, timeout=self.timeout) - else: - conn = httplib.HTTPSConnection(self.host, timeout=self.timeout) - self.auth.apply_auth(url, 'POST', self.headers, self.parameters) - conn.connect() - conn.request('POST', self.url, self.body, headers=self.headers) - resp = conn.getresponse() - if resp.status != 200: - if self.listener.on_error(resp.status) is False: + auth = self.auth.apply_auth() + resp = self.session.request('POST', + url, + data=self.body, + timeout=self.timeout, + stream=True, + auth=auth, + verify=self.verify) + if resp.status_code != 200: + if self.listener.on_error(resp.status_code) is False: break error_counter += 1 - if resp.status == 420: - self.retry_time = max(self.retry_420_start, self.retry_time) + if resp.status_code == 420: + self.retry_time = max(self.retry_420_start, + self.retry_time) sleep(self.retry_time) - self.retry_time = min(self.retry_time * 2, self.retry_time_cap) + self.retry_time = min(self.retry_time * 2, + self.retry_time_cap) else: error_counter = 0 self.retry_time = self.retry_time_start self.snooze_time = self.snooze_time_step self.listener.on_connect() self._read_loop(resp) - except (timeout, ssl.SSLError) as exc: + except (Timeout, ssl.SSLError) as exc: + # This is still necessary, as a SSLError can actually be + # thrown when using Requests # If it's not time out treat it like any other exception - if isinstance(exc, ssl.SSLError) and not (exc.args and 'timed out' in str(exc.args[0])): - exception = exc - break - - if self.listener.on_timeout() == False: + if isinstance(exc, ssl.SSLError): + if not (exc.args and 'timed out' in str(exc.args[0])): + exception = exc + break + if self.listener.on_timeout() is False: break if self.running is False: break - conn.close() sleep(self.snooze_time) self.snooze_time = min(self.snooze_time + self.snooze_time_step, self.snooze_time_cap) - except Exception as exception: + except Exception as exc: + exception = exc # any other exception is fatal, so kill loop break # cleanup self.running = False - if conn: - conn.close() + if resp: + resp.close() + + self.new_session() if exception: # call a handler first so that the exception can be logged. @@ -204,34 +290,58 @@ class Stream(object): self.running = False def _read_loop(self, resp): + buf = ReadBuffer(resp.raw, self.chunk_size) + + while self.running and not resp.raw.closed: + length = 0 + while not resp.raw.closed: + line = buf.read_line().strip() + if not line: + self.listener.keep_alive() # keep-alive new lines are expected + elif line.isdigit(): + length = int(line) + break + else: + raise TweepError('Expecting length, unexpected value found') - while self.running and not resp.isclosed(): - - # Note: keep-alive newlines might be inserted before each length value. - # read until we get a digit... - c = '\n' - while c == '\n' and self.running and not resp.isclosed(): - c = resp.read(1) - delimited_string = c - - # read rest of delimiter length.. - d = '' - while d != '\n' and self.running and not resp.isclosed(): - d = resp.read(1) - delimited_string += d - - # read the next twitter status object - if delimited_string.strip().isdigit(): - next_status_obj = resp.read( int(delimited_string) ) + next_status_obj = buf.read_len(length) + if self.running: self._data(next_status_obj) - if resp.isclosed(): + # # Note: keep-alive newlines might be inserted before each length value. + # # read until we get a digit... + # c = b'\n' + # for c in resp.iter_content(decode_unicode=True): + # if c == b'\n': + # continue + # break + # + # delimited_string = c + # + # # read rest of delimiter length.. + # d = b'' + # for d in resp.iter_content(decode_unicode=True): + # if d != b'\n': + # delimited_string += d + # continue + # break + # + # # read the next twitter status object + # if delimited_string.decode('utf-8').strip().isdigit(): + # status_id = int(delimited_string) + # next_status_obj = resp.raw.read(status_id) + # if self.running: + # self._data(next_status_obj.decode('utf-8')) + + + if resp.raw.closed: self.on_closed(resp) def _start(self, async): self.running = True if async: - Thread(target=self._run).start() + self._thread = Thread(target=self._run) + self._thread.start() else: self._run() @@ -239,81 +349,101 @@ class Stream(object): """ Called when the response has been closed by Twitter """ pass - def userstream(self, stall_warnings=False, _with=None, replies=None, - track=None, locations=None, async=False, encoding='utf8'): - self.parameters = {'delimited': 'length'} + def userstream(self, + stall_warnings=False, + _with=None, + replies=None, + track=None, + locations=None, + async=False, + encoding='utf8'): + self.session.params = {'delimited': 'length'} if self.running: raise TweepError('Stream object already connected!') - self.url = '/%s/user.json?delimited=length' % STREAM_VERSION - self.host='userstream.twitter.com' + self.url = '/%s/user.json' % STREAM_VERSION + self.host = 'userstream.twitter.com' if stall_warnings: - self.parameters['stall_warnings'] = stall_warnings + self.session.params['stall_warnings'] = stall_warnings if _with: - self.parameters['with'] = _with + self.session.params['with'] = _with if replies: - self.parameters['replies'] = replies + self.session.params['replies'] = replies if locations and len(locations) > 0: - assert len(locations) % 4 == 0 - self.parameters['locations'] = ','.join(['%.2f' % l for l in locations]) + if len(locations) % 4 != 0: + raise TweepError("Wrong number of locations points, " + "it has to be a multiple of 4") + self.session.params['locations'] = ','.join(['%.2f' % l for l in locations]) if track: - encoded_track = [s.encode(encoding) for s in track] - self.parameters['track'] = ','.join(encoded_track) - self.body = urlencode_noplus(self.parameters) + self.session.params['track'] = u','.join(track).encode(encoding) + self._start(async) def firehose(self, count=None, async=False): - self.parameters = {'delimited': 'length'} + self.session.params = {'delimited': 'length'} if self.running: raise TweepError('Stream object already connected!') - self.url = '/%s/statuses/firehose.json?delimited=length' % STREAM_VERSION + self.url = '/%s/statuses/firehose.json' % STREAM_VERSION if count: self.url += '&count=%s' % count self._start(async) def retweet(self, async=False): - self.parameters = {'delimited': 'length'} + self.session.params = {'delimited': 'length'} if self.running: raise TweepError('Stream object already connected!') - self.url = '/%s/statuses/retweet.json?delimited=length' % STREAM_VERSION + self.url = '/%s/statuses/retweet.json' % STREAM_VERSION self._start(async) - def sample(self, count=None, async=False): - self.parameters = {'delimited': 'length'} + def sample(self, async=False, languages=None): + self.session.params = {'delimited': 'length'} if self.running: raise TweepError('Stream object already connected!') - self.url = '/%s/statuses/sample.json?delimited=length' % STREAM_VERSION - if count: - self.url += '&count=%s' % count + self.url = '/%s/statuses/sample.json' % STREAM_VERSION + if languages: + self.session.params['language'] = ','.join(map(str, languages)) self._start(async) def filter(self, follow=None, track=None, async=False, locations=None, - count=None, stall_warnings=False, languages=None, encoding='utf8'): - self.parameters = {} - self.headers['Content-type'] = "application/x-www-form-urlencoded" + stall_warnings=False, languages=None, encoding='utf8'): + self.body = {} + self.session.headers['Content-type'] = "application/x-www-form-urlencoded" if self.running: raise TweepError('Stream object already connected!') - self.url = '/%s/statuses/filter.json?delimited=length' % STREAM_VERSION + self.url = '/%s/statuses/filter.json' % STREAM_VERSION if follow: - encoded_follow = [s.encode(encoding) for s in follow] - self.parameters['follow'] = ','.join(encoded_follow) + self.body['follow'] = u','.join(follow).encode(encoding) if track: - encoded_track = [s.encode(encoding) for s in track] - self.parameters['track'] = ','.join(encoded_track) + self.body['track'] = u','.join(track).encode(encoding) if locations and len(locations) > 0: - assert len(locations) % 4 == 0 - self.parameters['locations'] = ','.join(['%.4f' % l for l in locations]) - if count: - self.parameters['count'] = count + if len(locations) % 4 != 0: + raise TweepError("Wrong number of locations points, " + "it has to be a multiple of 4") + self.body['locations'] = u','.join(['%.4f' % l for l in locations]) if stall_warnings: - self.parameters['stall_warnings'] = stall_warnings + self.body['stall_warnings'] = stall_warnings if languages: - self.parameters['language'] = ','.join(map(str, languages)) - self.body = urlencode_noplus(self.parameters) - self.parameters['delimited'] = 'length' + self.body['language'] = u','.join(map(str, languages)) + self.session.params = {'delimited': 'length'} + self.host = 'stream.twitter.com' + self._start(async) + + def sitestream(self, follow, stall_warnings=False, + with_='user', replies=False, async=False): + self.body = {} + if self.running: + raise TweepError('Stream object already connected!') + self.url = '/%s/site.json' % STREAM_VERSION + self.body['follow'] = u','.join(map(six.text_type, follow)) + self.body['delimited'] = 'length' + if stall_warnings: + self.body['stall_warnings'] = stall_warnings + if with_: + self.body['with'] = with_ + if replies: + self.body['replies'] = replies self._start(async) def disconnect(self): if self.running is False: return self.running = False -