Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/jwt/jwks_client.py: 21%

64 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1import json 

2import urllib.request 

3from functools import lru_cache 

4from typing import Any, List, Optional 

5from urllib.error import URLError 

6 

7from .api_jwk import PyJWK, PyJWKSet 

8from .api_jwt import decode_complete as decode_token 

9from .exceptions import PyJWKClientError 

10from .jwk_set_cache import JWKSetCache 

11 

12 

13class PyJWKClient: 

14 def __init__( 

15 self, 

16 uri: str, 

17 cache_keys: bool = False, 

18 max_cached_keys: int = 16, 

19 cache_jwk_set: bool = True, 

20 lifespan: int = 300, 

21 ): 

22 self.uri = uri 

23 self.jwk_set_cache: Optional[JWKSetCache] = None 

24 

25 if cache_jwk_set: 

26 # Init jwt set cache with default or given lifespan. 

27 # Default lifespan is 300 seconds (5 minutes). 

28 if lifespan <= 0: 

29 raise PyJWKClientError( 

30 f'Lifespan must be greater than 0, the input is "{lifespan}"' 

31 ) 

32 self.jwk_set_cache = JWKSetCache(lifespan) 

33 else: 

34 self.jwk_set_cache = None 

35 

36 if cache_keys: 

37 # Cache signing keys 

38 # Ignore mypy (https://github.com/python/mypy/issues/2427) 

39 self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore 

40 

41 def fetch_data(self) -> Any: 

42 jwk_set: Any = None 

43 try: 

44 with urllib.request.urlopen(self.uri) as response: 

45 jwk_set = json.load(response) 

46 except URLError as e: 

47 raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') 

48 else: 

49 return jwk_set 

50 finally: 

51 if self.jwk_set_cache is not None: 

52 self.jwk_set_cache.put(jwk_set) 

53 

54 def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: 

55 data = None 

56 if self.jwk_set_cache is not None and not refresh: 

57 data = self.jwk_set_cache.get() 

58 

59 if data is None: 

60 data = self.fetch_data() 

61 

62 return PyJWKSet.from_dict(data) 

63 

64 def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: 

65 jwk_set = self.get_jwk_set(refresh) 

66 signing_keys = [ 

67 jwk_set_key 

68 for jwk_set_key in jwk_set.keys 

69 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id 

70 ] 

71 

72 if not signing_keys: 

73 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") 

74 

75 return signing_keys 

76 

77 def get_signing_key(self, kid: str) -> PyJWK: 

78 signing_keys = self.get_signing_keys() 

79 signing_key = self.match_kid(signing_keys, kid) 

80 

81 if not signing_key: 

82 # If no matching signing key from the jwk set, refresh the jwk set and try again. 

83 signing_keys = self.get_signing_keys(refresh=True) 

84 signing_key = self.match_kid(signing_keys, kid) 

85 

86 if not signing_key: 

87 raise PyJWKClientError( 

88 f'Unable to find a signing key that matches: "{kid}"' 

89 ) 

90 

91 return signing_key 

92 

93 def get_signing_key_from_jwt(self, token: str) -> PyJWK: 

94 unverified = decode_token(token, options={"verify_signature": False}) 

95 header = unverified["header"] 

96 return self.get_signing_key(header.get("kid")) 

97 

98 @staticmethod 

99 def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: 

100 signing_key = None 

101 

102 for key in signing_keys: 

103 if key.key_id == kid: 

104 signing_key = key 

105 break 

106 

107 return signing_key