reverted the encoding fix that tommy made in lieu of a different crazy hack
[twitter-api-cdsw] / tweepy / streaming.py
1 # Tweepy
2 # Copyright 2009-2010 Joshua Roesslein
3 # See LICENSE for details.
4
5 # Appengine users: https://developers.google.com/appengine/docs/python/sockets/#making_httplib_use_sockets
6
7 from __future__ import absolute_import, print_function
8
9 import logging
10 import requests
11 from requests.exceptions import Timeout
12 from threading import Thread
13 from time import sleep
14
15 import six
16
17 import ssl
18
19 from tweepy.models import Status
20 from tweepy.api import API
21 from tweepy.error import TweepError
22
23 from tweepy.utils import import_simplejson
24 json = import_simplejson()
25
26 STREAM_VERSION = '1.1'
27
28
29 class StreamListener(object):
30
31     def __init__(self, api=None):
32         self.api = api or API()
33
34     def on_connect(self):
35         """Called once connected to streaming server.
36
37         This will be invoked once a successful response
38         is received from the server. Allows the listener
39         to perform some work prior to entering the read loop.
40         """
41         pass
42
43     def on_data(self, raw_data):
44         """Called when raw data is received from connection.
45
46         Override this method if you wish to manually handle
47         the stream data. Return False to stop stream and close connection.
48         """
49         data = json.loads(raw_data)
50
51         if 'in_reply_to_status_id' in data:
52             status = Status.parse(self.api, data)
53             if self.on_status(status) is False:
54                 return False
55         elif 'delete' in data:
56             delete = data['delete']['status']
57             if self.on_delete(delete['id'], delete['user_id']) is False:
58                 return False
59         elif 'event' in data:
60             status = Status.parse(self.api, data)
61             if self.on_event(status) is False:
62                 return False
63         elif 'direct_message' in data:
64             status = Status.parse(self.api, data)
65             if self.on_direct_message(status) is False:
66                 return False
67         elif 'friends' in data:
68             if self.on_friends(data['friends']) is False:
69                 return False
70         elif 'limit' in data:
71             if self.on_limit(data['limit']['track']) is False:
72                 return False
73         elif 'disconnect' in data:
74             if self.on_disconnect(data['disconnect']) is False:
75                 return False
76         elif 'warning' in data:
77             if self.on_warning(data['warning']) is False:
78                 return False
79         else:
80             logging.error("Unknown message type: " + str(raw_data))
81
82     def keep_alive(self):
83         """Called when a keep-alive arrived"""
84         return
85
86     def on_status(self, status):
87         """Called when a new status arrives"""
88         return
89
90     def on_exception(self, exception):
91         """Called when an unhandled exception occurs."""
92         return
93
94     def on_delete(self, status_id, user_id):
95         """Called when a delete notice arrives for a status"""
96         return
97
98     def on_event(self, status):
99         """Called when a new event arrives"""
100         return
101
102     def on_direct_message(self, status):
103         """Called when a new direct message arrives"""
104         return
105
106     def on_friends(self, friends):
107         """Called when a friends list arrives.
108
109         friends is a list that contains user_id
110         """
111         return
112
113     def on_limit(self, track):
114         """Called when a limitation notice arrives"""
115         return
116
117     def on_error(self, status_code):
118         """Called when a non-200 status code is returned"""
119         return False
120
121     def on_timeout(self):
122         """Called when stream connection times out"""
123         return
124
125     def on_disconnect(self, notice):
126         """Called when twitter sends a disconnect notice
127
128         Disconnect codes are listed here:
129         https://dev.twitter.com/docs/streaming-apis/messages#Disconnect_messages_disconnect
130         """
131         return
132
133     def on_warning(self, notice):
134         """Called when a disconnection warning message arrives"""
135         return
136
137
138 class ReadBuffer(object):
139     """Buffer data from the response in a smarter way than httplib/requests can.
140
141     Tweets are roughly in the 2-12kb range, averaging around 3kb.
142     Requests/urllib3/httplib/socket all use socket.read, which blocks
143     until enough data is returned. On some systems (eg google appengine), socket
144     reads are quite slow. To combat this latency we can read big chunks,
145     but the blocking part means we won't get results until enough tweets
146     have arrived. That may not be a big deal for high throughput systems.
147     For low throughput systems we don't want to sacrafice latency, so we
148     use small chunks so it can read the length and the tweet in 2 read calls.
149     """
150
151     def __init__(self, stream, chunk_size):
152         self._stream = stream
153         self._buffer = u""
154         self._chunk_size = chunk_size
155
156     def read_len(self, length):
157         while not self._stream.closed:
158             if len(self._buffer) >= length:
159                 return self._pop(length)
160             read_len = max(self._chunk_size, length - len(self._buffer))
161             self._buffer += self._stream.read(read_len).decode("ascii")
162
163     def read_line(self, sep='\n'):
164         start = 0
165         while not self._stream.closed:
166             loc = self._buffer.find(sep, start)
167             if loc >= 0:
168                 return self._pop(loc + len(sep))
169             else:
170                 start = len(self._buffer)
171             self._buffer += self._stream.read(self._chunk_size).decode("ascii")
172
173     def _pop(self, length):
174         r = self._buffer[:length]
175         self._buffer = self._buffer[length:]
176         return r
177
178
179 class Stream(object):
180
181     host = 'stream.twitter.com'
182
183     def __init__(self, auth, listener, **options):
184         self.auth = auth
185         self.listener = listener
186         self.running = False
187         self.timeout = options.get("timeout", 300.0)
188         self.retry_count = options.get("retry_count")
189         # values according to
190         # https://dev.twitter.com/docs/streaming-apis/connecting#Reconnecting
191         self.retry_time_start = options.get("retry_time", 5.0)
192         self.retry_420_start = options.get("retry_420", 60.0)
193         self.retry_time_cap = options.get("retry_time_cap", 320.0)
194         self.snooze_time_step = options.get("snooze_time", 0.25)
195         self.snooze_time_cap = options.get("snooze_time_cap", 16)
196
197         # The default socket.read size. Default to less than half the size of
198         # a tweet so that it reads tweets with the minimal latency of 2 reads
199         # per tweet. Values higher than ~1kb will increase latency by waiting
200         # for more data to arrive but may also increase throughput by doing
201         # fewer socket read calls.
202         self.chunk_size = options.get("chunk_size",  512)
203
204         self.verify = options.get("verify", True)
205
206         self.api = API()
207         self.headers = options.get("headers") or {}
208         self.new_session()
209         self.body = None
210         self.retry_time = self.retry_time_start
211         self.snooze_time = self.snooze_time_step
212
213     def new_session(self):
214         self.session = requests.Session()
215         self.session.headers = self.headers
216         self.session.params = None
217
218     def _run(self):
219         # Authenticate
220         url = "https://%s%s" % (self.host, self.url)
221
222         # Connect and process the stream
223         error_counter = 0
224         resp = None
225         exception = None
226         while self.running:
227             if self.retry_count is not None:
228                 if error_counter > self.retry_count:
229                     # quit if error count greater than retry count
230                     break
231             try:
232                 auth = self.auth.apply_auth()
233                 resp = self.session.request('POST',
234                                             url,
235                                             data=self.body,
236                                             timeout=self.timeout,
237                                             stream=True,
238                                             auth=auth,
239                                             verify=self.verify)
240                 if resp.status_code != 200:
241                     if self.listener.on_error(resp.status_code) is False:
242                         break
243                     error_counter += 1
244                     if resp.status_code == 420:
245                         self.retry_time = max(self.retry_420_start,
246                                               self.retry_time)
247                     sleep(self.retry_time)
248                     self.retry_time = min(self.retry_time * 2,
249                                           self.retry_time_cap)
250                 else:
251                     error_counter = 0
252                     self.retry_time = self.retry_time_start
253                     self.snooze_time = self.snooze_time_step
254                     self.listener.on_connect()
255                     self._read_loop(resp)
256             except (Timeout, ssl.SSLError) as exc:
257                 # This is still necessary, as a SSLError can actually be
258                 # thrown when using Requests
259                 # If it's not time out treat it like any other exception
260                 if isinstance(exc, ssl.SSLError):
261                     if not (exc.args and 'timed out' in str(exc.args[0])):
262                         exception = exc
263                         break
264                 if self.listener.on_timeout() is False:
265                     break
266                 if self.running is False:
267                     break
268                 sleep(self.snooze_time)
269                 self.snooze_time = min(self.snooze_time + self.snooze_time_step,
270                                        self.snooze_time_cap)
271             except Exception as exc:
272                 exception = exc
273                 # any other exception is fatal, so kill loop
274                 break
275
276         # cleanup
277         self.running = False
278         if resp:
279             resp.close()
280
281         self.new_session()
282
283         if exception:
284             # call a handler first so that the exception can be logged.
285             self.listener.on_exception(exception)
286             raise
287
288     def _data(self, data):
289         if self.listener.on_data(data) is False:
290             self.running = False
291
292     def _read_loop(self, resp):
293         buf = ReadBuffer(resp.raw, self.chunk_size)
294
295         while self.running and not resp.raw.closed:
296             length = 0
297             while not resp.raw.closed:
298                 line = buf.read_line().strip()
299                 if not line:
300                     self.listener.keep_alive()  # keep-alive new lines are expected
301                 elif line.isdigit():
302                     length = int(line)
303                     break
304                 else:
305                     raise TweepError('Expecting length, unexpected value found')
306
307             next_status_obj = buf.read_len(length)
308             if self.running:
309                 self._data(next_status_obj)
310
311             # # Note: keep-alive newlines might be inserted before each length value.
312             # # read until we get a digit...
313             # c = b'\n'
314             # for c in resp.iter_content(decode_unicode=True):
315             #     if c == b'\n':
316             #         continue
317             #     break
318             #
319             # delimited_string = c
320             #
321             # # read rest of delimiter length..
322             # d = b''
323             # for d in resp.iter_content(decode_unicode=True):
324             #     if d != b'\n':
325             #         delimited_string += d
326             #         continue
327             #     break
328             #
329             # # read the next twitter status object
330             # if delimited_string.decode('utf-8').strip().isdigit():
331             #     status_id = int(delimited_string)
332             #     next_status_obj = resp.raw.read(status_id)
333             #     if self.running:
334             #         self._data(next_status_obj.decode('utf-8'))
335
336
337         if resp.raw.closed:
338             self.on_closed(resp)
339
340     def _start(self, async):
341         self.running = True
342         if async:
343             self._thread = Thread(target=self._run)
344             self._thread.start()
345         else:
346             self._run()
347
348     def on_closed(self, resp):
349         """ Called when the response has been closed by Twitter """
350         pass
351
352     def userstream(self,
353                    stall_warnings=False,
354                    _with=None,
355                    replies=None,
356                    track=None,
357                    locations=None,
358                    async=False,
359                    encoding='utf8'):
360         self.session.params = {'delimited': 'length'}
361         if self.running:
362             raise TweepError('Stream object already connected!')
363         self.url = '/%s/user.json' % STREAM_VERSION
364         self.host = 'userstream.twitter.com'
365         if stall_warnings:
366             self.session.params['stall_warnings'] = stall_warnings
367         if _with:
368             self.session.params['with'] = _with
369         if replies:
370             self.session.params['replies'] = replies
371         if locations and len(locations) > 0:
372             if len(locations) % 4 != 0:
373                 raise TweepError("Wrong number of locations points, "
374                                  "it has to be a multiple of 4")
375             self.session.params['locations'] = ','.join(['%.2f' % l for l in locations])
376         if track:
377             self.session.params['track'] = u','.join(track).encode(encoding)
378
379         self._start(async)
380
381     def firehose(self, count=None, async=False):
382         self.session.params = {'delimited': 'length'}
383         if self.running:
384             raise TweepError('Stream object already connected!')
385         self.url = '/%s/statuses/firehose.json' % STREAM_VERSION
386         if count:
387             self.url += '&count=%s' % count
388         self._start(async)
389
390     def retweet(self, async=False):
391         self.session.params = {'delimited': 'length'}
392         if self.running:
393             raise TweepError('Stream object already connected!')
394         self.url = '/%s/statuses/retweet.json' % STREAM_VERSION
395         self._start(async)
396
397     def sample(self, async=False, languages=None):
398         self.session.params = {'delimited': 'length'}
399         if self.running:
400             raise TweepError('Stream object already connected!')
401         self.url = '/%s/statuses/sample.json' % STREAM_VERSION
402         if languages:
403             self.session.params['language'] = ','.join(map(str, languages))
404         self._start(async)
405
406     def filter(self, follow=None, track=None, async=False, locations=None,
407                stall_warnings=False, languages=None, encoding='utf8'):
408         self.body = {}
409         self.session.headers['Content-type'] = "application/x-www-form-urlencoded"
410         if self.running:
411             raise TweepError('Stream object already connected!')
412         self.url = '/%s/statuses/filter.json' % STREAM_VERSION
413         if follow:
414             self.body['follow'] = u','.join(follow).encode(encoding)
415         if track:
416             self.body['track'] = u','.join(track).encode(encoding)
417         if locations and len(locations) > 0:
418             if len(locations) % 4 != 0:
419                 raise TweepError("Wrong number of locations points, "
420                                  "it has to be a multiple of 4")
421             self.body['locations'] = u','.join(['%.4f' % l for l in locations])
422         if stall_warnings:
423             self.body['stall_warnings'] = stall_warnings
424         if languages:
425             self.body['language'] = u','.join(map(str, languages))
426         self.session.params = {'delimited': 'length'}
427         self.host = 'stream.twitter.com'
428         self._start(async)
429
430     def sitestream(self, follow, stall_warnings=False,
431                    with_='user', replies=False, async=False):
432         self.body = {}
433         if self.running:
434             raise TweepError('Stream object already connected!')
435         self.url = '/%s/site.json' % STREAM_VERSION
436         self.body['follow'] = u','.join(map(six.text_type, follow))
437         self.body['delimited'] = 'length'
438         if stall_warnings:
439             self.body['stall_warnings'] = stall_warnings
440         if with_:
441             self.body['with'] = with_
442         if replies:
443             self.body['replies'] = replies
444         self._start(async)
445
446     def disconnect(self):
447         if self.running is False:
448             return
449         self.running = False

Benjamin Mako Hill || Want to submit a patch?