initial version of several example programs using the tweepy twitter API
[twitter-api-cdsw] / oauth / oauth.py
1 """
2 The MIT License
3
4 Copyright (c) 2007 Leah Culver
5
6 Permission is hereby granted, free of charge, to any person obtaining a copy
7 of this software and associated documentation files (the "Software"), to deal
8 in the Software without restriction, including without limitation the rights
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 copies of the Software, and to permit persons to whom the Software is
11 furnished to do so, subject to the following conditions:
12
13 The above copyright notice and this permission notice shall be included in
14 all copies or substantial portions of the Software.
15
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 THE SOFTWARE.
23 """
24
25 import cgi
26 import urllib
27 import time
28 import random
29 import urlparse
30 import hmac
31 import binascii
32
33
34 VERSION = '1.0' # Hi Blaine!
35 HTTP_METHOD = 'GET'
36 SIGNATURE_METHOD = 'PLAINTEXT'
37
38
39 class OAuthError(RuntimeError):
40     """Generic exception class."""
41     def __init__(self, message='OAuth error occured.'):
42         self.message = message
43
44 def build_authenticate_header(realm=''):
45     """Optional WWW-Authenticate header (401 error)"""
46     return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
47
48 def escape(s):
49     """Escape a URL including any /."""
50     return urllib.quote(s, safe='~')
51
52 def _utf8_str(s):
53     """Convert unicode to utf-8."""
54     if isinstance(s, unicode):
55         return s.encode("utf-8")
56     else:
57         return str(s)
58
59 def generate_timestamp():
60     """Get seconds since epoch (UTC)."""
61     return int(time.time())
62
63 def generate_nonce(length=8):
64     """Generate pseudorandom number."""
65     return ''.join([str(random.randint(0, 9)) for i in range(length)])
66
67 def generate_verifier(length=8):
68     """Generate pseudorandom number."""
69     return ''.join([str(random.randint(0, 9)) for i in range(length)])
70
71
72 class OAuthConsumer(object):
73     """Consumer of OAuth authentication.
74
75     OAuthConsumer is a data type that represents the identity of the Consumer
76     via its shared secret with the Service Provider.
77
78     """
79     key = None
80     secret = None
81
82     def __init__(self, key, secret):
83         self.key = key
84         self.secret = secret
85
86
87 class OAuthToken(object):
88     """OAuthToken is a data type that represents an End User via either an access
89     or request token.
90     
91     key -- the token
92     secret -- the token secret
93
94     """
95     key = None
96     secret = None
97     callback = None
98     callback_confirmed = None
99     verifier = None
100
101     def __init__(self, key, secret):
102         self.key = key
103         self.secret = secret
104
105     def set_callback(self, callback):
106         self.callback = callback
107         self.callback_confirmed = 'true'
108
109     def set_verifier(self, verifier=None):
110         if verifier is not None:
111             self.verifier = verifier
112         else:
113             self.verifier = generate_verifier()
114
115     def get_callback_url(self):
116         if self.callback and self.verifier:
117             # Append the oauth_verifier.
118             parts = urlparse.urlparse(self.callback)
119             scheme, netloc, path, params, query, fragment = parts[:6]
120             if query:
121                 query = '%s&oauth_verifier=%s' % (query, self.verifier)
122             else:
123                 query = 'oauth_verifier=%s' % self.verifier
124             return urlparse.urlunparse((scheme, netloc, path, params,
125                 query, fragment))
126         return self.callback
127
128     def to_string(self):
129         data = {
130             'oauth_token': self.key,
131             'oauth_token_secret': self.secret,
132         }
133         if self.callback_confirmed is not None:
134             data['oauth_callback_confirmed'] = self.callback_confirmed
135         return urllib.urlencode(data)
136  
137     def from_string(s):
138         """ Returns a token from something like:
139         oauth_token_secret=xxx&oauth_token=xxx
140         """
141         params = cgi.parse_qs(s, keep_blank_values=False)
142         key = params['oauth_token'][0]
143         secret = params['oauth_token_secret'][0]
144         token = OAuthToken(key, secret)
145         try:
146             token.callback_confirmed = params['oauth_callback_confirmed'][0]
147         except KeyError:
148             pass # 1.0, no callback confirmed.
149         return token
150     from_string = staticmethod(from_string)
151
152     def __str__(self):
153         return self.to_string()
154
155
156 class OAuthRequest(object):
157     """OAuthRequest represents the request and can be serialized.
158
159     OAuth parameters:
160         - oauth_consumer_key 
161         - oauth_token
162         - oauth_signature_method
163         - oauth_signature 
164         - oauth_timestamp 
165         - oauth_nonce
166         - oauth_version
167         - oauth_verifier
168         ... any additional parameters, as defined by the Service Provider.
169     """
170     parameters = None # OAuth parameters.
171     http_method = HTTP_METHOD
172     http_url = None
173     version = VERSION
174
175     def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):
176         self.http_method = http_method
177         self.http_url = http_url
178         self.parameters = parameters or {}
179
180     def set_parameter(self, parameter, value):
181         self.parameters[parameter] = value
182
183     def get_parameter(self, parameter):
184         try:
185             return self.parameters[parameter]
186         except:
187             raise OAuthError('Parameter not found: %s' % parameter)
188
189     def _get_timestamp_nonce(self):
190         return self.get_parameter('oauth_timestamp'), self.get_parameter(
191             'oauth_nonce')
192
193     def get_nonoauth_parameters(self):
194         """Get any non-OAuth parameters."""
195         parameters = {}
196         for k, v in self.parameters.iteritems():
197             # Ignore oauth parameters.
198             if k.find('oauth_') < 0:
199                 parameters[k] = v
200         return parameters
201
202     def to_header(self, realm=''):
203         """Serialize as a header for an HTTPAuth request."""
204         auth_header = 'OAuth realm="%s"' % realm
205         # Add the oauth parameters.
206         if self.parameters:
207             for k, v in self.parameters.iteritems():
208                 if k[:6] == 'oauth_':
209                     auth_header += ', %s="%s"' % (k, escape(str(v)))
210         return {'Authorization': auth_header}
211
212     def to_postdata(self):
213         """Serialize as post data for a POST request."""
214         return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \
215             for k, v in self.parameters.iteritems()])
216
217     def to_url(self):
218         """Serialize as a URL for a GET request."""
219         return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
220
221     def get_normalized_parameters(self):
222         """Return a string that contains the parameters that must be signed."""
223         params = self.parameters
224         try:
225             # Exclude the signature if it exists.
226             del params['oauth_signature']
227         except:
228             pass
229         # Escape key values before sorting.
230         key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \
231             for k,v in params.items()]
232         # Sort lexicographically, first after key, then after value.
233         key_values.sort()
234         # Combine key value pairs into a string.
235         return '&'.join(['%s=%s' % (k, v) for k, v in key_values])
236
237     def get_normalized_http_method(self):
238         """Uppercases the http method."""
239         return self.http_method.upper()
240
241     def get_normalized_http_url(self):
242         """Parses the URL and rebuilds it to be scheme://host/path."""
243         parts = urlparse.urlparse(self.http_url)
244         scheme, netloc, path = parts[:3]
245         # Exclude default port numbers.
246         if scheme == 'http' and netloc[-3:] == ':80':
247             netloc = netloc[:-3]
248         elif scheme == 'https' and netloc[-4:] == ':443':
249             netloc = netloc[:-4]
250         return '%s://%s%s' % (scheme, netloc, path)
251
252     def sign_request(self, signature_method, consumer, token):
253         """Set the signature parameter to the result of build_signature."""
254         # Set the signature method.
255         self.set_parameter('oauth_signature_method',
256             signature_method.get_name())
257         # Set the signature.
258         self.set_parameter('oauth_signature',
259             self.build_signature(signature_method, consumer, token))
260
261     def build_signature(self, signature_method, consumer, token):
262         """Calls the build signature method within the signature method."""
263         return signature_method.build_signature(self, consumer, token)
264
265     def from_request(http_method, http_url, headers=None, parameters=None,
266             query_string=None):
267         """Combines multiple parameter sources."""
268         if parameters is None:
269             parameters = {}
270
271         # Headers
272         if headers and 'Authorization' in headers:
273             auth_header = headers['Authorization']
274             # Check that the authorization header is OAuth.
275             if auth_header[:6] == 'OAuth ':
276                 auth_header = auth_header[6:]
277                 try:
278                     # Get the parameters from the header.
279                     header_params = OAuthRequest._split_header(auth_header)
280                     parameters.update(header_params)
281                 except:
282                     raise OAuthError('Unable to parse OAuth parameters from '
283                         'Authorization header.')
284
285         # GET or POST query string.
286         if query_string:
287             query_params = OAuthRequest._split_url_string(query_string)
288             parameters.update(query_params)
289
290         # URL parameters.
291         param_str = urlparse.urlparse(http_url)[4] # query
292         url_params = OAuthRequest._split_url_string(param_str)
293         parameters.update(url_params)
294
295         if parameters:
296             return OAuthRequest(http_method, http_url, parameters)
297
298         return None
299     from_request = staticmethod(from_request)
300
301     def from_consumer_and_token(oauth_consumer, token=None,
302             callback=None, verifier=None, http_method=HTTP_METHOD,
303             http_url=None, parameters=None):
304         if not parameters:
305             parameters = {}
306
307         defaults = {
308             'oauth_consumer_key': oauth_consumer.key,
309             'oauth_timestamp': generate_timestamp(),
310             'oauth_nonce': generate_nonce(),
311             'oauth_version': OAuthRequest.version,
312         }
313
314         defaults.update(parameters)
315         parameters = defaults
316
317         if token:
318             parameters['oauth_token'] = token.key
319             if token.callback:
320                 parameters['oauth_callback'] = token.callback
321             # 1.0a support for verifier.
322             if verifier:
323                 parameters['oauth_verifier'] = verifier
324         elif callback:
325             # 1.0a support for callback in the request token request.
326             parameters['oauth_callback'] = callback
327
328         return OAuthRequest(http_method, http_url, parameters)
329     from_consumer_and_token = staticmethod(from_consumer_and_token)
330
331     def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,
332             http_url=None, parameters=None):
333         if not parameters:
334             parameters = {}
335
336         parameters['oauth_token'] = token.key
337
338         if callback:
339             parameters['oauth_callback'] = callback
340
341         return OAuthRequest(http_method, http_url, parameters)
342     from_token_and_callback = staticmethod(from_token_and_callback)
343
344     def _split_header(header):
345         """Turn Authorization: header into parameters."""
346         params = {}
347         parts = header.split(',')
348         for param in parts:
349             # Ignore realm parameter.
350             if param.find('realm') > -1:
351                 continue
352             # Remove whitespace.
353             param = param.strip()
354             # Split key-value.
355             param_parts = param.split('=', 1)
356             # Remove quotes and unescape the value.
357             params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
358         return params
359     _split_header = staticmethod(_split_header)
360
361     def _split_url_string(param_str):
362         """Turn URL string into parameters."""
363         parameters = cgi.parse_qs(param_str, keep_blank_values=False)
364         for k, v in parameters.iteritems():
365             parameters[k] = urllib.unquote(v[0])
366         return parameters
367     _split_url_string = staticmethod(_split_url_string)
368
369 class OAuthServer(object):
370     """A worker to check the validity of a request against a data store."""
371     timestamp_threshold = 300 # In seconds, five minutes.
372     version = VERSION
373     signature_methods = None
374     data_store = None
375
376     def __init__(self, data_store=None, signature_methods=None):
377         self.data_store = data_store
378         self.signature_methods = signature_methods or {}
379
380     def set_data_store(self, data_store):
381         self.data_store = data_store
382
383     def get_data_store(self):
384         return self.data_store
385
386     def add_signature_method(self, signature_method):
387         self.signature_methods[signature_method.get_name()] = signature_method
388         return self.signature_methods
389
390     def fetch_request_token(self, oauth_request):
391         """Processes a request_token request and returns the
392         request token on success.
393         """
394         try:
395             # Get the request token for authorization.
396             token = self._get_token(oauth_request, 'request')
397         except OAuthError:
398             # No token required for the initial token request.
399             version = self._get_version(oauth_request)
400             consumer = self._get_consumer(oauth_request)
401             try:
402                 callback = self.get_callback(oauth_request)
403             except OAuthError:
404                 callback = None # 1.0, no callback specified.
405             self._check_signature(oauth_request, consumer, None)
406             # Fetch a new token.
407             token = self.data_store.fetch_request_token(consumer, callback)
408         return token
409
410     def fetch_access_token(self, oauth_request):
411         """Processes an access_token request and returns the
412         access token on success.
413         """
414         version = self._get_version(oauth_request)
415         consumer = self._get_consumer(oauth_request)
416         try:
417             verifier = self._get_verifier(oauth_request)
418         except OAuthError:
419             verifier = None
420         # Get the request token.
421         token = self._get_token(oauth_request, 'request')
422         self._check_signature(oauth_request, consumer, token)
423         new_token = self.data_store.fetch_access_token(consumer, token, verifier)
424         return new_token
425
426     def verify_request(self, oauth_request):
427         """Verifies an api call and checks all the parameters."""
428         # -> consumer and token
429         version = self._get_version(oauth_request)
430         consumer = self._get_consumer(oauth_request)
431         # Get the access token.
432         token = self._get_token(oauth_request, 'access')
433         self._check_signature(oauth_request, consumer, token)
434         parameters = oauth_request.get_nonoauth_parameters()
435         return consumer, token, parameters
436
437     def authorize_token(self, token, user):
438         """Authorize a request token."""
439         return self.data_store.authorize_request_token(token, user)
440
441     def get_callback(self, oauth_request):
442         """Get the callback URL."""
443         return oauth_request.get_parameter('oauth_callback')
444  
445     def build_authenticate_header(self, realm=''):
446         """Optional support for the authenticate header."""
447         return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
448
449     def _get_version(self, oauth_request):
450         """Verify the correct version request for this server."""
451         try:
452             version = oauth_request.get_parameter('oauth_version')
453         except:
454             version = VERSION
455         if version and version != self.version:
456             raise OAuthError('OAuth version %s not supported.' % str(version))
457         return version
458
459     def _get_signature_method(self, oauth_request):
460         """Figure out the signature with some defaults."""
461         try:
462             signature_method = oauth_request.get_parameter(
463                 'oauth_signature_method')
464         except:
465             signature_method = SIGNATURE_METHOD
466         try:
467             # Get the signature method object.
468             signature_method = self.signature_methods[signature_method]
469         except:
470             signature_method_names = ', '.join(self.signature_methods.keys())
471             raise OAuthError('Signature method %s not supported try one of the '
472                 'following: %s' % (signature_method, signature_method_names))
473
474         return signature_method
475
476     def _get_consumer(self, oauth_request):
477         consumer_key = oauth_request.get_parameter('oauth_consumer_key')
478         consumer = self.data_store.lookup_consumer(consumer_key)
479         if not consumer:
480             raise OAuthError('Invalid consumer.')
481         return consumer
482
483     def _get_token(self, oauth_request, token_type='access'):
484         """Try to find the token for the provided request token key."""
485         token_field = oauth_request.get_parameter('oauth_token')
486         token = self.data_store.lookup_token(token_type, token_field)
487         if not token:
488             raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
489         return token
490     
491     def _get_verifier(self, oauth_request):
492         return oauth_request.get_parameter('oauth_verifier')
493
494     def _check_signature(self, oauth_request, consumer, token):
495         timestamp, nonce = oauth_request._get_timestamp_nonce()
496         self._check_timestamp(timestamp)
497         self._check_nonce(consumer, token, nonce)
498         signature_method = self._get_signature_method(oauth_request)
499         try:
500             signature = oauth_request.get_parameter('oauth_signature')
501         except:
502             raise OAuthError('Missing signature.')
503         # Validate the signature.
504         valid_sig = signature_method.check_signature(oauth_request, consumer,
505             token, signature)
506         if not valid_sig:
507             key, base = signature_method.build_signature_base_string(
508                 oauth_request, consumer, token)
509             raise OAuthError('Invalid signature. Expected signature base '
510                 'string: %s' % base)
511         built = signature_method.build_signature(oauth_request, consumer, token)
512
513     def _check_timestamp(self, timestamp):
514         """Verify that timestamp is recentish."""
515         timestamp = int(timestamp)
516         now = int(time.time())
517         lapsed = now - timestamp
518         if lapsed > self.timestamp_threshold:
519             raise OAuthError('Expired timestamp: given %d and now %s has a '
520                 'greater difference than threshold %d' %
521                 (timestamp, now, self.timestamp_threshold))
522
523     def _check_nonce(self, consumer, token, nonce):
524         """Verify that the nonce is uniqueish."""
525         nonce = self.data_store.lookup_nonce(consumer, token, nonce)
526         if nonce:
527             raise OAuthError('Nonce already used: %s' % str(nonce))
528
529
530 class OAuthClient(object):
531     """OAuthClient is a worker to attempt to execute a request."""
532     consumer = None
533     token = None
534
535     def __init__(self, oauth_consumer, oauth_token):
536         self.consumer = oauth_consumer
537         self.token = oauth_token
538
539     def get_consumer(self):
540         return self.consumer
541
542     def get_token(self):
543         return self.token
544
545     def fetch_request_token(self, oauth_request):
546         """-> OAuthToken."""
547         raise NotImplementedError
548
549     def fetch_access_token(self, oauth_request):
550         """-> OAuthToken."""
551         raise NotImplementedError
552
553     def access_resource(self, oauth_request):
554         """-> Some protected resource."""
555         raise NotImplementedError
556
557
558 class OAuthDataStore(object):
559     """A database abstraction used to lookup consumers and tokens."""
560
561     def lookup_consumer(self, key):
562         """-> OAuthConsumer."""
563         raise NotImplementedError
564
565     def lookup_token(self, oauth_consumer, token_type, token_token):
566         """-> OAuthToken."""
567         raise NotImplementedError
568
569     def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
570         """-> OAuthToken."""
571         raise NotImplementedError
572
573     def fetch_request_token(self, oauth_consumer, oauth_callback):
574         """-> OAuthToken."""
575         raise NotImplementedError
576
577     def fetch_access_token(self, oauth_consumer, oauth_token, oauth_verifier):
578         """-> OAuthToken."""
579         raise NotImplementedError
580
581     def authorize_request_token(self, oauth_token, user):
582         """-> OAuthToken."""
583         raise NotImplementedError
584
585
586 class OAuthSignatureMethod(object):
587     """A strategy class that implements a signature method."""
588     def get_name(self):
589         """-> str."""
590         raise NotImplementedError
591
592     def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
593         """-> str key, str raw."""
594         raise NotImplementedError
595
596     def build_signature(self, oauth_request, oauth_consumer, oauth_token):
597         """-> str."""
598         raise NotImplementedError
599
600     def check_signature(self, oauth_request, consumer, token, signature):
601         built = self.build_signature(oauth_request, consumer, token)
602         return built == signature
603
604
605 class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
606
607     def get_name(self):
608         return 'HMAC-SHA1'
609         
610     def build_signature_base_string(self, oauth_request, consumer, token):
611         sig = (
612             escape(oauth_request.get_normalized_http_method()),
613             escape(oauth_request.get_normalized_http_url()),
614             escape(oauth_request.get_normalized_parameters()),
615         )
616
617         key = '%s&' % escape(consumer.secret)
618         if token:
619             key += escape(token.secret)
620         raw = '&'.join(sig)
621         return key, raw
622
623     def build_signature(self, oauth_request, consumer, token):
624         """Builds the base signature string."""
625         key, raw = self.build_signature_base_string(oauth_request, consumer,
626             token)
627
628         # HMAC object.
629         try:
630             import hashlib # 2.5
631             hashed = hmac.new(key, raw, hashlib.sha1)
632         except:
633             import sha # Deprecated
634             hashed = hmac.new(key, raw, sha)
635
636         # Calculate the digest base 64.
637         return binascii.b2a_base64(hashed.digest())[:-1]
638
639
640 class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
641
642     def get_name(self):
643         return 'PLAINTEXT'
644
645     def build_signature_base_string(self, oauth_request, consumer, token):
646         """Concatenates the consumer key and secret."""
647         sig = '%s&' % escape(consumer.secret)
648         if token:
649             sig = sig + escape(token.secret)
650         return sig, sig
651
652     def build_signature(self, oauth_request, consumer, token):
653         key, raw = self.build_signature_base_string(oauth_request, consumer,
654             token)
655         return key

Benjamin Mako Hill || Want to submit a patch?