Auth helper
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -115,5 +115,4 @@ dmypy.json
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
token.json
|
||||
sync_settings.json
|
||||
conf/
|
||||
|
||||
@@ -1,47 +1,55 @@
|
||||
import os
|
||||
import yaml
|
||||
from requests_oauthlib import OAuth2Session
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
import webbrowser
|
||||
from boxsdk import OAuth2
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
# This is necessary for testing with non-HTTPS localhost
|
||||
# Remove this if deploying to production
|
||||
os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'
|
||||
|
||||
# This is necessary because Azure does not guarantee
|
||||
# to return scopes in the same case and order as requested
|
||||
os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1'
|
||||
os.environ['OAUTHLIB_IGNORE_SCOPE_CHANGE'] = '1'
|
||||
|
||||
# Load the oauth_settings.yml file
|
||||
stream = open('oauth_settings.yml', 'r')
|
||||
settings = yaml.load(stream, Loader=yaml.BaseLoader)
|
||||
authorize_url = '{0}{1}'.format(settings['authority'], settings['authorize_endpoint'])
|
||||
token_url = '{0}{1}'.format(settings['authority'], settings['token_endpoint'])
|
||||
|
||||
# Method to generate a sign-in url
|
||||
from .const import *
|
||||
|
||||
|
||||
def get_sign_in_url():
|
||||
# Initialize the OAuth client
|
||||
aad_auth = OAuth2Session(settings['app_id'],
|
||||
scope=settings['scopes'],
|
||||
redirect_uri=settings['redirect'])
|
||||
|
||||
sign_in_url, state = aad_auth.authorization_url(authorize_url, prompt='login')
|
||||
|
||||
return sign_in_url, state
|
||||
|
||||
# Method to exchange auth code for access token
|
||||
def store_tokens(access_token: str, refresh_token: str):
|
||||
logging.info('saving new access token')
|
||||
with open(TOKEN_FILE, 'w') as outfile:
|
||||
json.dump({'access_token': access_token, 'refresh_token': refresh_token}, outfile, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
def get_token_from_code(callback_url, expected_state):
|
||||
# Initialize the OAuth client
|
||||
aad_auth = OAuth2Session(settings['app_id'],
|
||||
state=expected_state,
|
||||
scope=settings['scopes'],
|
||||
redirect_uri=settings['redirect'])
|
||||
def init_oauth():
|
||||
with open(SECRETS_FILE) as f:
|
||||
secrets = json.load(f)
|
||||
client_id = secrets['client_id']
|
||||
client_secret = secrets['client_secret']
|
||||
logging.debug('Load OAuth2 secret success')
|
||||
|
||||
token = aad_auth.fetch_token(token_url,
|
||||
client_secret=settings['app_secret'],
|
||||
authorization_response=callback_url)
|
||||
try:
|
||||
with open(TOKEN_FILE) as f:
|
||||
tokens = json.load(f)
|
||||
return OAuth2(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
store_tokens=store_tokens,
|
||||
access_token=tokens['access_token'],
|
||||
refresh_token=tokens['access_token']
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug('Fail to load OAuth2 token file', e)
|
||||
|
||||
return token
|
||||
oauth = OAuth2(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
store_tokens=store_tokens
|
||||
)
|
||||
|
||||
auth_url, csrf_token = oauth.get_authorization_url('http://localhost:8000')
|
||||
print('Initiating login at', auth_url)
|
||||
webbrowser.open(auth_url, new=2)
|
||||
|
||||
print('After logging in, please paste the entire callback URL (such as http://localhost:8000/......)')
|
||||
try:
|
||||
callback_url = urlparse(input('Paste here: '))
|
||||
callback_url_params = parse_qs(callback_url.query)
|
||||
oauth.authenticate(callback_url_params['code'][0])
|
||||
return oauth
|
||||
except Exception as e:
|
||||
print(f'Error parsing URL: {type(e).__name__}: {e}')
|
||||
sys.exit(1)
|
||||
|
||||
3
src/const.py
Normal file
3
src/const.py
Normal file
@@ -0,0 +1,3 @@
|
||||
SECRETS_FILE = 'conf/secrets.json'
|
||||
TOKEN_FILE = 'conf/token.json'
|
||||
SETTING_FILE = 'conf/sync_settings.json'
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logger():
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from requests_oauthlib import OAuth2Session
|
||||
|
||||
from .auth_helper import settings, token_url
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
def __init__(self, filename, token=None):
|
||||
self.lock = threading.Lock()
|
||||
self.filename = filename
|
||||
|
||||
if token is not None:
|
||||
self.__save_token(token)
|
||||
else:
|
||||
self.__load_token()
|
||||
|
||||
def __save_token(self, new_token):
|
||||
self.token = new_token
|
||||
with open(self.filename, 'w') as outfile:
|
||||
json.dump(new_token, outfile, indent=4, sort_keys=True)
|
||||
|
||||
def __load_token(self):
|
||||
with open(self.filename) as f:
|
||||
self.token = json.load(f)
|
||||
|
||||
def get_token(self):
|
||||
with self.lock:
|
||||
now = time.time()
|
||||
# Subtract 5 minutes from expiration to account for clock skew
|
||||
expire_time = self.token['expires_at'] - 300
|
||||
|
||||
if now >= expire_time: # Refresh the token
|
||||
logging.warning('Refreshing OAuth2 Token')
|
||||
aad_auth = OAuth2Session(settings['app_id'],
|
||||
token=self.token,
|
||||
scope=settings['scopes'],
|
||||
redirect_uri=settings['redirect'])
|
||||
|
||||
refresh_params = {
|
||||
'client_id': settings['app_id'],
|
||||
'client_secret': settings['app_secret'],
|
||||
}
|
||||
new_token = aad_auth.refresh_token(token_url, **refresh_params)
|
||||
|
||||
self.__save_token(new_token)
|
||||
return new_token
|
||||
|
||||
else:
|
||||
return self.token
|
||||
Reference in New Issue
Block a user