[auth] added "debug flag" check and wrote relative tests close #12
This commit is contained in:
parent
ed54127c81
commit
d944d45bbf
|
@ -30,12 +30,34 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from os import environ
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||||
UnauthenticatedUser)
|
UnauthenticatedUser)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger('halfapi')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..conf import PRODUCTION
|
||||||
|
except ImportError:
|
||||||
|
logger.warning('Could not import PRODUCTION variable from conf module,'\
|
||||||
|
' using HALFAPI_PROD environment variable')
|
||||||
|
PRODUCTION = environ.get('HALFAPI_PROD') or False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..conf import SECRET
|
||||||
|
except ImportError:
|
||||||
|
logger.warning('Could not import SECRET variable from conf module,'\
|
||||||
|
' using HALFAPI_SECRET environment variable')
|
||||||
|
SECRET = environ.get('HALFAPI_SECRET', False)
|
||||||
|
if not SECRET:
|
||||||
|
raise Exception('Missing HALFAPI_SECRET variable')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JWTUser(BaseUser):
|
class JWTUser(BaseUser):
|
||||||
def __init__(self, id: UUID, token: str, payload: dict) -> None:
|
def __init__(self, id: UUID, token: str, payload: dict) -> None:
|
||||||
|
@ -64,13 +86,14 @@ class JWTUser(BaseUser):
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthenticationBackend(AuthenticationBackend):
|
class JWTAuthenticationBackend(AuthenticationBackend):
|
||||||
def __init__(self, secret_key: str, algorithm: str = 'HS256', prefix: str = 'JWT', name: str = 'name'):
|
def __init__(self, secret_key: str = SECRET,
|
||||||
|
algorithm: str = 'HS256', prefix: str = 'JWT'):
|
||||||
|
|
||||||
if secret_key is None:
|
if secret_key is None:
|
||||||
raise Exception('Missing secret_key argument for JWTAuthenticationBackend')
|
raise Exception('Missing secret_key argument for JWTAuthenticationBackend')
|
||||||
self.secret_key = secret_key
|
self.secret_key = secret_key
|
||||||
self.algorithm = algorithm
|
self.algorithm = algorithm
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.id = id
|
|
||||||
|
|
||||||
async def authenticate(self, request):
|
async def authenticate(self, request):
|
||||||
if "Authorization" not in request.headers:
|
if "Authorization" not in request.headers:
|
||||||
|
@ -78,7 +101,15 @@ class JWTAuthenticationBackend(AuthenticationBackend):
|
||||||
|
|
||||||
token = request.headers["Authorization"]
|
token = request.headers["Authorization"]
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm)
|
payload = jwt.decode(token,
|
||||||
|
key=self.secret_key,
|
||||||
|
algorithms=self.algorithm,
|
||||||
|
verify=True)
|
||||||
|
|
||||||
|
if PRODUCTION and 'debug' in payload.keys():
|
||||||
|
raise AuthenticationError(
|
||||||
|
'Trying to connect using *DEBUG* token in *PRODUCTION* mode')
|
||||||
|
|
||||||
except jwt.InvalidTokenError as e:
|
except jwt.InvalidTokenError as e:
|
||||||
raise AuthenticationError(str(e))
|
raise AuthenticationError(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,17 +1,33 @@
|
||||||
|
import os
|
||||||
import jwt
|
import jwt
|
||||||
import requests
|
from requests import Request
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
import json
|
import json
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
import sys
|
import sys
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from starlette.testclient import TestClient
|
from uuid import uuid4, UUID
|
||||||
|
|
||||||
from halfapi.app import app
|
from starlette.testclient import TestClient
|
||||||
from halfapi.lib.jwt_middleware import (JWTUser, JWTAuthenticationBackend,
|
from starlette.authentication import (
|
||||||
|
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||||
|
UnauthenticatedUser)
|
||||||
|
|
||||||
|
|
||||||
|
#from halfapi.app import app
|
||||||
|
os.environ['HALFAPI_PROD'] = 'True'
|
||||||
|
os.environ['HALFAPI_SECRET'] = 'randomsecret'
|
||||||
|
|
||||||
|
from halfapi.lib.jwt_middleware import (PRODUCTION, SECRET,
|
||||||
|
JWTUser, JWTAuthenticationBackend,
|
||||||
JWTWebSocketAuthenticationBackend)
|
JWTWebSocketAuthenticationBackend)
|
||||||
|
|
||||||
|
def test_constants():
|
||||||
|
assert PRODUCTION == bool(os.environ['HALFAPI_PROD'])
|
||||||
|
assert SECRET == os.environ['HALFAPI_SECRET']
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token():
|
def token():
|
||||||
# This fixture needs to have a running auth-lirmm on 127.0.0.1:3000
|
# This fixture needs to have a running auth-lirmm on 127.0.0.1:3000
|
||||||
|
@ -33,6 +49,16 @@ def token():
|
||||||
|
|
||||||
return res['token']
|
return res['token']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def token_builder():
|
||||||
|
yield jwt.encode({
|
||||||
|
'name':'xxx',
|
||||||
|
'id': str(uuid4())},
|
||||||
|
key=SECRET
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token_dirser():
|
def token_dirser():
|
||||||
# This fixture needs to have a running auth-lirmm on 127.0.0.1:3000
|
# This fixture needs to have a running auth-lirmm on 127.0.0.1:3000
|
||||||
|
@ -55,6 +81,7 @@ def token_dirser():
|
||||||
return res['token']
|
return res['token']
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
def test_token(token):
|
def test_token(token):
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@ -90,3 +117,28 @@ def test_labopers(token, token_dirser):
|
||||||
})
|
})
|
||||||
|
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_JWTUser():
|
||||||
|
uid = uuid4()
|
||||||
|
token = '{}'
|
||||||
|
payload = {}
|
||||||
|
user = JWTUser(uid, token, payload)
|
||||||
|
assert user.id == uid
|
||||||
|
assert user.token == token
|
||||||
|
assert user.payload == payload
|
||||||
|
assert user.is_authenticated == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_JWTAuthenticationBackend(token_builder):
|
||||||
|
backend = JWTAuthenticationBackend()
|
||||||
|
assert backend.secret_key == SECRET
|
||||||
|
|
||||||
|
req = Request(
|
||||||
|
headers={
|
||||||
|
'Authorization': token_builder
|
||||||
|
})
|
||||||
|
|
||||||
|
credentials, user = await backend.authenticate(req)
|
||||||
|
assert type(user) == JWTUser
|
||||||
|
assert type(credentials) == AuthCredentials
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
import jwt
|
||||||
|
from requests import Request
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
import json
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
import sys
|
||||||
|
from hashlib import sha256
|
||||||
|
from base64 import b64decode
|
||||||
|
from uuid import uuid4, UUID
|
||||||
|
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
from starlette.authentication import (
|
||||||
|
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials,
|
||||||
|
UnauthenticatedUser)
|
||||||
|
|
||||||
|
|
||||||
|
#from halfapi.app import app
|
||||||
|
os.environ['HALFAPI_PROD'] = ''
|
||||||
|
os.environ['HALFAPI_SECRET'] = 'randomsecret'
|
||||||
|
|
||||||
|
from halfapi.lib.jwt_middleware import (PRODUCTION, SECRET,
|
||||||
|
JWTUser, JWTAuthenticationBackend,
|
||||||
|
JWTWebSocketAuthenticationBackend)
|
||||||
|
|
||||||
|
def test_constants():
|
||||||
|
assert PRODUCTION == bool(os.environ['HALFAPI_PROD'])
|
||||||
|
assert SECRET == os.environ['HALFAPI_SECRET']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def token_debug_builder():
|
||||||
|
yield jwt.encode({
|
||||||
|
'name':'xxx',
|
||||||
|
'id': str(uuid4()),
|
||||||
|
'debug': True},
|
||||||
|
key=SECRET
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_JWTAuthenticationBackend_debug(token_debug_builder):
|
||||||
|
backend = JWTAuthenticationBackend()
|
||||||
|
|
||||||
|
req = Request(
|
||||||
|
headers={
|
||||||
|
'Authorization': token_debug_builder
|
||||||
|
})
|
||||||
|
|
||||||
|
auth = await backend.authenticate(req)
|
||||||
|
assert(len(auth) == 2)
|
||||||
|
assert type(auth[0]) == AuthCredentials
|
||||||
|
assert type(auth[1]) == JWTUser
|
Loading…
Reference in New Issue