]> projects.mako.cc - twitter-api-cdsw/blob - requests_oauthlib/oauth2_session.py
Merge pull request #3 from guyrt/master
[twitter-api-cdsw] / requests_oauthlib / oauth2_session.py
1 from __future__ import unicode_literals
2
3 import logging
4
5 from oauthlib.common import generate_token, urldecode
6 from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
7 from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
8 import requests
9
10 log = logging.getLogger(__name__)
11
12
13 class TokenUpdated(Warning):
14     def __init__(self, token):
15         super(TokenUpdated, self).__init__()
16         self.token = token
17
18
19 class OAuth2Session(requests.Session):
20     """Versatile OAuth 2 extension to :class:`requests.Session`.
21
22     Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec
23     including the four core OAuth 2 grants.
24
25     Can be used to create authorization urls, fetch tokens and access protected
26     resources using the :class:`requests.Session` interface you are used to.
27
28     - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant
29     - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant
30     - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant
31     - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant
32
33     Note that the only time you will be using Implicit Grant from python is if
34     you are driving a user agent able to obtain URL fragments.
35     """
36
37     def __init__(self, client_id=None, client=None, auto_refresh_url=None,
38             auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None,
39             state=None, token_updater=None, **kwargs):
40         """Construct a new OAuth 2 client session.
41
42         :param client_id: Client id obtained during registration
43         :param client: :class:`oauthlib.oauth2.Client` to be used. Default is
44                        WebApplicationClient which is useful for any
45                        hosted application but not mobile or desktop.
46         :param scope: List of scopes you wish to request access to
47         :param redirect_uri: Redirect URI you registered as callback
48         :param token: Token dictionary, must include access_token
49                       and token_type.
50         :param state: State string used to prevent CSRF. This will be given
51                       when creating the authorization url and must be supplied
52                       when parsing the authorization response.
53                       Can be either a string or a no argument callable.
54         :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply
55                            this if you wish the client to automatically refresh
56                            your access tokens.
57         :auto_refresh_kwargs: Extra arguments to pass to the refresh token
58                               endpoint.
59         :token_updater: Method with one argument, token, to be used to update
60                         your token databse on automatic token refresh. If not
61                         set a TokenUpdated warning will be raised when a token
62                         has been refreshed. This warning will carry the token
63                         in its token argument.
64         :param kwargs: Arguments to pass to the Session constructor.
65         """
66         super(OAuth2Session, self).__init__(**kwargs)
67         self.client_id = client_id or client.client_id
68         self.scope = scope
69         self.redirect_uri = redirect_uri
70         self.token = token or {}
71         self.state = state or generate_token
72         self._state = state
73         self.auto_refresh_url = auto_refresh_url
74         self.auto_refresh_kwargs = auto_refresh_kwargs or {}
75         self.token_updater = token_updater
76         self._client = client or WebApplicationClient(client_id, token=token)
77         self._client._populate_attributes(token or {})
78
79         # Allow customizations for non compliant providers through various
80         # hooks to adjust requests and responses.
81         self.compliance_hook = {
82             'access_token_response': set([]),
83             'refresh_token_response': set([]),
84             'protected_request': set([]),
85         }
86
87     def new_state(self):
88         """Generates a state string to be used in authorizations."""
89         try:
90             self._state = self.state()
91             log.debug('Generated new state %s.', self._state)
92         except TypeError:
93             self._state = self.state
94             log.debug('Re-using previously supplied state %s.', self._state)
95         return self._state
96
97     @property
98     def authorized(self):
99         """Boolean that indicates whether this session has an OAuth token
100         or not. If `self.authorized` is True, you can reasonably expect
101         OAuth-protected requests to the resource to succeed. If
102         `self.authorized` is False, you need the user to go through the OAuth
103         authentication dance before OAuth-protected requests to the resource
104         will succeed.
105         """
106         return bool(self._client.access_token)
107
108     def authorization_url(self, url, state=None, **kwargs):
109         """Form an authorization URL.
110
111         :param url: Authorization endpoint url, must be HTTPS.
112         :param state: An optional state string for CSRF protection. If not
113                       given it will be generated for you.
114         :param kwargs: Extra parameters to include.
115         :return: authorization_url, state
116         """
117         state = state or self.new_state()
118         return self._client.prepare_request_uri(url,
119                 redirect_uri=self.redirect_uri,
120                 scope=self.scope,
121                 state=state,
122                 **kwargs), state
123
124     def fetch_token(self, token_url, code=None, authorization_response=None,
125             body='', auth=None, username=None, password=None, method='POST',
126             timeout=None, headers=None, verify=True, **kwargs):
127         """Generic method for fetching an access token from the token endpoint.
128
129         If you are using the MobileApplicationClient you will want to use
130         token_from_fragment instead of fetch_token.
131
132         :param token_url: Token endpoint URL, must use HTTPS.
133         :param code: Authorization code (used by WebApplicationClients).
134         :param authorization_response: Authorization response URL, the callback
135                                        URL of the request back to you. Used by
136                                        WebApplicationClients instead of code.
137         :param body: Optional application/x-www-form-urlencoded body to add the
138                      include in the token request. Prefer kwargs over body.
139         :param auth: An auth tuple or method as accepted by requests.
140         :param username: Username used by LegacyApplicationClients.
141         :param password: Password used by LegacyApplicationClients.
142         :param method: The HTTP method used to make the request. Defaults
143                        to POST, but may also be GET. Other methods should
144                        be added as needed.
145         :param headers: Dict to default request headers with.
146         :param timeout: Timeout of the request in seconds.
147         :param verify: Verify SSL certificate.
148         :param kwargs: Extra parameters to include in the token request.
149         :return: A token dict
150         """
151         if not is_secure_transport(token_url):
152             raise InsecureTransportError()
153
154         if not code and authorization_response:
155             self._client.parse_request_uri_response(authorization_response,
156                     state=self._state)
157             code = self._client.code
158         elif not code and isinstance(self._client, WebApplicationClient):
159             code = self._client.code
160             if not code:
161                 raise ValueError('Please supply either code or '
162                                  'authorization_code parameters.')
163
164
165         body = self._client.prepare_request_body(code=code, body=body,
166                 redirect_uri=self.redirect_uri, username=username,
167                 password=password, **kwargs)
168
169         headers = headers or {
170             'Accept': 'application/json',
171             'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8',
172         }
173         if method.upper() == 'POST':
174             r = self.post(token_url, data=dict(urldecode(body)),
175                 timeout=timeout, headers=headers, auth=auth,
176                 verify=verify)
177             log.debug('Prepared fetch token request body %s', body)
178         elif method.upper() == 'GET':
179             # if method is not 'POST', switch body to querystring and GET
180             r = self.get(token_url, params=dict(urldecode(body)),
181                 timeout=timeout, headers=headers, auth=auth,
182                 verify=verify)
183             log.debug('Prepared fetch token request querystring %s', body)
184         else:
185             raise ValueError('The method kwarg must be POST or GET.')
186
187         log.debug('Request to fetch token completed with status %s.',
188                   r.status_code)
189         log.debug('Request headers were %s', r.request.headers)
190         log.debug('Request body was %s', r.request.body)
191         log.debug('Response headers were %s and content %s.',
192                   r.headers, r.text)
193         log.debug('Invoking %d token response hooks.',
194                   len(self.compliance_hook['access_token_response']))
195         for hook in self.compliance_hook['access_token_response']:
196             log.debug('Invoking hook %s.', hook)
197             r = hook(r)
198
199         r.raise_for_status()
200
201         self._client.parse_request_body_response(r.text, scope=self.scope)
202         self.token = self._client.token
203         log.debug('Obtained token %s.', self.token)
204         return self.token
205
206     def token_from_fragment(self, authorization_response):
207         """Parse token from the URI fragment, used by MobileApplicationClients.
208
209         :param authorization_response: The full URL of the redirect back to you
210         :return: A token dict
211         """
212         self._client.parse_request_uri_response(authorization_response,
213                 state=self._state)
214         self.token = self._client.token
215         return self.token
216
217     def refresh_token(self, token_url, refresh_token=None, body='', auth=None,
218                       timeout=None, verify=True, **kwargs):
219         """Fetch a new access token using a refresh token.
220
221         :param token_url: The token endpoint, must be HTTPS.
222         :param refresh_token: The refresh_token to use.
223         :param body: Optional application/x-www-form-urlencoded body to add the
224                      include in the token request. Prefer kwargs over body.
225         :param auth: An auth tuple or method as accepted by requests.
226         :param timeout: Timeout of the request in seconds.
227         :param verify: Verify SSL certificate.
228         :param kwargs: Extra parameters to include in the token request.
229         :return: A token dict
230         """
231         if not token_url:
232             raise ValueError('No token endpoint set for auto_refresh.')
233
234         if not is_secure_transport(token_url):
235             raise InsecureTransportError()
236
237         # Need to nullify token to prevent it from being added to the request
238         refresh_token = refresh_token or self.token.get('refresh_token')
239         self.token = {}
240
241         log.debug('Adding auto refresh key word arguments %s.',
242                   self.auto_refresh_kwargs)
243         kwargs.update(self.auto_refresh_kwargs)
244         body = self._client.prepare_refresh_body(body=body,
245                 refresh_token=refresh_token, scope=self.scope, **kwargs)
246         log.debug('Prepared refresh token request body %s', body)
247         r = self.post(token_url, data=dict(urldecode(body)), auth=auth,
248                       timeout=timeout, verify=verify)
249         log.debug('Request to refresh token completed with status %s.',
250                   r.status_code)
251         log.debug('Response headers were %s and content %s.',
252                   r.headers, r.text)
253         log.debug('Invoking %d token response hooks.',
254                   len(self.compliance_hook['refresh_token_response']))
255         for hook in self.compliance_hook['refresh_token_response']:
256             log.debug('Invoking hook %s.', hook)
257             r = hook(r)
258
259         self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
260         if not 'refresh_token' in self.token:
261             log.debug('No new refresh token given. Re-using old.')
262             self.token['refresh_token'] = refresh_token
263         return self.token
264
265     def request(self, method, url, data=None, headers=None, **kwargs):
266         """Intercept all requests and add the OAuth 2 token if present."""
267         if not is_secure_transport(url):
268             raise InsecureTransportError()
269         if self.token:
270             log.debug('Invoking %d protected resource request hooks.',
271                       len(self.compliance_hook['protected_request']))
272             for hook in self.compliance_hook['protected_request']:
273                 log.debug('Invoking hook %s.', hook)
274                 url, headers, data = hook(url, headers, data)
275
276             log.debug('Adding token %s to request.', self.token)
277             try:
278                 url, headers, data = self._client.add_token(url,
279                         http_method=method, body=data, headers=headers)
280             # Attempt to retrieve and save new access token if expired
281             except TokenExpiredError:
282                 if self.auto_refresh_url:
283                     log.debug('Auto refresh is set, attempting to refresh at %s.',
284                               self.auto_refresh_url)
285                     token = self.refresh_token(self.auto_refresh_url)
286                     if self.token_updater:
287                         log.debug('Updating token to %s using %s.',
288                                   token, self.token_updater)
289                         self.token_updater(token)
290                         url, headers, data = self._client.add_token(url,
291                                 http_method=method, body=data, headers=headers)
292                     else:
293                         raise TokenUpdated(token)
294                 else:
295                     raise
296
297         log.debug('Requesting url %s using method %s.', url, method)
298         log.debug('Supplying headers %s and data %s', headers, data)
299         log.debug('Passing through key word arguments %s.', kwargs)
300         return super(OAuth2Session, self).request(method, url,
301                 headers=headers, data=data, **kwargs)
302
303     def register_compliance_hook(self, hook_type, hook):
304         """Register a hook for request/response tweaking.
305
306         Available hooks are:
307             access_token_response invoked before token parsing.
308             refresh_token_response invoked before refresh token parsing.
309             protected_request invoked before making a request.
310
311         If you find a new hook is needed please send a GitHub PR request
312         or open an issue.
313         """
314         if hook_type not in self.compliance_hook:
315             raise ValueError('Hook type %s is not in %s.',
316                              hook_type, self.compliance_hook)
317         self.compliance_hook[hook_type].add(hook)

Benjamin Mako Hill || Want to submit a patch?