462 lines
14 KiB
Python
462 lines
14 KiB
Python
# Copyright (c) Twisted Matrix Laboratories.
|
|
# See LICENSE for details.
|
|
|
|
"""
|
|
Tests for L{twisted.cred}, now with 30% more starch.
|
|
"""
|
|
|
|
|
|
from binascii import hexlify, unhexlify
|
|
|
|
from zope.interface import Interface, implementer
|
|
|
|
from twisted.cred import checkers, credentials, error, portal
|
|
from twisted.internet import defer
|
|
from twisted.python import components
|
|
from twisted.python.versions import Version
|
|
from twisted.trial import unittest
|
|
|
|
try:
|
|
from crypt import crypt as _crypt
|
|
except ImportError:
|
|
crypt = None
|
|
else:
|
|
crypt = _crypt
|
|
|
|
|
|
# The Twisted version in which UsernameHashedPassword is first deprecated.
|
|
_uhpVersion = Version("Twisted", 21, 2, 0)
|
|
|
|
|
|
class ITestable(Interface):
|
|
"""
|
|
An interface for a theoretical protocol.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class TestAvatar:
|
|
"""
|
|
A test avatar.
|
|
"""
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.loggedIn = False
|
|
self.loggedOut = False
|
|
|
|
def login(self):
|
|
assert not self.loggedIn
|
|
self.loggedIn = True
|
|
|
|
def logout(self):
|
|
self.loggedOut = True
|
|
|
|
|
|
@implementer(ITestable)
|
|
class Testable(components.Adapter):
|
|
"""
|
|
A theoretical protocol for testing.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
components.registerAdapter(Testable, TestAvatar, ITestable)
|
|
|
|
|
|
class IDerivedCredentials(credentials.IUsernamePassword):
|
|
pass
|
|
|
|
|
|
@implementer(IDerivedCredentials, ITestable)
|
|
class DerivedCredentials:
|
|
def __init__(self, username, password):
|
|
self.username = username
|
|
self.password = password
|
|
|
|
def checkPassword(self, password):
|
|
return password == self.password
|
|
|
|
|
|
@implementer(portal.IRealm)
|
|
class TestRealm:
|
|
"""
|
|
A basic test realm.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.avatars = {}
|
|
|
|
def requestAvatar(self, avatarId, mind, *interfaces):
|
|
if avatarId in self.avatars:
|
|
avatar = self.avatars[avatarId]
|
|
else:
|
|
avatar = TestAvatar(avatarId)
|
|
self.avatars[avatarId] = avatar
|
|
avatar.login()
|
|
return (interfaces[0], interfaces[0](avatar), avatar.logout)
|
|
|
|
|
|
class CredTests(unittest.TestCase):
|
|
"""
|
|
Tests for the meat of L{twisted.cred} -- realms, portals, avatars, and
|
|
checkers.
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.realm = TestRealm()
|
|
self.portal = portal.Portal(self.realm)
|
|
self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
|
|
self.checker.addUser(b"bob", b"hello")
|
|
self.portal.registerChecker(self.checker)
|
|
|
|
def test_listCheckers(self):
|
|
"""
|
|
The checkers in a portal can check only certain types of credentials.
|
|
Since this portal has
|
|
L{checkers.InMemoryUsernamePasswordDatabaseDontUse} registered, it
|
|
"""
|
|
expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPassword]
|
|
got = self.portal.listCredentialsInterfaces()
|
|
self.assertEqual(sorted(got), sorted(expected))
|
|
|
|
def test_basicLogin(self):
|
|
"""
|
|
Calling C{login} on a portal with correct credentials and an interface
|
|
that the portal's realm supports works.
|
|
"""
|
|
login = self.successResultOf(
|
|
self.portal.login(
|
|
credentials.UsernamePassword(b"bob", b"hello"), self, ITestable
|
|
)
|
|
)
|
|
iface, impl, logout = login
|
|
|
|
# whitebox
|
|
self.assertEqual(iface, ITestable)
|
|
self.assertTrue(iface.providedBy(impl), f"{impl} does not implement {iface}")
|
|
|
|
# greybox
|
|
self.assertTrue(impl.original.loggedIn)
|
|
self.assertTrue(not impl.original.loggedOut)
|
|
logout()
|
|
self.assertTrue(impl.original.loggedOut)
|
|
|
|
def test_derivedInterface(self):
|
|
"""
|
|
Logging in with correct derived credentials and an interface
|
|
that the portal's realm supports works.
|
|
"""
|
|
login = self.successResultOf(
|
|
self.portal.login(DerivedCredentials(b"bob", b"hello"), self, ITestable)
|
|
)
|
|
iface, impl, logout = login
|
|
|
|
# whitebox
|
|
self.assertEqual(iface, ITestable)
|
|
self.assertTrue(iface.providedBy(impl), f"{impl} does not implement {iface}")
|
|
|
|
# greybox
|
|
self.assertTrue(impl.original.loggedIn)
|
|
self.assertTrue(not impl.original.loggedOut)
|
|
logout()
|
|
self.assertTrue(impl.original.loggedOut)
|
|
|
|
def test_failedLoginPassword(self):
|
|
"""
|
|
Calling C{login} with incorrect credentials (in this case a wrong
|
|
password) causes L{error.UnauthorizedLogin} to be raised.
|
|
"""
|
|
login = self.failureResultOf(
|
|
self.portal.login(
|
|
credentials.UsernamePassword(b"bob", b"h3llo"), self, ITestable
|
|
)
|
|
)
|
|
self.assertTrue(login)
|
|
self.assertEqual(error.UnauthorizedLogin, login.type)
|
|
|
|
def test_failedLoginName(self):
|
|
"""
|
|
Calling C{login} with incorrect credentials (in this case no known
|
|
user) causes L{error.UnauthorizedLogin} to be raised.
|
|
"""
|
|
login = self.failureResultOf(
|
|
self.portal.login(
|
|
credentials.UsernamePassword(b"jay", b"hello"), self, ITestable
|
|
)
|
|
)
|
|
self.assertTrue(login)
|
|
self.assertEqual(error.UnauthorizedLogin, login.type)
|
|
|
|
|
|
class OnDiskDatabaseTests(unittest.TestCase):
|
|
users = [
|
|
(b"user1", b"pass1"),
|
|
(b"user2", b"pass2"),
|
|
(b"user3", b"pass3"),
|
|
]
|
|
|
|
def setUp(self):
|
|
self.dbfile = self.mktemp()
|
|
with open(self.dbfile, "wb") as f:
|
|
for u, p in self.users:
|
|
f.write(u + b":" + p + b"\n")
|
|
|
|
def test_getUserNonexistentDatabase(self):
|
|
"""
|
|
A missing db file will cause a permanent rejection of authorization
|
|
attempts.
|
|
"""
|
|
self.db = checkers.FilePasswordDB("test_thisbetternoteverexist.db")
|
|
|
|
self.assertRaises(error.UnauthorizedLogin, self.db.getUser, "user")
|
|
|
|
def testUserLookup(self):
|
|
self.db = checkers.FilePasswordDB(self.dbfile)
|
|
for u, p in self.users:
|
|
self.assertRaises(KeyError, self.db.getUser, u.upper())
|
|
self.assertEqual(self.db.getUser(u), (u, p))
|
|
|
|
def testCaseInSensitivity(self):
|
|
self.db = checkers.FilePasswordDB(self.dbfile, caseSensitive=False)
|
|
for u, p in self.users:
|
|
self.assertEqual(self.db.getUser(u.upper()), (u, p))
|
|
|
|
def testRequestAvatarId(self):
|
|
self.db = checkers.FilePasswordDB(self.dbfile)
|
|
creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
|
d = defer.gatherResults(
|
|
[defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds]
|
|
)
|
|
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
|
return d
|
|
|
|
def testRequestAvatarId_hashed(self):
|
|
self.db = checkers.FilePasswordDB(self.dbfile)
|
|
UsernameHashedPassword = self.getDeprecatedModuleAttribute(
|
|
"twisted.cred.credentials", "UsernameHashedPassword", _uhpVersion
|
|
)
|
|
creds = [UsernameHashedPassword(u, p) for u, p in self.users]
|
|
d = defer.gatherResults(
|
|
[defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds]
|
|
)
|
|
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
|
return d
|
|
|
|
|
|
class HashedPasswordOnDiskDatabaseTests(unittest.TestCase):
|
|
users = [
|
|
(b"user1", b"pass1"),
|
|
(b"user2", b"pass2"),
|
|
(b"user3", b"pass3"),
|
|
]
|
|
|
|
def setUp(self):
|
|
dbfile = self.mktemp()
|
|
self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
|
|
with open(dbfile, "wb") as f:
|
|
for u, p in self.users:
|
|
f.write(u + b":" + self.hash(u, p, u[:2]) + b"\n")
|
|
|
|
r = TestRealm()
|
|
self.port = portal.Portal(r)
|
|
self.port.registerChecker(self.db)
|
|
|
|
def hash(self, u: bytes, p: bytes, s: bytes) -> bytes:
|
|
hashed_password = crypt(p.decode("ascii"), s.decode("ascii")) # type: ignore[misc]
|
|
return hashed_password.encode("ascii")
|
|
|
|
def testGoodCredentials(self):
|
|
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
|
d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds])
|
|
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
|
return d
|
|
|
|
def testGoodCredentials_login(self):
|
|
goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
|
|
d = defer.gatherResults(
|
|
[self.port.login(c, None, ITestable) for c in goodCreds]
|
|
)
|
|
d.addCallback(lambda x: [a.original.name for i, a, l in x])
|
|
d.addCallback(self.assertEqual, [u for u, p in self.users])
|
|
return d
|
|
|
|
def testBadCredentials(self):
|
|
badCreds = [
|
|
credentials.UsernamePassword(u, b"wrong password") for u, p in self.users
|
|
]
|
|
d = defer.DeferredList(
|
|
[self.port.login(c, None, ITestable) for c in badCreds], consumeErrors=True
|
|
)
|
|
d.addCallback(self._assertFailures, error.UnauthorizedLogin)
|
|
return d
|
|
|
|
def testHashedCredentials(self):
|
|
UsernameHashedPassword = self.getDeprecatedModuleAttribute(
|
|
"twisted.cred.credentials", "UsernameHashedPassword", _uhpVersion
|
|
)
|
|
hashedCreds = [
|
|
UsernameHashedPassword(u, self.hash(None, p, u[:2])) for u, p in self.users
|
|
]
|
|
d = defer.DeferredList(
|
|
[self.port.login(c, None, ITestable) for c in hashedCreds],
|
|
consumeErrors=True,
|
|
)
|
|
d.addCallback(self._assertFailures, error.UnhandledCredentials)
|
|
return d
|
|
|
|
def _assertFailures(self, failures, *expectedFailures):
|
|
for flag, failure in failures:
|
|
self.assertEqual(flag, defer.FAILURE)
|
|
failure.trap(*expectedFailures)
|
|
return None
|
|
|
|
if crypt is None:
|
|
skip = "crypt module not available"
|
|
|
|
|
|
class CheckersMixin:
|
|
"""
|
|
L{unittest.TestCase} mixin for testing that some checkers accept
|
|
and deny specified credentials.
|
|
|
|
Subclasses must provide
|
|
- C{getCheckers} which returns a sequence of
|
|
L{checkers.ICredentialChecker}
|
|
- C{getGoodCredentials} which returns a list of 2-tuples of
|
|
credential to check and avaterId to expect.
|
|
- C{getBadCredentials} which returns a list of credentials
|
|
which are expected to be unauthorized.
|
|
"""
|
|
|
|
@defer.inlineCallbacks
|
|
def test_positive(self):
|
|
"""
|
|
The given credentials are accepted by all the checkers, and give
|
|
the expected C{avatarID}s
|
|
"""
|
|
for chk in self.getCheckers():
|
|
for cred, avatarId in self.getGoodCredentials():
|
|
r = yield chk.requestAvatarId(cred)
|
|
self.assertEqual(r, avatarId)
|
|
|
|
@defer.inlineCallbacks
|
|
def test_negative(self):
|
|
"""
|
|
The given credentials are rejected by all the checkers.
|
|
"""
|
|
for chk in self.getCheckers():
|
|
for cred in self.getBadCredentials():
|
|
d = chk.requestAvatarId(cred)
|
|
yield self.assertFailure(d, error.UnauthorizedLogin)
|
|
|
|
|
|
class HashlessFilePasswordDBMixin:
|
|
credClass = credentials.UsernamePassword
|
|
diskHash = None
|
|
|
|
@staticmethod
|
|
def networkHash(x: bytes) -> bytes:
|
|
return x
|
|
|
|
_validCredentials = [
|
|
(b"user1", b"password1"),
|
|
(b"user2", b"password2"),
|
|
(b"user3", b"password3"),
|
|
]
|
|
|
|
def getGoodCredentials(self):
|
|
for u, p in self._validCredentials:
|
|
yield self.credClass(u, self.networkHash(p)), u
|
|
|
|
def getBadCredentials(self):
|
|
for u, p in [
|
|
(b"user1", b"password3"),
|
|
(b"user2", b"password1"),
|
|
(b"bloof", b"blarf"),
|
|
]:
|
|
yield self.credClass(u, self.networkHash(p))
|
|
|
|
def getCheckers(self):
|
|
diskHash = self.diskHash or (lambda x: x)
|
|
hashCheck = self.diskHash and (
|
|
lambda username, password, stored: self.diskHash(password)
|
|
)
|
|
|
|
for cache in True, False:
|
|
fn = self.mktemp()
|
|
with open(fn, "wb") as fObj:
|
|
for u, p in self._validCredentials:
|
|
fObj.write(u + b":" + diskHash(p) + b"\n")
|
|
yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
|
|
|
|
fn = self.mktemp()
|
|
with open(fn, "wb") as fObj:
|
|
for u, p in self._validCredentials:
|
|
fObj.write(diskHash(p) + b" dingle dongle " + u + b"\n")
|
|
yield checkers.FilePasswordDB(fn, b" ", 3, 0, cache=cache, hash=hashCheck)
|
|
|
|
fn = self.mktemp()
|
|
with open(fn, "wb") as fObj:
|
|
for u, p in self._validCredentials:
|
|
fObj.write(
|
|
b"zip,zap," + u.title() + b",zup," + diskHash(p) + b"\n",
|
|
)
|
|
yield checkers.FilePasswordDB(
|
|
fn, b",", 2, 4, False, cache=cache, hash=hashCheck
|
|
)
|
|
|
|
|
|
class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
|
|
@staticmethod
|
|
def diskHash(x):
|
|
return hexlify(x)
|
|
|
|
|
|
class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
|
|
@staticmethod
|
|
def networkHash(x: bytes) -> bytes:
|
|
return hexlify(x)
|
|
|
|
class credClass(credentials.UsernamePassword):
|
|
def checkPassword(self, password):
|
|
return unhexlify(self.password) == password
|
|
|
|
|
|
class HashlessFilePasswordDBCheckerTests(
|
|
HashlessFilePasswordDBMixin, CheckersMixin, unittest.TestCase
|
|
):
|
|
pass
|
|
|
|
|
|
class LocallyHashedFilePasswordDBCheckerTests(
|
|
LocallyHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase
|
|
):
|
|
pass
|
|
|
|
|
|
class NetworkHashedFilePasswordDBCheckerTests(
|
|
NetworkHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase
|
|
):
|
|
pass
|
|
|
|
|
|
class UsernameHashedPasswordTests(unittest.TestCase):
|
|
"""
|
|
UsernameHashedPassword is a deprecated class that is functionally
|
|
equivalent to UsernamePassword.
|
|
"""
|
|
|
|
def test_deprecation(self):
|
|
"""
|
|
Tests that UsernameHashedPassword is deprecated.
|
|
"""
|
|
self.getDeprecatedModuleAttribute(
|
|
"twisted.cred.credentials",
|
|
"UsernameHashedPassword",
|
|
_uhpVersion,
|
|
"Use twisted.cred.credentials.UsernamePassword instead.",
|
|
)
|