import requests_oauthlib (a dependency)
[yelp-api-cdsw] / requests_oauthlib / oauth2_session.py
1 from __future__ import unicode_literals
2 import os
3 import requests
4 from oauthlib.common import log, generate_token, urldecode
5 from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
6 from oauthlib.oauth2 import TokenExpiredError
7
8 from .utils import is_secure_transport
9
10
11 class TokenUpdated(Warning):
12     def __init__(self, token):
13         super(TokenUpdated, self).__init__()
14         self.token = token
15
16
17 class OAuth2Session(requests.Session):
18     """Versatile OAuth 2 extension to :class:`requests.Session`.
19
20     Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec
21     including the four core OAuth 2 grants.
22
23     Can be used to create authorization urls, fetch tokens and access protected
24     resources using the :class:`requests.Session` interface you are used to.
25
26     - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant
27     - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant
28     - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant
29     - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant
30
31     Note that the only time you will be using Implicit Grant from python is if
32     you are driving a user agent able to obtain URL fragments.
33     """
34
35     def __init__(self, client_id=None, client=None, auto_refresh_url=None,
36             auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None,
37             state=None, state_generator=None, token_updater=None, **kwargs):
38         """Construct a new OAuth 2 client session.
39
40         :param client_id: Client id obtained during registration
41         :param client: :class:`oauthlib.oauth2.Client` to be used. Default is
42                        WebApplicationClient which is useful for any
43                        hosted application but not mobile or desktop.
44         :param scope: List of scopes you wish to request access to
45         :param redirect_uri: Redirect URI you registered as callback
46         :param token: Token dictionary, must include access_token
47                       and token_type.
48         :param state: State string used to prevent CSRF. This will be given
49                       when creating the authorization url and must be supplied
50                       when parsing the authorization response.
51                       Can be either a string or a no argument callable.
52         :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply
53                            this if you wish the client to automatically refresh
54                            your access tokens.
55         :auto_refresh_kwargs: Extra arguments to pass to the refresh token
56                               endpoint.
57         :token_updater: Method with one argument, token, to be used to update
58                         your token databse on automatic token refresh. If not
59                         set a TokenUpdated warning will be raised when a token
60                         has been refreshed. This warning will carry the token
61                         in its token argument.
62         :param kwargs: Arguments to pass to the Session constructor.
63         """
64         super(OAuth2Session, self).__init__(**kwargs)
65         self.client_id = client_id or client.client_id
66         self.scope = scope
67         self.redirect_uri = redirect_uri
68         self.token = token or {}
69         self.state = state or generate_token
70         self._state = None
71         self.auto_refresh_url = auto_refresh_url
72         self.auto_refresh_kwargs = auto_refresh_kwargs or {}
73         self.token_updater = token_updater
74         self._client = client or WebApplicationClient(client_id, token=token)
75         self._client._populate_attributes(token or {})
76
77         # Allow customizations for non compliant providers through various
78         # hooks to adjust requests and responses.
79         self.compliance_hook = {
80             'access_token_response': set([]),
81             'refresh_token_response': set([]),
82             'protected_request': set([]),
83         }
84
85     def new_state(self):
86         """Generates a state string to be used in authorizations."""
87         try:
88             self._state = self.state()
89             log.debug('Generated new state %s.', self._state)
90         except TypeError:
91             self._state = self.state
92             log.debug('Re-using previously supplied state %s.', self._state)
93         return self._state
94
95     def authorization_url(self, url, **kwargs):
96         """Form an authorization URL.
97
98         :param url: Authorization endpoint url, must be HTTPS.
99         :param kwargs: Extra parameters to include.
100         :return: authorization_url, state
101         """
102         state = self.new_state()
103         return self._client.prepare_request_uri(url,
104                 redirect_uri=self.redirect_uri,
105                 scope=self.scope,
106                 state=state,
107                 **kwargs), state
108
109     def fetch_token(self, token_url, code=None, authorization_response=None,
110             body='', auth=None, username=None, password=None, **kwargs):
111         """Generic method for fetching an access token from the token endpoint.
112
113         If you are using the MobileApplicationClient you will want to use
114         token_from_fragment instead of fetch_token.
115
116         :param token_url: Token endpoint URL, must use HTTPS.
117         :param code: Authorization code (used by WebApplicationClients).
118         :param authorization_response: Authorization response URL, the callback
119                                        URL of the request back to you. Used by
120                                        WebApplicationClients instead of code.
121         :param body: Optional application/x-www-form-urlencoded body to add the
122                      include in the token request. Prefer kwargs over body.
123         :param auth: An auth tuple or method as accepted by requests.
124         :param username: Username used by LegacyApplicationClients.
125         :param password: Password used by LegacyApplicationClients.
126         :param kwargs: Extra parameters to include in the token request.
127         :return: A token dict
128         """
129         if not is_secure_transport(token_url):
130             raise InsecureTransportError()
131
132         if not code and authorization_response:
133             self._client.parse_request_uri_response(authorization_response,
134                     state=self._state)
135             code = self._client.code
136         elif not code and isinstance(self._client, WebApplicationClient):
137             code = self._client.code
138             if not code:
139                 raise ValueError('Please supply either code or '
140                                  'authorization_code parameters.')
141
142
143         body = self._client.prepare_request_body(code=code, body=body,
144                 redirect_uri=self.redirect_uri, username=username,
145                 password=password, **kwargs)
146         # (ib-lundgren) All known, to me, token requests use POST.
147         r = self.post(token_url, data=dict(urldecode(body)),
148             headers={'Accept': 'application/json'}, auth=auth)
149         log.debug('Prepared fetch token request body %s', body)
150         log.debug('Request to fetch token completed with status %s.',
151                   r.status_code)
152         log.debug('Response headers were %s and content %s.',
153                   r.headers, r.text)
154         log.debug('Invoking %d token response hooks.',
155                   len(self.compliance_hook['access_token_response']))
156         for hook in self.compliance_hook['access_token_response']:
157             log.debug('Invoking hook %s.', hook)
158             r = hook(r)
159
160         self._client.parse_request_body_response(r.text, scope=self.scope)
161         self.token = self._client.token
162         log.debug('Obtained token %s.', self.token)
163         return self.token
164
165     def token_from_fragment(self, authorization_response):
166         """Parse token from the URI fragment, used by MobileApplicationClients.
167
168         :param authorization_response: The full URL of the redirect back to you
169         :return: A token dict
170         """
171         self._client.parse_request_uri_response(authorization_response,
172                 state=self._state)
173         self.token = self._client.token
174         return self.token
175
176     def refresh_token(self, token_url, refresh_token=None, body='', auth=None,
177                       **kwargs):
178         """Fetch a new access token using a refresh token.
179
180         :param token_url: The token endpoint, must be HTTPS.
181         :param refresh_token: The refresh_token to use.
182         :param body: Optional application/x-www-form-urlencoded body to add the
183                      include in the token request. Prefer kwargs over body.
184         :param auth: An auth tuple or method as accepted by requests.
185         :param kwargs: Extra parameters to include in the token request.
186         :return: A token dict
187         """
188         if not token_url:
189             raise ValueError('No token endpoint set for auto_refresh.')
190
191         if not is_secure_transport(token_url):
192             raise InsecureTransportError()
193
194         # Need to nullify token to prevent it from being added to the request
195         refresh_token = refresh_token or self.token.get('refresh_token')
196         self.token = {}
197
198         log.debug('Adding auto refresh key word arguments %s.',
199                   self.auto_refresh_kwargs)
200         kwargs.update(self.auto_refresh_kwargs)
201         body = self._client.prepare_refresh_body(body=body,
202                 refresh_token=refresh_token, scope=self.scope, **kwargs)
203         log.debug('Prepared refresh token request body %s', body)
204         r = self.post(token_url, data=dict(urldecode(body)), auth=auth)
205         log.debug('Request to refresh token completed with status %s.',
206                   r.status_code)
207         log.debug('Response headers were %s and content %s.',
208                   r.headers, r.text)
209         log.debug('Invoking %d token response hooks.',
210                   len(self.compliance_hook['refresh_token_response']))
211         for hook in self.compliance_hook['refresh_token_response']:
212             log.debug('Invoking hook %s.', hook)
213             r = hook(r)
214
215         self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
216         if not 'refresh_token' in self.token:
217             log.debug('No new refresh token given. Re-using old.')
218             self.token['refresh_token'] = refresh_token
219         return self.token
220
221     def request(self, method, url, data=None, headers=None, **kwargs):
222         """Intercept all requests and add the OAuth 2 token if present."""
223         if not is_secure_transport(url):
224             raise InsecureTransportError()
225         if self.token:
226             log.debug('Invoking %d protected resource request hooks.',
227                       len(self.compliance_hook['protected_request']))
228             for hook in self.compliance_hook['protected_request']:
229                 log.debug('Invoking hook %s.', hook)
230                 url, headers, data = hook(url, headers, data)
231
232             log.debug('Adding token %s to request.', self.token)
233             try:
234                 url, headers, data = self._client.add_token(url,
235                         http_method=method, body=data, headers=headers)
236             # Attempt to retrieve and save new access token if expired
237             except TokenExpiredError:
238                 if self.auto_refresh_url:
239                     log.debug('Auto refresh is set, attempting to refresh at %s.',
240                               self.auto_refresh_url)
241                     token = self.refresh_token(self.auto_refresh_url)
242                     if self.token_updater:
243                         log.debug('Updating token to %s using %s.',
244                                   token, self.token_updater)
245                         self.token_updater(token)
246                         url, headers, data = self._client.add_token(url,
247                                 http_method=method, body=data, headers=headers)
248                     else:
249                         raise TokenUpdated(token)
250                 else:
251                     raise
252
253         log.debug('Requesting url %s using method %s.', url, method)
254         log.debug('Supplying headers %s and data %s', headers, data)
255         log.debug('Passing through key word arguments %s.', kwargs)
256         return super(OAuth2Session, self).request(method, url,
257                 headers=headers, data=data, **kwargs)
258
259     def register_compliance_hook(self, hook_type, hook):
260         """Register a hook for request/response tweaking.
261
262         Available hooks are:
263             access_token_response invoked before token parsing.
264             refresh_token_response invoked before refresh token parsing.
265             protected_request invoked before making a request.
266
267         If you find a new hook is needed please send a GitHub PR request
268         or open an issue.
269         """
270         if hook_type not in self.compliance_hook:
271             raise ValueError('Hook type %s is not in %s.',
272                              hook_type, self.compliance_hook)
273         self.compliance_hook[hook_type].add(hook)

Benjamin Mako Hill || Want to submit a patch?