Auth helper
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -115,5 +115,4 @@ dmypy.json
|
|||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
token.json
|
conf/
|
||||||
sync_settings.json
|
|
||||||
|
|||||||
@@ -1,47 +1,55 @@
|
|||||||
import os
|
import logging
|
||||||
import yaml
|
import json
|
||||||
from requests_oauthlib import OAuth2Session
|
import sys
|
||||||
|
import webbrowser
|
||||||
|
from boxsdk import OAuth2
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
|
||||||
# This is necessary for testing with non-HTTPS localhost
|
from .const import *
|
||||||
# Remove this if deploying to production
|
|
||||||
os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'
|
|
||||||
|
|
||||||
# This is necessary because Azure does not guarantee
|
|
||||||
# to return scopes in the same case and order as requested
|
|
||||||
os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1'
|
|
||||||
os.environ['OAUTHLIB_IGNORE_SCOPE_CHANGE'] = '1'
|
|
||||||
|
|
||||||
# Load the oauth_settings.yml file
|
|
||||||
stream = open('oauth_settings.yml', 'r')
|
|
||||||
settings = yaml.load(stream, Loader=yaml.BaseLoader)
|
|
||||||
authorize_url = '{0}{1}'.format(settings['authority'], settings['authorize_endpoint'])
|
|
||||||
token_url = '{0}{1}'.format(settings['authority'], settings['token_endpoint'])
|
|
||||||
|
|
||||||
# Method to generate a sign-in url
|
|
||||||
|
|
||||||
|
|
||||||
def get_sign_in_url():
|
def store_tokens(access_token: str, refresh_token: str):
|
||||||
# Initialize the OAuth client
|
logging.info('saving new access token')
|
||||||
aad_auth = OAuth2Session(settings['app_id'],
|
with open(TOKEN_FILE, 'w') as outfile:
|
||||||
scope=settings['scopes'],
|
json.dump({'access_token': access_token, 'refresh_token': refresh_token}, outfile, indent=4, sort_keys=True)
|
||||||
redirect_uri=settings['redirect'])
|
|
||||||
|
|
||||||
sign_in_url, state = aad_auth.authorization_url(authorize_url, prompt='login')
|
|
||||||
|
|
||||||
return sign_in_url, state
|
|
||||||
|
|
||||||
# Method to exchange auth code for access token
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_from_code(callback_url, expected_state):
|
def init_oauth():
|
||||||
# Initialize the OAuth client
|
with open(SECRETS_FILE) as f:
|
||||||
aad_auth = OAuth2Session(settings['app_id'],
|
secrets = json.load(f)
|
||||||
state=expected_state,
|
client_id = secrets['client_id']
|
||||||
scope=settings['scopes'],
|
client_secret = secrets['client_secret']
|
||||||
redirect_uri=settings['redirect'])
|
logging.debug('Load OAuth2 secret success')
|
||||||
|
|
||||||
token = aad_auth.fetch_token(token_url,
|
try:
|
||||||
client_secret=settings['app_secret'],
|
with open(TOKEN_FILE) as f:
|
||||||
authorization_response=callback_url)
|
tokens = json.load(f)
|
||||||
|
return OAuth2(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
store_tokens=store_tokens,
|
||||||
|
access_token=tokens['access_token'],
|
||||||
|
refresh_token=tokens['access_token']
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug('Fail to load OAuth2 token file', e)
|
||||||
|
|
||||||
return token
|
oauth = OAuth2(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
store_tokens=store_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_url, csrf_token = oauth.get_authorization_url('http://localhost:8000')
|
||||||
|
print('Initiating login at', auth_url)
|
||||||
|
webbrowser.open(auth_url, new=2)
|
||||||
|
|
||||||
|
print('After logging in, please paste the entire callback URL (such as http://localhost:8000/......)')
|
||||||
|
try:
|
||||||
|
callback_url = urlparse(input('Paste here: '))
|
||||||
|
callback_url_params = parse_qs(callback_url.query)
|
||||||
|
oauth.authenticate(callback_url_params['code'][0])
|
||||||
|
return oauth
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error parsing URL: {type(e).__name__}: {e}')
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
3
src/const.py
Normal file
3
src/const.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
SECRETS_FILE = 'conf/secrets.json'
|
||||||
|
TOKEN_FILE = 'conf/token.json'
|
||||||
|
SETTING_FILE = 'conf/sync_settings.json'
|
||||||
16
src/setup.py
16
src/setup.py
@@ -1,11 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
root = logging.getLogger()
|
|
||||||
root.setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
def setup_logger():
|
||||||
handler.setLevel(logging.DEBUG)
|
root = logging.getLogger()
|
||||||
formatter = logging.Formatter('%(levelname)s - %(message)s')
|
root.setLevel(logging.WARNING)
|
||||||
handler.setFormatter(formatter)
|
|
||||||
root.addHandler(handler)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setLevel(logging.DEBUG)
|
||||||
|
formatter = logging.Formatter('%(levelname)s - %(message)s')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
root.addHandler(handler)
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import logging
|
|
||||||
from requests_oauthlib import OAuth2Session
|
|
||||||
|
|
||||||
from .auth_helper import settings, token_url
|
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
|
||||||
|
|
||||||
def __init__(self, filename, token=None):
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.filename = filename
|
|
||||||
|
|
||||||
if token is not None:
|
|
||||||
self.__save_token(token)
|
|
||||||
else:
|
|
||||||
self.__load_token()
|
|
||||||
|
|
||||||
def __save_token(self, new_token):
|
|
||||||
self.token = new_token
|
|
||||||
with open(self.filename, 'w') as outfile:
|
|
||||||
json.dump(new_token, outfile, indent=4, sort_keys=True)
|
|
||||||
|
|
||||||
def __load_token(self):
|
|
||||||
with open(self.filename) as f:
|
|
||||||
self.token = json.load(f)
|
|
||||||
|
|
||||||
def get_token(self):
|
|
||||||
with self.lock:
|
|
||||||
now = time.time()
|
|
||||||
# Subtract 5 minutes from expiration to account for clock skew
|
|
||||||
expire_time = self.token['expires_at'] - 300
|
|
||||||
|
|
||||||
if now >= expire_time: # Refresh the token
|
|
||||||
logging.warning('Refreshing OAuth2 Token')
|
|
||||||
aad_auth = OAuth2Session(settings['app_id'],
|
|
||||||
token=self.token,
|
|
||||||
scope=settings['scopes'],
|
|
||||||
redirect_uri=settings['redirect'])
|
|
||||||
|
|
||||||
refresh_params = {
|
|
||||||
'client_id': settings['app_id'],
|
|
||||||
'client_secret': settings['app_secret'],
|
|
||||||
}
|
|
||||||
new_token = aad_auth.refresh_token(token_url, **refresh_params)
|
|
||||||
|
|
||||||
self.__save_token(new_token)
|
|
||||||
return new_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
return self.token
|
|
||||||
Reference in New Issue
Block a user