Initial commit

This commit is contained in:
2020-01-28 14:59:07 -06:00
commit 3bb2fdfad6
108 changed files with 24266 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
__version__ = "1.4.8"

Binary file not shown.

View File

@@ -0,0 +1,466 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
KEY_GROUP_LIST = "GGGroups"
KEY_GROUP_ID = "GGGroupId"
KEY_CORE_LIST = "Cores"
KEY_CORE_ARN = "thingArn"
KEY_CA_LIST = "CAs"
KEY_CONNECTIVITY_INFO_LIST = "Connectivity"
KEY_CONNECTIVITY_INFO_ID = "Id"
KEY_HOST_ADDRESS = "HostAddress"
KEY_PORT_NUMBER = "PortNumber"
KEY_METADATA = "Metadata"
class ConnectivityInfo(object):
"""
Class the stores one set of the connectivity information.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, id, host, port, metadata):
self._id = id
self._host = host
self._port = port
self._metadata = metadata
@property
def id(self):
"""
Connectivity Information Id.
"""
return self._id
@property
def host(self):
"""
Host address.
"""
return self._host
@property
def port(self):
"""
Port number.
"""
return self._port
@property
def metadata(self):
"""
Metadata string.
"""
return self._metadata
class CoreConnectivityInfo(object):
"""
Class that stores the connectivity information for a Greengrass core.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, coreThingArn, groupId):
self._core_thing_arn = coreThingArn
self._group_id = groupId
self._connectivity_info_dict = dict()
@property
def coreThingArn(self):
"""
Thing arn for this Greengrass core.
"""
return self._core_thing_arn
@property
def groupId(self):
"""
Greengrass group id that this Greengrass core belongs to.
"""
return self._group_id
@property
def connectivityInfoList(self):
"""
The list of connectivity information that this Greengrass core has.
"""
return list(self._connectivity_info_dict.values())
def getConnectivityInfo(self, id):
"""
**Description**
Used for quickly accessing a certain set of connectivity information by id.
**Syntax**
.. code:: python
myCoreConnectivityInfo.getConnectivityInfo("CoolId")
**Parameters**
*id* - The id for the desired connectivity information.
**Return**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
"""
return self._connectivity_info_dict.get(id)
def appendConnectivityInfo(self, connectivityInfo):
"""
**Description**
Used for adding a new set of connectivity information to the list for this Greengrass core. This is used by the
SDK internally. No need to call directly from user scripts.
**Syntax**
.. code:: python
myCoreConnectivityInfo.appendConnectivityInfo(newInfo)
**Parameters**
*connectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
**Returns**
None
"""
self._connectivity_info_dict[connectivityInfo.id] = connectivityInfo
class GroupConnectivityInfo(object):
"""
Class that stores the connectivity information for a specific Greengrass group.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, groupId):
self._group_id = groupId
self._core_connectivity_info_dict = dict()
self._ca_list = list()
@property
def groupId(self):
"""
Id for this Greengrass group.
"""
return self._group_id
@property
def coreConnectivityInfoList(self):
"""
A list of Greengrass cores
(:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object) that belong to this
Greengrass group.
"""
return list(self._core_connectivity_info_dict.values())
@property
def caList(self):
"""
A list of CA content strings for this Greengrass group.
"""
return self._ca_list
def getCoreConnectivityInfo(self, coreThingArn):
"""
**Description**
Used to retrieve the corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
object by core thing arn.
**Syntax**
.. code:: python
myGroupConnectivityInfo.getCoreConnectivityInfo("YourOwnArnString")
**Parameters**
coreThingArn - Thing arn for the desired Greengrass core.
**Returns**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.CoreConnectivityInfo` object.
"""
return self._core_connectivity_info_dict.get(coreThingArn)
def appendCoreConnectivityInfo(self, coreConnectivityInfo):
"""
**Description**
Used to append new core connectivity information to this group connectivity information. This is used by the
SDK internally. No need to call directly from user scripts.
**Syntax**
.. code:: python
myGroupConnectivityInfo.appendCoreConnectivityInfo(newCoreConnectivityInfo)
**Parameters**
*coreConnectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object.
**Returns**
None
"""
self._core_connectivity_info_dict[coreConnectivityInfo.coreThingArn] = coreConnectivityInfo
def appendCa(self, ca):
"""
**Description**
Used to append new CA content string to this group connectivity information. This is used by the SDK internally.
No need to call directly from user scripts.
**Syntax**
.. code:: python
myGroupConnectivityInfo.appendCa("CaContentString")
**Parameters**
*ca* - Group CA content string.
**Returns**
None
"""
self._ca_list.append(ca)
class DiscoveryInfo(object):
"""
Class that stores the discovery information coming back from the discovery request.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, rawJson):
self._raw_json = rawJson
@property
def rawJson(self):
"""
JSON response string that contains the discovery information. This is reserved in case users want to do
some process by themselves.
"""
return self._raw_json
def getAllCores(self):
"""
**Description**
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
object for this discovery information. The retrieved cores could be from different Greengrass groups. This is
designed for uses who want to iterate through all available cores at the same time, regardless of which group
those cores are in.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllCores()
**Parameters**
None
**Returns**
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivtyInfo` object.
"""
groups_list = self.getAllGroups()
core_list = list()
for group in groups_list:
core_list.extend(group.coreConnectivityInfoList)
return core_list
def getAllCas(self):
"""
**Description**
Used to retrieve the list of :code:`(groupId, caContent)` pair for this discovery information. The retrieved
pairs could be from different Greengrass groups. This is designed for users who want to iterate through all
available cores/groups/CAs at the same time, regardless of which group those CAs belong to.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllCas()
**Parameters**
None
**Returns**
List of :code:`(groupId, caContent)` string pair, where :code:`caContent` is the CA content string and
:code:`groupId` is the group id that this CA belongs to.
"""
group_list = self.getAllGroups()
ca_list = list()
for group in group_list:
for ca in group.caList:
ca_list.append((group.groupId, ca))
return ca_list
def getAllGroups(self):
"""
**Description**
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo`
object for this discovery information. This is designed for users who want to iterate through all available
groups that this Greengrass aware device (GGAD) belongs to.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllGroups()
**Parameters**
None
**Returns**
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object.
"""
groups_dict = self.toObjectAtGroupLevel()
return list(groups_dict.values())
def toObjectAtGroupLevel(self):
"""
**Description**
Used to get a dictionary of Greengrass group discovery information, with group id string as key and the
corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object as the
value. This is designed for users who know exactly which group, which core and which set of connectivity info
they want to use for the Greengrass aware device to connect.
**Syntax**
.. code:: python
# Get to the targeted connectivity information for a specific core in a specific group
groupLevelDiscoveryInfoObj = myDiscoveryInfo.toObjectAtGroupLevel()
groupConnectivityInfoObj = groupLevelDiscoveryInfoObj.toObjectAtGroupLevel("IKnowMyGroupId")
coreConnectivityInfoObj = groupConnectivityInfoObj.getCoreConnectivityInfo("IKnowMyCoreThingArn")
connectivityInfo = coreConnectivityInfoObj.getConnectivityInfo("IKnowMyConnectivityInfoSetId")
# Now retrieve the detailed information
caList = groupConnectivityInfoObj.caList
host = connectivityInfo.host
port = connectivityInfo.port
metadata = connectivityInfo.metadata
# Actual connecting logic follows...
"""
groups_object = json.loads(self._raw_json)
groups_dict = dict()
for group_object in groups_object[KEY_GROUP_LIST]:
group_info = self._decode_group_info(group_object)
groups_dict[group_info.groupId] = group_info
return groups_dict
def _decode_group_info(self, group_object):
group_id = group_object[KEY_GROUP_ID]
group_info = GroupConnectivityInfo(group_id)
for core in group_object[KEY_CORE_LIST]:
core_info = self._decode_core_info(core, group_id)
group_info.appendCoreConnectivityInfo(core_info)
for ca in group_object[KEY_CA_LIST]:
group_info.appendCa(ca)
return group_info
def _decode_core_info(self, core_object, group_id):
core_info = CoreConnectivityInfo(core_object[KEY_CORE_ARN], group_id)
for connectivity_info_object in core_object[KEY_CONNECTIVITY_INFO_LIST]:
connectivity_info = ConnectivityInfo(connectivity_info_object[KEY_CONNECTIVITY_INFO_ID],
connectivity_info_object[KEY_HOST_ADDRESS],
connectivity_info_object[KEY_PORT_NUMBER],
connectivity_info_object.get(KEY_METADATA,''))
core_info.appendConnectivityInfo(connectivity_info)
return core_info

View File

@@ -0,0 +1,426 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure
from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo
from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder
import re
import sys
import ssl
import time
import errno
import logging
import socket
import platform
if platform.system() == 'Windows':
EAGAIN = errno.WSAEWOULDBLOCK
else:
EAGAIN = errno.EAGAIN
class DiscoveryInfoProvider(object):
REQUEST_TYPE_PREFIX = "GET "
PAYLOAD_PREFIX = "/greengrass/discover/thing/"
PAYLOAD_SUFFIX = " HTTP/1.1\r\n" # Space in the front
HOST_PREFIX = "Host: "
HOST_SUFFIX = "\r\n\r\n"
HTTP_PROTOCOL = r"HTTP/1.1 "
CONTENT_LENGTH = r"content-length: "
CONTENT_LENGTH_PATTERN = CONTENT_LENGTH + r"([0-9]+)\r\n"
HTTP_RESPONSE_CODE_PATTERN = HTTP_PROTOCOL + r"([0-9]+) "
HTTP_SC_200 = "200"
HTTP_SC_400 = "400"
HTTP_SC_401 = "401"
HTTP_SC_404 = "404"
HTTP_SC_429 = "429"
LOW_LEVEL_RC_COMPLETE = 0
LOW_LEVEL_RC_TIMEOUT = -1
_logger = logging.getLogger(__name__)
def __init__(self, caPath="", certPath="", keyPath="", host="", port=8443, timeoutSec=120):
"""
The class that provides functionality to perform a Greengrass discovery process to the cloud.
Users can perform Greengrass discovery process for a specific Greengrass aware device to retrieve
connectivity/identity information of Greengrass cores within the same group.
**Syntax**
.. code:: python
from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider
# Create a discovery information provider
myDiscoveryInfoProvider = DiscoveryInfoProvider()
# Create a discovery information provider with custom configuration
myDiscoveryInfoProvider = DiscoveryInfoProvider(caPath=myCAPath, certPath=myCertPath, keyPath=myKeyPath, host=myHost, timeoutSec=myTimeoutSec)
**Parameters**
*caPath* - Path to read the root CA file.
*certPath* - Path to read the certificate file.
*keyPath* - Path to read the private key file.
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
been timed out.
**Returns**
AWSIoTPythonSDK.core.greengrass.discovery.providers.DiscoveryInfoProvider object
"""
self._ca_path = caPath
self._cert_path = certPath
self._key_path = keyPath
self._host = host
self._port = port
self._timeout_sec = timeoutSec
self._expected_exception_map = {
self.HTTP_SC_400 : DiscoveryInvalidRequestException(),
self.HTTP_SC_401 : DiscoveryUnauthorizedException(),
self.HTTP_SC_404 : DiscoveryDataNotFoundException(),
self.HTTP_SC_429 : DiscoveryThrottlingException()
}
def configureEndpoint(self, host, port=8443):
"""
**Description**
Used to configure the host address and port number for the discovery request to hit. Should be called before
the discovery request happens.
**Syntax**
.. code:: python
# Using default port configuration, 8443
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com")
# Customize port configuration
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com", port=8888)
**Parameters**
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
**Returns**
None
"""
self._host = host
self._port = port
def configureCredentials(self, caPath, certPath, keyPath):
"""
**Description**
Used to configure the credentials for discovery request. Should be called before the discovery request happens.
**Syntax**
.. code:: python
myDiscoveryInfoProvider.configureCredentials("my/ca/path", "my/cert/path", "my/key/path")
**Parameters**
*caPath* - Path to read the root CA file.
*certPath* - Path to read the certificate file.
*keyPath* - Path to read the private key file.
**Returns**
None
"""
self._ca_path = caPath
self._cert_path = certPath
self._key_path = keyPath
def configureTimeout(self, timeoutSec):
"""
**Description**
Used to configure the time out in seconds for discovery request sending/response waiting. Should be called before
the discovery request happens.
**Syntax**
.. code:: python
# Configure the time out for discovery to be 10 seconds
myDiscoveryInfoProvider.configureTimeout(10)
**Parameters**
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
been timed out.
**Returns**
None
"""
self._timeout_sec = timeoutSec
def discover(self, thingName):
"""
**Description**
Perform the discovery request for the given Greengrass aware device thing name.
**Syntax**
.. code:: python
myDiscoveryInfoProvider.discover(thingName="myGGAD")
**Parameters**
*thingName* - Greengrass aware device thing name.
**Returns**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.DiscoveryInfo` object.
"""
self._logger.info("Starting discover request...")
self._logger.info("Endpoint: " + self._host + ":" + str(self._port))
self._logger.info("Target thing: " + thingName)
sock = self._create_tcp_connection()
ssl_sock = self._create_ssl_connection(sock)
self._raise_on_timeout(self._send_discovery_request(ssl_sock, thingName))
status_code, response_body = self._receive_discovery_response(ssl_sock)
return self._raise_if_not_200(status_code, response_body)
def _create_tcp_connection(self):
self._logger.debug("Creating tcp connection...")
try:
if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
sock = socket.create_connection((self._host, self._port))
else:
sock = socket.create_connection((self._host, self._port), source_address=("", 0))
return sock
except socket.error as err:
if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN:
raise
self._logger.debug("Created tcp connection.")
def _create_ssl_connection(self, sock):
self._logger.debug("Creating ssl connection...")
ssl_protocol_version = ssl.PROTOCOL_SSLv23
if self._port == 443:
ssl_context = SSLContextBuilder()\
.with_ca_certs(self._ca_path)\
.with_cert_key_pair(self._cert_path, self._key_path)\
.with_cert_reqs(ssl.CERT_REQUIRED)\
.with_check_hostname(True)\
.with_ciphers(None)\
.with_alpn_protocols(['x-amzn-http-ca'])\
.build()
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False)
ssl_sock.do_handshake()
else:
ssl_sock = ssl.wrap_socket(sock,
certfile=self._cert_path,
keyfile=self._key_path,
ca_certs=self._ca_path,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl_protocol_version)
self._logger.debug("Matching host name...")
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
self._tls_match_hostname(ssl_sock)
else:
ssl.match_hostname(ssl_sock.getpeercert(), self._host)
return ssl_sock
def _tls_match_hostname(self, ssl_sock):
try:
cert = ssl_sock.getpeercert()
except AttributeError:
# the getpeercert can throw Attribute error: object has no attribute 'peer_certificate'
# Don't let that crash the whole client. See also: http://bugs.python.org/issue13721
raise ssl.SSLError('Not connected')
san = cert.get('subjectAltName')
if san:
have_san_dns = False
for (key, value) in san:
if key == 'DNS':
have_san_dns = True
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
return
if key == 'IP Address':
have_san_dns = True
if value.lower() == self._host.lower():
return
if have_san_dns:
# Only check subject if subjectAltName dns not found.
raise ssl.SSLError('Certificate subject does not match remote hostname.')
subject = cert.get('subject')
if subject:
for ((key, value),) in subject:
if key == 'commonName':
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
return
raise ssl.SSLError('Certificate subject does not match remote hostname.')
def _host_matches_cert(self, host, cert_host):
if cert_host[0:2] == "*.":
if cert_host.count("*") != 1:
return False
host_match = host.split(".", 1)[1]
cert_match = cert_host.split(".", 1)[1]
if host_match == cert_match:
return True
else:
return False
else:
if host == cert_host:
return True
else:
return False
def _send_discovery_request(self, ssl_sock, thing_name):
request = self.REQUEST_TYPE_PREFIX + \
self.PAYLOAD_PREFIX + \
thing_name + \
self.PAYLOAD_SUFFIX + \
self.HOST_PREFIX + \
self._host + ":" + str(self._port) + \
self.HOST_SUFFIX
self._logger.debug("Sending discover request: " + request)
start_time = time.time()
desired_length_to_write = len(request)
actual_length_written = 0
while True:
try:
length_written = ssl_sock.write(request.encode("utf-8"))
actual_length_written += length_written
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
pass
if actual_length_written == desired_length_to_write:
return self.LOW_LEVEL_RC_COMPLETE
if start_time + self._timeout_sec < time.time():
return self.LOW_LEVEL_RC_TIMEOUT
def _receive_discovery_response(self, ssl_sock):
self._logger.debug("Receiving discover response header...")
rc1, response_header = self._receive_until(ssl_sock, self._got_two_crlfs)
status_code, body_length = self._handle_discovery_response_header(rc1, response_header.decode("utf-8"))
self._logger.debug("Receiving discover response body...")
rc2, response_body = self._receive_until(ssl_sock, self._got_enough_bytes, body_length)
response_body = self._handle_discovery_response_body(rc2, response_body.decode("utf-8"))
return status_code, response_body
def _receive_until(self, ssl_sock, criteria_function, extra_data=None):
start_time = time.time()
response = bytearray()
number_bytes_read = 0
while True: # Python does not have do-while
try:
response.append(self._convert_to_int_py3(ssl_sock.read(1)))
number_bytes_read += 1
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
pass
if criteria_function((number_bytes_read, response, extra_data)):
return self.LOW_LEVEL_RC_COMPLETE, response
if start_time + self._timeout_sec < time.time():
return self.LOW_LEVEL_RC_TIMEOUT, response
def _convert_to_int_py3(self, input_char):
try:
return ord(input_char)
except:
return input_char
def _got_enough_bytes(self, data):
number_bytes_read, response, target_length = data
return number_bytes_read == int(target_length)
def _got_two_crlfs(self, data):
number_bytes_read, response, extra_data_unused = data
number_of_crlf = 2
has_enough_bytes = number_bytes_read > number_of_crlf * 2 - 1
if has_enough_bytes:
end_of_received = response[number_bytes_read - number_of_crlf * 2 : number_bytes_read]
expected_end_of_response = b"\r\n" * number_of_crlf
return end_of_received == expected_end_of_response
else:
return False
def _handle_discovery_response_header(self, rc, response):
self._raise_on_timeout(rc)
http_status_code_matcher = re.compile(self.HTTP_RESPONSE_CODE_PATTERN)
http_status_code_matched_groups = http_status_code_matcher.match(response)
content_length_matcher = re.compile(self.CONTENT_LENGTH_PATTERN)
content_length_matched_groups = content_length_matcher.search(response)
return http_status_code_matched_groups.group(1), content_length_matched_groups.group(1)
def _handle_discovery_response_body(self, rc, response):
self._raise_on_timeout(rc)
return response
def _raise_on_timeout(self, rc):
if rc == self.LOW_LEVEL_RC_TIMEOUT:
raise DiscoveryTimeoutException()
def _raise_if_not_200(self, status_code, response_body): # response_body here is str in Py3
if status_code != self.HTTP_SC_200:
expected_exception = self._expected_exception_map.get(status_code)
if expected_exception:
raise expected_exception
else:
raise DiscoveryFailure(response_body)
return DiscoveryInfo(response_body)

View File

@@ -0,0 +1,156 @@
# /*
# * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
_BASE_THINGS_TOPIC = "$aws/things/"
_NOTIFY_OPERATION = "notify"
_NOTIFY_NEXT_OPERATION = "notify-next"
_GET_OPERATION = "get"
_START_NEXT_OPERATION = "start-next"
_WILDCARD_OPERATION = "+"
_UPDATE_OPERATION = "update"
_ACCEPTED_REPLY = "accepted"
_REJECTED_REPLY = "rejected"
_WILDCARD_REPLY = "#"
#Members of this enum are tuples
_JOB_ID_REQUIRED_INDEX = 1
_JOB_OPERATION_INDEX = 2
_STATUS_KEY = 'status'
_STATUS_DETAILS_KEY = 'statusDetails'
_EXPECTED_VERSION_KEY = 'expectedVersion'
_EXEXCUTION_NUMBER_KEY = 'executionNumber'
_INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState'
_INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument'
_CLIENT_TOKEN_KEY = 'clientToken'
_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes'
#The type of job topic.
class jobExecutionTopicType(object):
JOB_UNRECOGNIZED_TOPIC = (0, False, '')
JOB_GET_PENDING_TOPIC = (1, False, _GET_OPERATION)
JOB_START_NEXT_TOPIC = (2, False, _START_NEXT_OPERATION)
JOB_DESCRIBE_TOPIC = (3, True, _GET_OPERATION)
JOB_UPDATE_TOPIC = (4, True, _UPDATE_OPERATION)
JOB_NOTIFY_TOPIC = (5, False, _NOTIFY_OPERATION)
JOB_NOTIFY_NEXT_TOPIC = (6, False, _NOTIFY_NEXT_OPERATION)
JOB_WILDCARD_TOPIC = (7, False, _WILDCARD_OPERATION)
#Members of this enum are tuples
_JOB_SUFFIX_INDEX = 1
#The type of reply topic, or #JOB_REQUEST_TYPE for topics that are not replies.
class jobExecutionTopicReplyType(object):
JOB_UNRECOGNIZED_TOPIC_TYPE = (0, '')
JOB_REQUEST_TYPE = (1, '')
JOB_ACCEPTED_REPLY_TYPE = (2, '/' + _ACCEPTED_REPLY)
JOB_REJECTED_REPLY_TYPE = (3, '/' + _REJECTED_REPLY)
JOB_WILDCARD_REPLY_TYPE = (4, '/' + _WILDCARD_REPLY)
_JOB_STATUS_INDEX = 1
class jobExecutionStatus(object):
JOB_EXECUTION_STATUS_NOT_SET = (0, None)
JOB_EXECUTION_QUEUED = (1, 'QUEUED')
JOB_EXECUTION_IN_PROGRESS = (2, 'IN_PROGRESS')
JOB_EXECUTION_FAILED = (3, 'FAILED')
JOB_EXECUTION_SUCCEEDED = (4, 'SUCCEEDED')
JOB_EXECUTION_CANCELED = (5, 'CANCELED')
JOB_EXECUTION_REJECTED = (6, 'REJECTED')
JOB_EXECUTION_UNKNOWN_STATUS = (99, None)
def _getExecutionStatus(jobStatus):
try:
return jobStatus[_JOB_STATUS_INDEX]
except KeyError:
return None
def _isWithoutJobIdTopicType(srcJobExecTopicType):
return (srcJobExecTopicType == jobExecutionTopicType.JOB_GET_PENDING_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_START_NEXT_TOPIC
or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC)
class thingJobManager:
def __init__(self, thingName, clientToken = None):
self._thingName = thingName
self._clientToken = clientToken
def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None):
if self._thingName is None:
return None
#Verify topics that only support request type, actually have request type specified for reply
if (srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) and srcJobExecTopicReplyType != jobExecutionTopicReplyType.JOB_REQUEST_TYPE:
return None
#Verify topics that explicitly do not want a job ID do not have one specified
if (jobId is not None and _isWithoutJobIdTopicType(srcJobExecTopicType)):
return None
#Verify job ID is present if the topic requires one
if jobId is None and srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
return None
#Ensure the job operation is a non-empty string
if srcJobExecTopicType[_JOB_OPERATION_INDEX] == '':
return None
if srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
return '{0}{1}/jobs/{2}/{3}{4}'.format(_BASE_THINGS_TOPIC, self._thingName, str(jobId), srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
elif srcJobExecTopicType == jobExecutionTopicType.JOB_WILDCARD_TOPIC:
return '{0}{1}/jobs/#'.format(_BASE_THINGS_TOPIC, self._thingName)
else:
return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None):
executionStatus = _getExecutionStatus(status)
if executionStatus is None:
return None
payload = {_STATUS_KEY: executionStatus}
if statusDetails:
payload[_STATUS_DETAILS_KEY] = statusDetails
if expectedVersion > 0:
payload[_EXPECTED_VERSION_KEY] = str(expectedVersion)
if executionNumber > 0:
payload[_EXEXCUTION_NUMBER_KEY] = str(executionNumber)
if includeJobExecutionState:
payload[_INCLUDE_JOB_EXECUTION_STATE_KEY] = True
if includeJobDocument:
payload[_INCLUDE_JOB_DOCUMENT_KEY] = True
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
if stepTimeoutInMinutes is not None:
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
return json.dumps(payload)
def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True):
payload = {_INCLUDE_JOB_DOCUMENT_KEY: includeJobDocument}
if executionNumber > 0:
payload[_EXEXCUTION_NUMBER_KEY] = executionNumber
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
return json.dumps(payload)
def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None):
payload = {}
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
if statusDetails is not None:
payload[_STATUS_DETAILS_KEY] = statusDetails
if stepTimeoutInMinutes is not None:
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
return json.dumps(payload)
def serializeClientTokenPayload(self):
return json.dumps({_CLIENT_TOKEN_KEY: self._clientToken}) if self._clientToken is not None else '{}'

View File

@@ -0,0 +1,63 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
try:
import ssl
except:
ssl = None
class SSLContextBuilder(object):
def __init__(self):
self.check_supportability()
self._ssl_context = ssl.create_default_context()
def check_supportability(self):
if ssl is None:
raise RuntimeError("This platform has no SSL/TLS.")
if not hasattr(ssl, "SSLContext"):
raise NotImplementedError("This platform does not support SSLContext. Python 2.7.10+/3.5+ is required.")
if not hasattr(ssl.SSLContext, "set_alpn_protocols"):
raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.")
def with_ca_certs(self, ca_certs):
self._ssl_context.load_verify_locations(ca_certs)
return self
def with_cert_key_pair(self, cert_file, key_file):
self._ssl_context.load_cert_chain(cert_file, key_file)
return self
def with_cert_reqs(self, cert_reqs):
self._ssl_context.verify_mode = cert_reqs
return self
def with_check_hostname(self, check_hostname):
self._ssl_context.check_hostname = check_hostname
return self
def with_ciphers(self, ciphers):
if ciphers is not None:
self._ssl_context.set_ciphers(ciphers) # set_ciphers() does not allow None input. Use default (do nothing) if None
return self
def with_alpn_protocols(self, alpn_protocols):
self._ssl_context.set_alpn_protocols(alpn_protocols)
return self
def build(self):
return self._ssl_context

View File

@@ -0,0 +1,699 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
# This class implements the progressive backoff logic for auto-reconnect.
# It manages the reconnect wait time for the current reconnect, controling
# when to increase it and when to reset it.
import re
import sys
import ssl
import errno
import struct
import socket
import base64
import time
import threading
import logging
import os
from datetime import datetime
import hashlib
import hmac
from AWSIoTPythonSDK.exception.AWSIoTExceptions import ClientError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssHandShakeError
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
try:
from urllib.parse import quote # Python 3+
except ImportError:
from urllib import quote
# INI config file handling
try:
from configparser import ConfigParser # Python 3+
from configparser import NoOptionError
from configparser import NoSectionError
except ImportError:
from ConfigParser import ConfigParser
from ConfigParser import NoOptionError
from ConfigParser import NoSectionError
class ProgressiveBackOffCore:
# Logger
_logger = logging.getLogger(__name__)
def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20):
# The base reconnection time in seconds, default 1
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
# The maximum reconnection time in seconds, default 32
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
# The minimum time in milliseconds that a connection must be maintained in order to be considered stable
# Default 20
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
# Current backOff time in seconds, init to equal to 0
self._currentBackoffTimeSecond = 1
# Handler for timer
self._resetBackoffTimer = None
# For custom progressiveBackoff timing configuration
def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond):
if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0:
self._logger.error("init: Negative time configuration detected.")
raise ValueError("Negative time configuration detected.")
if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond:
self._logger.error("init: Min connect time should be bigger than base reconnect time.")
raise ValueError("Min connect time should be bigger than base reconnect time.")
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
self._currentBackoffTimeSecond = 1
# Block the reconnect logic for _currentBackoffTimeSecond
# Update the currentBackoffTimeSecond for the next reconnect
# Cancel the in-waiting timer for resetting backOff time
# This should get called only when a disconnect/reconnect happens
def backOff(self):
self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.")
if self._resetBackoffTimer is not None:
# Cancel the timer
self._resetBackoffTimer.cancel()
# Block the reconnect logic
time.sleep(self._currentBackoffTimeSecond)
# Update the backoff time
if self._currentBackoffTimeSecond == 0:
# This is the first attempt to connect, set it to base
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
else:
# r_cur = min(2^n*r_base, r_max)
self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2)
# Start the timer for resetting _currentBackoffTimeSecond
# Will be cancelled upon calling backOff
def startStableConnectionTimer(self):
self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond,
self._connectionStableThenResetBackoffTime)
self._resetBackoffTimer.start()
def stopStableConnectionTimer(self):
if self._resetBackoffTimer is not None:
# Cancel the timer
self._resetBackoffTimer.cancel()
# Timer callback to reset _currentBackoffTimeSecond
# If the connection is stable for longer than _minimumConnectTimeSecond,
# reset the currentBackoffTimeSecond to _baseReconnectTimeSecond
def _connectionStableThenResetBackoffTime(self):
self._logger.debug(
"stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.")
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
class SigV4Core:
_logger = logging.getLogger(__name__)
def __init__(self):
self._aws_access_key_id = ""
self._aws_secret_access_key = ""
self._aws_session_token = ""
self._credentialConfigFilePath = "~/.aws/credentials"
def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken):
self._aws_access_key_id = srcAWSAccessKeyID
self._aws_secret_access_key = srcAWSSecretAccessKey
self._aws_session_token = srcAWSSessionToken
def _createAmazonDate(self):
# Returned as a unicode string in Py3.x
amazonDate = []
currentTime = datetime.utcnow()
YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ')
YMD = YMDHMS[0:YMDHMS.index('T')]
amazonDate.append(YMD)
amazonDate.append(YMDHMS)
return amazonDate
def _sign(self, key, message):
# Returned as a utf-8 byte string in Py3.x
return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest()
def _getSignatureKey(self, key, dateStamp, regionName, serviceName):
# Returned as a utf-8 byte string in Py3.x
kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp)
kRegion = self._sign(kDate, regionName)
kService = self._sign(kRegion, serviceName)
kSigning = self._sign(kService, 'aws4_request')
return kSigning
def _checkIAMCredentials(self):
# Check custom config
ret = self._checkKeyInCustomConfig()
# Check environment variables
if not ret:
ret = self._checkKeyInEnv()
# Check files
if not ret:
ret = self._checkKeyInFiles()
# All credentials returned as unicode strings in Py3.x
return ret
def _checkKeyInEnv(self):
ret = dict()
self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN')
if self._aws_access_key_id is not None and self._aws_secret_access_key is not None:
ret["aws_access_key_id"] = self._aws_access_key_id
ret["aws_secret_access_key"] = self._aws_secret_access_key
# We do not necessarily need session token...
if self._aws_session_token is not None:
ret["aws_session_token"] = self._aws_session_token
self._logger.debug("IAM credentials from env var.")
return ret
def _checkKeyInINIDefault(self, srcConfigParser, sectionName):
ret = dict()
# Check aws_access_key_id and aws_secret_access_key
try:
ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id")
ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key")
except NoOptionError:
self._logger.warn("Cannot find IAM keyID/secretKey in credential file.")
# We do not continue searching if we cannot even get IAM id/secret right
if len(ret) == 2:
# Check aws_session_token, optional
try:
ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token")
except NoOptionError:
self._logger.debug("No AWS Session Token found.")
return ret
def _checkKeyInFiles(self):
credentialFile = None
credentialConfig = None
ret = dict()
# Should be compatible with aws cli default credential configuration
# *NIX/Windows
try:
# See if we get the file
credentialConfig = ConfigParser()
credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/
credentialConfig.read(credentialFilePath)
# Now we have the file, start looking for credentials...
# 'default' section
ret = self._checkKeyInINIDefault(credentialConfig, "default")
if not ret:
# 'DEFAULT' section
ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT")
self._logger.debug("IAM credentials from file.")
except IOError:
self._logger.debug("No IAM credential configuration file in " + credentialFilePath)
except NoSectionError:
self._logger.error("Cannot find IAM 'default' section.")
return ret
def _checkKeyInCustomConfig(self):
ret = dict()
if self._aws_access_key_id != "" and self._aws_secret_access_key != "":
ret["aws_access_key_id"] = self._aws_access_key_id
ret["aws_secret_access_key"] = self._aws_secret_access_key
# We do not necessarily need session token...
if self._aws_session_token != "":
ret["aws_session_token"] = self._aws_session_token
self._logger.debug("IAM credentials from custom config.")
return ret
def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path):
# Return the endpoint as unicode string in 3.x
# Gather all the facts
amazonDate = self._createAmazonDate()
amazonDateSimple = amazonDate[0] # Unicode in 3.x
amazonDateComplex = amazonDate[1] # Unicode in 3.x
allKeys = self._checkIAMCredentials() # Unicode in 3.x
if not self._hasCredentialsNecessaryForWebsocket(allKeys):
raise wssNoKeyInEnvironmentError()
else:
# Because of self._hasCredentialsNecessaryForWebsocket(...), keyID and secretKey should not be None from here
keyID = allKeys["aws_access_key_id"]
secretKey = allKeys["aws_secret_access_key"]
# amazonDateSimple and amazonDateComplex are guaranteed not to be None
queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \
"&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \
"&X-Amz-Date=" + amazonDateComplex + \
"&X-Amz-Expires=86400" + \
"&X-Amz-SignedHeaders=host" # Unicode in 3.x
hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x
# Create the string to sign
signedHeaders = "host"
canonicalHeaders = "host:" + host + "\n"
canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x
hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x
stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x
# Sign it
signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName)
signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest()
# generate url
url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature
# See if we have STS token, if we do, add it
awsSessionTokenCandidate = allKeys.get("aws_session_token")
if awsSessionTokenCandidate is not None and len(awsSessionTokenCandidate) != 0:
aws_session_token = allKeys["aws_session_token"]
url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x
self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url)
return url
def _hasCredentialsNecessaryForWebsocket(self, allKeys):
awsAccessKeyIdCandidate = allKeys.get("aws_access_key_id")
awsSecretAccessKeyCandidate = allKeys.get("aws_secret_access_key")
# None value is NOT considered as valid entries
validEntries = awsAccessKeyIdCandidate is not None and awsAccessKeyIdCandidate is not None
if validEntries:
# Empty value is NOT considered as valid entries
validEntries &= (len(awsAccessKeyIdCandidate) != 0 and len(awsSecretAccessKeyCandidate) != 0)
return validEntries
# This is an internal class that buffers the incoming bytes into an
# internal buffer until it gets the full desired length of bytes.
# At that time, this bufferedReader will be reset.
# *Error handling:
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
# leave them to the paho _packet_read for further handling (ignored and try
# again when data is available.
# For other errors, leave them to the paho _packet_read for error reporting.
class _BufferedReader:
_sslSocket = None
_internalBuffer = None
_remainedLength = -1
_bufferingInProgress = False
def __init__(self, sslSocket):
self._sslSocket = sslSocket
self._internalBuffer = bytearray()
self._bufferingInProgress = False
def _reset(self):
self._internalBuffer = bytearray()
self._remainedLength = -1
self._bufferingInProgress = False
def read(self, numberOfBytesToBeBuffered):
if not self._bufferingInProgress: # If last read is completed...
self._remainedLength = numberOfBytesToBeBuffered
self._bufferingInProgress = True # Now we start buffering a new length of bytes
while self._remainedLength > 0: # Read in a loop, always try to read in the remained length
# If the data is temporarily not available, socket.error will be raised and catched by paho
dataChunk = self._sslSocket.read(self._remainedLength)
# There is a chance where the server terminates the connection without closing the socket.
# If that happens, let's raise an exception and enter the reconnect flow.
if not dataChunk:
raise socket.error(errno.ECONNABORTED, 0)
self._internalBuffer.extend(dataChunk) # Buffer the data
self._remainedLength -= len(dataChunk) # Update the remained length
# The requested length of bytes is buffered, recover the context and return it
# Otherwise error should be raised
ret = self._internalBuffer
self._reset()
return ret # This should always be bytearray
# This is the internal class that sends requested data out chunk by chunk according
# to the availablity of the socket write operation. If the requested bytes of data
# (after encoding) needs to be sent out in separate socket write operations (most
# probably be interrupted by the error socket.error (errno = ssl.SSL_ERROR_WANT_WRITE).)
# , the write pointer is stored to ensure that the continued bytes will be sent next
# time this function gets called.
# *Error handling:
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
# leave them to the paho _packet_read for further handling (ignored and try
# again when data is available.
# For other errors, leave them to the paho _packet_read for error reporting.
class _BufferedWriter:
_sslSocket = None
_internalBuffer = None
_writingInProgress = False
_requestedDataLength = -1
def __init__(self, sslSocket):
self._sslSocket = sslSocket
self._internalBuffer = bytearray()
self._writingInProgress = False
self._requestedDataLength = -1
def _reset(self):
self._internalBuffer = bytearray()
self._writingInProgress = False
self._requestedDataLength = -1
# Input data for this function needs to be an encoded wss frame
# Always request for packet[pos=0:] (raw MQTT data)
def write(self, encodedData, payloadLength):
# encodedData should always be bytearray
# Check if we have a frame that is partially sent
if not self._writingInProgress:
self._internalBuffer = encodedData
self._writingInProgress = True
self._requestedDataLength = payloadLength
# Now, write as much as we can
lengthWritten = self._sslSocket.write(self._internalBuffer)
self._internalBuffer = self._internalBuffer[lengthWritten:]
# This MQTT packet has been sent out in a wss frame, completely
if len(self._internalBuffer) == 0:
ret = self._requestedDataLength
self._reset()
return ret
# This socket write is half-baked...
else:
return 0 # Ensure that the 'pos' inside the MQTT packet never moves since we have not finished the transmission of this encoded frame
class SecuredWebSocketCore:
# Websocket Constants
_OP_CONTINUATION = 0x0
_OP_TEXT = 0x1
_OP_BINARY = 0x2
_OP_CONNECTION_CLOSE = 0x8
_OP_PING = 0x9
_OP_PONG = 0xa
# Websocket Connect Status
_WebsocketConnectInit = -1
_WebsocketDisconnected = 1
_logger = logging.getLogger(__name__)
def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecretAccessKey="", AWSSessionToken=""):
self._connectStatus = self._WebsocketConnectInit
# Handlers
self._sslSocket = socket
self._sigV4Handler = self._createSigV4Core()
self._sigV4Handler.setIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken)
# Endpoint Info
self._hostAddress = hostAddress
self._portNumber = portNumber
# Section Flags
self._hasOpByte = False
self._hasPayloadLengthFirst = False
self._hasPayloadLengthExtended = False
self._hasMaskKey = False
self._hasPayload = False
# Properties for current websocket frame
self._isFIN = False
self._RSVBits = None
self._opCode = None
self._needMaskKey = False
self._payloadLengthBytesLength = 1
self._payloadLength = 0
self._maskKey = None
self._payloadDataBuffer = bytearray() # Once the whole wss connection is lost, there is no need to keep the buffered payload
try:
self._handShake(hostAddress, portNumber)
except wssNoKeyInEnvironmentError: # Handle SigV4 signing and websocket handshaking errors
raise ValueError("No Access Key/KeyID Error")
except wssHandShakeError:
raise ValueError("Websocket Handshake Error")
except ClientError as e:
raise ValueError(e.message)
# Now we have a socket with secured websocket...
self._bufferedReader = _BufferedReader(self._sslSocket)
self._bufferedWriter = _BufferedWriter(self._sslSocket)
def _createSigV4Core(self):
return SigV4Core()
def _generateMaskKey(self):
return bytearray(os.urandom(4))
# os.urandom returns ascii str in 2.x, converted to bytearray
# os.urandom returns bytes in 3.x, converted to bytearray
def _reset(self): # Reset the context for wss frame reception
# Control info
self._hasOpByte = False
self._hasPayloadLengthFirst = False
self._hasPayloadLengthExtended = False
self._hasMaskKey = False
self._hasPayload = False
# Frame Info
self._isFIN = False
self._RSVBits = None
self._opCode = None
self._needMaskKey = False
self._payloadLengthBytesLength = 1
self._payloadLength = 0
self._maskKey = None
# Never reset the payloadData since we might have fragmented MQTT data from the pervious frame
def _generateWSSKey(self):
return base64.b64encode(os.urandom(128)) # Bytes
def _verifyWSSResponse(self, response, clientKey):
# Check if it is a 101 response
rawResponse = response.strip().lower()
if b"101 switching protocols" not in rawResponse or b"upgrade: websocket" not in rawResponse or b"connection: upgrade" not in rawResponse:
return False
# Parse out the sec-websocket-accept
WSSAcceptKeyIndex = response.strip().index(b"sec-websocket-accept: ") + len(b"sec-websocket-accept: ")
rawSecWebSocketAccept = response.strip()[WSSAcceptKeyIndex:].split(b"\r\n")[0].strip()
# Verify the WSSAcceptKey
return self._verifyWSSAcceptKey(rawSecWebSocketAccept, clientKey)
def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey):
GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
verifyServerAcceptKey = base64.b64encode((hashlib.sha1(clientKey + GUID)).digest()) # Bytes
return srcAcceptKey == verifyServerAcceptKey
def _handShake(self, hostAddress, portNumber):
CRLF = "\r\n"
IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*"
matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress)
if not matched:
raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress)
region = matched.group(2)
signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt")
# Now we got a signedURL
path = signedURL[signedURL.index("/mqtt"):]
# Assemble HTTP request headers
Method = "GET " + path + " HTTP/1.1" + CRLF
Host = "Host: " + hostAddress + CRLF
Connection = "Connection: " + "Upgrade" + CRLF
Upgrade = "Upgrade: " + "websocket" + CRLF
secWebSocketVersion = "Sec-WebSocket-Version: " + "13" + CRLF
rawSecWebSocketKey = self._generateWSSKey() # Bytes
secWebSocketKey = "sec-websocket-key: " + rawSecWebSocketKey.decode('utf-8') + CRLF # Should be randomly generated...
secWebSocketProtocol = "Sec-WebSocket-Protocol: " + "mqttv3.1" + CRLF
secWebSocketExtensions = "Sec-WebSocket-Extensions: " + "permessage-deflate; client_max_window_bits" + CRLF
# Send the HTTP request
# Ensure that we are sending bytes, not by any chance unicode string
handshakeBytes = Method + Host + Connection + Upgrade + secWebSocketVersion + secWebSocketProtocol + secWebSocketExtensions + secWebSocketKey + CRLF
handshakeBytes = handshakeBytes.encode('utf-8')
self._sslSocket.write(handshakeBytes)
# Read it back (Non-blocking socket)
timeStart = time.time()
wssHandshakeResponse = bytearray()
while len(wssHandshakeResponse) == 0:
try:
wssHandshakeResponse += self._sslSocket.read(1024) # Response is always less than 1024 bytes
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
if time.time() - timeStart > self._getTimeoutSec():
raise err # We make sure that reconnect gets retried in Paho upon a wss reconnect response timeout
else:
raise err
# Verify response
# Now both wssHandshakeResponse and rawSecWebSocketKey are byte strings
if not self._verifyWSSResponse(wssHandshakeResponse, rawSecWebSocketKey):
raise wssHandShakeError()
else:
pass
def _getTimeoutSec(self):
return DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
# Used to create a single wss frame
# Assume that the maximum length of a MQTT packet never exceeds the maximum length
# for a wss frame. Therefore, the FIN bit for the encoded frame will always be 1.
# Frames are encoded as BINARY frames.
def _encodeFrame(self, rawPayload, opCode, masked=1):
ret = bytearray()
# Op byte
opByte = 0x80 | opCode # Always a FIN, no RSV bits
ret.append(opByte)
# Payload Length bytes
maskBit = masked
payloadLength = len(rawPayload)
if payloadLength <= 125:
ret.append((maskBit << 7) | payloadLength)
elif payloadLength <= 0xffff: # 16-bit unsigned int
ret.append((maskBit << 7) | 126)
ret.extend(struct.pack("!H", payloadLength))
elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0)
ret.append((maskBit << 7) | 127)
ret.extend(struct.pack("!Q", payloadLength))
else: # Overflow
raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.")
if maskBit == 1:
# Mask key bytes
maskKey = self._generateMaskKey()
ret.extend(maskKey)
# Mask the payload
payloadBytes = bytearray(rawPayload)
if maskBit == 1:
for i in range(0, payloadLength):
payloadBytes[i] ^= maskKey[i % 4]
ret.extend(payloadBytes)
# Return the assembled wss frame
return ret
# Used for the wss client to close a wss connection
# Create and send a masked wss closing frame
def _closeWssConnection(self):
# Frames sent from client to server must be masked
self._sslSocket.write(self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, masked=1))
# Used for the wss client to respond to a wss PING from server
# Create and send a masked PONG frame
def _sendPONG(self):
# Frames sent from client to server must be masked
self._sslSocket.write(self._encodeFrame(b"", self._OP_PONG, masked=1))
# Override sslSocket read. Always read from the wss internal payload buffer, which
# contains the masked MQTT packet. This read will decode ONE wss frame every time
# and load in the payload for MQTT _packet_read. At any time, MQTT _packet_read
# should be able to read a complete MQTT packet from the payload (buffered per wss
# frame payload). If the MQTT packet is break into separate wss frames, different
# chunks will be buffered in separate frames and MQTT _packet_read will not be able
# to collect a complete MQTT packet to operate on until the necessary payload is
# fully buffered.
# If the requested number of bytes are not available, SSL_ERROR_WANT_READ will be
# raised to trigger another call of _packet_read when the data is available again.
def read(self, numberOfBytes):
# Check if we have enough data for paho
# _payloadDataBuffer will not be empty ony when the payload of a new wss frame
# has been unmasked.
if len(self._payloadDataBuffer) >= numberOfBytes:
ret = self._payloadDataBuffer[0:numberOfBytes]
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
# struct.unpack(fmt, string) # Py2.x
# struct.unpack(fmt, buffer) # Py3.x
# Here ret is always in bytes (buffer interface)
if sys.version_info[0] < 3: # Py2.x
ret = str(ret)
return ret
# Emmm, We don't. Try to buffer from the socket (It's a new wss frame).
if not self._hasOpByte: # Check if we need to buffer OpByte
opByte = self._bufferedReader.read(1)
self._isFIN = (opByte[0] & 0x80) == 0x80
self._RSVBits = (opByte[0] & 0x70)
self._opCode = (opByte[0] & 0x0f)
self._hasOpByte = True # Finished buffering opByte
# Check if any of the RSV bits are set, if so, close the connection
# since client never sends negotiated extensions
if self._RSVBits != 0x0:
self._closeWssConnection()
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray()
raise socket.error(ssl.SSL_ERROR_WANT_READ, "RSV bits set with NO negotiated extensions.")
if not self._hasPayloadLengthFirst: # Check if we need to buffer First Payload Length byte
payloadLengthFirst = self._bufferedReader.read(1)
self._hasPayloadLengthFirst = True # Finished buffering first byte of payload length
self._needMaskKey = (payloadLengthFirst[0] & 0x80) == 0x80
payloadLengthFirstByteArray = bytearray()
payloadLengthFirstByteArray.extend(payloadLengthFirst)
self._payloadLength = (payloadLengthFirstByteArray[0] & 0x7f)
if self._payloadLength == 126:
self._payloadLengthBytesLength = 2
self._hasPayloadLengthExtended = False # Force to buffer the extended
elif self._payloadLength == 127:
self._payloadLengthBytesLength = 8
self._hasPayloadLengthExtended = False # Force to buffer the extended
else: # _payloadLength <= 125:
self._hasPayloadLengthExtended = True # No need to buffer extended payload length
if not self._hasPayloadLengthExtended: # Check if we need to buffer Extended Payload Length bytes
payloadLengthExtended = self._bufferedReader.read(self._payloadLengthBytesLength)
self._hasPayloadLengthExtended = True
if sys.version_info[0] < 3:
payloadLengthExtended = str(payloadLengthExtended)
if self._payloadLengthBytesLength == 2:
self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0]
else: # _payloadLengthBytesLength == 8
self._payloadLength = struct.unpack("!Q", payloadLengthExtended)[0]
if self._needMaskKey: # Response from server is masked, close the connection
self._closeWssConnection()
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray()
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Server response masked, closing connection and try again.")
if not self._hasPayload: # Check if we need to buffer the payload
payloadForThisFrame = self._bufferedReader.read(self._payloadLength)
self._hasPayload = True
# Client side should never received a masked packet from the server side
# Unmask it as needed
#if self._needMaskKey:
# for i in range(0, self._payloadLength):
# payloadForThisFrame[i] ^= self._maskKey[i % 4]
# Append it to the internal payload buffer
self._payloadDataBuffer.extend(payloadForThisFrame)
# Now we have the complete wss frame, reset the context
# Check to see if it is a wss closing frame
if self._opCode == self._OP_CONNECTION_CLOSE:
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray() # Ensure that once the wss closing frame comes, we have nothing to read and start all over again
# Check to see if it is a wss PING frame
if self._opCode == self._OP_PING:
self._sendPONG() # Nothing more to do here, if the transmission of the last wssMQTT packet is not finished, it will continue
self._reset()
# Check again if we have enough data for paho
if len(self._payloadDataBuffer) >= numberOfBytes:
ret = self._payloadDataBuffer[0:numberOfBytes]
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
# struct.unpack(fmt, string) # Py2.x
# struct.unpack(fmt, buffer) # Py3.x
# Here ret is always in bytes (buffer interface)
if sys.version_info[0] < 3: # Py2.x
ret = str(ret)
return ret
else: # Fragmented MQTT packets in separate wss frames
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not a complete MQTT packet payload within this wss frame.")
def write(self, bytesToBeSent):
# When there is a disconnection, select will report a TypeError which triggers the reconnect.
# In reconnect, Paho will set the socket object (mocked by wss) to None, blocking other ops
# before a connection is re-established.
# This 'low-level' socket write op should always be able to write to plain socket.
# Error reporting is performed by Python socket itself.
# Wss closing frame handling is performed in the wss read.
return self._bufferedWriter.write(self._encodeFrame(bytesToBeSent, self._OP_BINARY, 1), len(bytesToBeSent))
def close(self):
if self._sslSocket is not None:
self._sslSocket.close()
self._sslSocket = None
def getpeercert(self):
return self._sslSocket.getpeercert()
def getSSLSocket(self):
if self._connectStatus != self._WebsocketDisconnected:
return self._sslSocket
else:
return None # Leave the sslSocket to Paho to close it. (_ssl.close() -> wssCore.close())

View File

@@ -0,0 +1,244 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import ssl
import logging
from threading import Lock
from numbers import Number
import AWSIoTPythonSDK.core.protocol.paho.client as mqtt
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
class ClientStatus(object):
IDLE = 0
CONNECT = 1
RESUBSCRIBE = 2
DRAINING = 3
STABLE = 4
USER_DISCONNECT = 5
ABNORMAL_DISCONNECT = 6
class ClientStatusContainer(object):
def __init__(self):
self._status = ClientStatus.IDLE
def get_status(self):
return self._status
def set_status(self, status):
if ClientStatus.USER_DISCONNECT == self._status: # If user requests to disconnect, no status updates other than user connect
if ClientStatus.CONNECT == status:
self._status = status
else:
self._status = status
class InternalAsyncMqttClient(object):
_logger = logging.getLogger(__name__)
def __init__(self, client_id, clean_session, protocol, use_wss):
self._paho_client = self._create_paho_client(client_id, clean_session, None, protocol, use_wss)
self._use_wss = use_wss
self._event_callback_map_lock = Lock()
self._event_callback_map = dict()
def _create_paho_client(self, client_id, clean_session, user_data, protocol, use_wss):
self._logger.debug("Initializing MQTT layer...")
return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss)
# TODO: Merge credentials providers configuration into one
def set_cert_credentials_provider(self, cert_credentials_provider):
# History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3
# pre-installed. In this version, TLSv1_2 is not even an option.
# SSLv23 is a work-around which selects the highest TLS version between the client
# and service. If user installs opensslv1.0.1+, this option will work fine for Mutual
# Auth.
# Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support
# in Python only starts from Python2.7.
# See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23
if self._use_wss:
ca_path = cert_credentials_provider.get_ca_path()
self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
else:
ca_path = cert_credentials_provider.get_ca_path()
cert_path = cert_credentials_provider.get_cert_path()
key_path = cert_credentials_provider.get_key_path()
self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path,
cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
def set_iam_credentials_provider(self, iam_credentials_provider):
self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(),
iam_credentials_provider.get_secret_access_key(),
iam_credentials_provider.get_session_token())
def set_endpoint_provider(self, endpoint_provider):
self._endpoint_provider = endpoint_provider
def configure_last_will(self, topic, payload, qos, retain=False):
self._paho_client.will_set(topic, payload, qos, retain)
def configure_alpn_protocols(self, alpn_protocols):
self._paho_client.config_alpn_protocols(alpn_protocols)
def clear_last_will(self):
self._paho_client.will_clear()
def set_username_password(self, username, password=None):
self._paho_client.username_pw_set(username, password)
def set_socket_factory(self, socket_factory):
self._paho_client.socket_factory_set(socket_factory)
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
def connect(self, keep_alive_sec, ack_callback=None):
host = self._endpoint_provider.get_host()
port = self._endpoint_provider.get_port()
with self._event_callback_map_lock:
self._logger.debug("Filling in fixed event callbacks: CONNACK, DISCONNECT, MESSAGE")
self._event_callback_map[FixedEventMids.CONNACK_MID] = self._create_combined_on_connect_callback(ack_callback)
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = self._create_combined_on_disconnect_callback(None)
self._event_callback_map[FixedEventMids.MESSAGE_MID] = self._create_converted_on_message_callback()
rc = self._paho_client.connect(host, port, keep_alive_sec)
if MQTT_ERR_SUCCESS == rc:
self.start_background_network_io()
return rc
def start_background_network_io(self):
self._logger.debug("Starting network I/O thread...")
self._paho_client.loop_start()
def stop_background_network_io(self):
self._logger.debug("Stopping network I/O thread...")
self._paho_client.loop_stop()
def disconnect(self, ack_callback=None):
with self._event_callback_map_lock:
rc = self._paho_client.disconnect()
if MQTT_ERR_SUCCESS == rc:
self._logger.debug("Filling in custom disconnect event callback...")
combined_on_disconnect_callback = self._create_combined_on_disconnect_callback(ack_callback)
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = combined_on_disconnect_callback
return rc
def _create_combined_on_connect_callback(self, ack_callback):
def combined_on_connect_callback(mid, data):
self.on_online()
if ack_callback:
ack_callback(mid, data)
return combined_on_connect_callback
def _create_combined_on_disconnect_callback(self, ack_callback):
def combined_on_disconnect_callback(mid, data):
self.on_offline()
if ack_callback:
ack_callback(mid, data)
return combined_on_disconnect_callback
def _create_converted_on_message_callback(self):
def converted_on_message_callback(mid, data):
self.on_message(data)
return converted_on_message_callback
# For client online notification
def on_online(self):
pass
# For client offline notification
def on_offline(self):
pass
# For client message reception notification
def on_message(self, message):
pass
def publish(self, topic, payload, qos, retain=False, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.publish(topic, payload, qos, retain)
if MQTT_ERR_SUCCESS == rc and qos > 0 and ack_callback:
self._logger.debug("Filling in custom puback (QoS>0) event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def subscribe(self, topic, qos, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.subscribe(topic, qos)
if MQTT_ERR_SUCCESS == rc and ack_callback:
self._logger.debug("Filling in custom suback event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def unsubscribe(self, topic, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.unsubscribe(topic)
if MQTT_ERR_SUCCESS == rc and ack_callback:
self._logger.debug("Filling in custom unsuback event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def register_internal_event_callbacks(self, on_connect, on_disconnect, on_publish, on_subscribe, on_unsubscribe, on_message):
self._logger.debug("Registering internal event callbacks to MQTT layer...")
self._paho_client.on_connect = on_connect
self._paho_client.on_disconnect = on_disconnect
self._paho_client.on_publish = on_publish
self._paho_client.on_subscribe = on_subscribe
self._paho_client.on_unsubscribe = on_unsubscribe
self._paho_client.on_message = on_message
def unregister_internal_event_callbacks(self):
self._logger.debug("Unregistering internal event callbacks from MQTT layer...")
self._paho_client.on_connect = None
self._paho_client.on_disconnect = None
self._paho_client.on_publish = None
self._paho_client.on_subscribe = None
self._paho_client.on_unsubscribe = None
self._paho_client.on_message = None
def invoke_event_callback(self, mid, data=None):
with self._event_callback_map_lock:
event_callback = self._event_callback_map.get(mid)
# For invoking the event callback, we do not need to acquire the lock
if event_callback:
self._logger.debug("Invoking custom event callback...")
if data is not None:
event_callback(mid=mid, data=data)
else:
event_callback(mid=mid)
if isinstance(mid, Number): # Do NOT remove callbacks for CONNACK/DISCONNECT/MESSAGE
self._logger.debug("This custom event callback is for pub/sub/unsub, removing it after invocation...")
with self._event_callback_map_lock:
del self._event_callback_map[mid]
def remove_event_callback(self, mid):
with self._event_callback_map_lock:
if mid in self._event_callback_map:
self._logger.debug("Removing custom event callback...")
del self._event_callback_map[mid]
def clean_up_event_callbacks(self):
with self._event_callback_map_lock:
self._event_callback_map.clear()
def get_event_callback_map(self):
return self._event_callback_map

View File

@@ -0,0 +1,20 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC = 30
DEFAULT_OPERATION_TIMEOUT_SEC = 5
DEFAULT_DRAINING_INTERNAL_SEC = 0.5
METRICS_PREFIX = "?SDK=Python&Version="
ALPN_PROTCOLS = "x-amzn-mqtt-ca"

View File

@@ -0,0 +1,29 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class EventTypes(object):
CONNACK = 0
DISCONNECT = 1
PUBACK = 2
SUBACK = 3
UNSUBACK = 4
MESSAGE = 5
class FixedEventMids(object):
CONNACK_MID = "CONNECTED"
DISCONNECT_MID = "DISCONNECTED"
MESSAGE_MID = "MESSAGE"
QUEUED_MID = "QUEUED"

View File

@@ -0,0 +1,87 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import logging
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
class AppendResults(object):
APPEND_FAILURE_QUEUE_FULL = -1
APPEND_FAILURE_QUEUE_DISABLED = -2
APPEND_SUCCESS = 0
class OfflineRequestQueue(list):
_logger = logging.getLogger(__name__)
def __init__(self, max_size, drop_behavior=DropBehaviorTypes.DROP_NEWEST):
if not isinstance(max_size, int) or not isinstance(drop_behavior, int):
self._logger.error("init: MaximumSize/DropBehavior must be integer.")
raise TypeError("MaximumSize/DropBehavior must be integer.")
if drop_behavior != DropBehaviorTypes.DROP_OLDEST and drop_behavior != DropBehaviorTypes.DROP_NEWEST:
self._logger.error("init: Drop behavior not supported.")
raise ValueError("Drop behavior not supported.")
list.__init__([])
self._drop_behavior = drop_behavior
# When self._maximumSize > 0, queue is limited
# When self._maximumSize == 0, queue is disabled
# When self._maximumSize < 0. queue is infinite
self._max_size = max_size
def _is_enabled(self):
return self._max_size != 0
def _need_drop_messages(self):
# Need to drop messages when:
# 1. Queue is limited and full
# 2. Queue is disabled
is_queue_full = len(self) >= self._max_size
is_queue_limited = self._max_size > 0
is_queue_disabled = not self._is_enabled()
return (is_queue_full and is_queue_limited) or is_queue_disabled
def set_behavior_drop_newest(self):
self._drop_behavior = DropBehaviorTypes.DROP_NEWEST
def set_behavior_drop_oldest(self):
self._drop_behavior = DropBehaviorTypes.DROP_OLDEST
# Override
# Append to a queue with a limited size.
# Return APPEND_SUCCESS if the append is successful
# Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full
# Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled
def append(self, data):
ret = AppendResults.APPEND_SUCCESS
if self._is_enabled():
if self._need_drop_messages():
# We should drop the newest
if DropBehaviorTypes.DROP_NEWEST == self._drop_behavior:
self._logger.warn("append: Full queue. Drop the newest: " + str(data))
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
# We should drop the oldest
else:
current_oldest = super(OfflineRequestQueue, self).pop(0)
self._logger.warn("append: Full queue. Drop the oldest: " + str(current_oldest))
super(OfflineRequestQueue, self).append(data)
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
else:
self._logger.debug("append: Add new element: " + str(data))
super(OfflineRequestQueue, self).append(data)
else:
self._logger.debug("append: Queue is disabled. Drop the message: " + str(data))
ret = AppendResults.APPEND_FAILURE_QUEUE_DISABLED
return ret

View File

@@ -0,0 +1,27 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class RequestTypes(object):
CONNECT = 0
DISCONNECT = 1
PUBLISH = 2
SUBSCRIBE = 3
UNSUBSCRIBE = 4
class QueueableRequest(object):
def __init__(self, type, data):
self.type = type
self.data = data # Can be a tuple

View File

@@ -0,0 +1,296 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import time
import logging
from threading import Thread
from threading import Event
from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
from AWSIoTPythonSDK.core.protocol.internal.queues import OfflineRequestQueue
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
from AWSIoTPythonSDK.core.protocol.paho.client import topic_matches_sub
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC
class EventProducer(object):
_logger = logging.getLogger(__name__)
def __init__(self, cv, event_queue):
self._cv = cv
self._event_queue = event_queue
def on_connect(self, client, user_data, flags, rc):
self._add_to_queue(FixedEventMids.CONNACK_MID, EventTypes.CONNACK, rc)
self._logger.debug("Produced [connack] event")
def on_disconnect(self, client, user_data, rc):
self._add_to_queue(FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, rc)
self._logger.debug("Produced [disconnect] event")
def on_publish(self, client, user_data, mid):
self._add_to_queue(mid, EventTypes.PUBACK, None)
self._logger.debug("Produced [puback] event")
def on_subscribe(self, client, user_data, mid, granted_qos):
self._add_to_queue(mid, EventTypes.SUBACK, granted_qos)
self._logger.debug("Produced [suback] event")
def on_unsubscribe(self, client, user_data, mid):
self._add_to_queue(mid, EventTypes.UNSUBACK, None)
self._logger.debug("Produced [unsuback] event")
def on_message(self, client, user_data, message):
self._add_to_queue(FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, message)
self._logger.debug("Produced [message] event")
def _add_to_queue(self, mid, event_type, data):
with self._cv:
self._event_queue.put((mid, event_type, data))
self._cv.notify()
class EventConsumer(object):
MAX_DISPATCH_INTERNAL_SEC = 0.01
_logger = logging.getLogger(__name__)
def __init__(self, cv, event_queue, internal_async_client,
subscription_manager, offline_requests_manager, client_status):
self._cv = cv
self._event_queue = event_queue
self._internal_async_client = internal_async_client
self._subscription_manager = subscription_manager
self._offline_requests_manager = offline_requests_manager
self._client_status = client_status
self._is_running = False
self._draining_interval_sec = DEFAULT_DRAINING_INTERNAL_SEC
self._dispatch_methods = {
EventTypes.CONNACK : self._dispatch_connack,
EventTypes.DISCONNECT : self._dispatch_disconnect,
EventTypes.PUBACK : self._dispatch_puback,
EventTypes.SUBACK : self._dispatch_suback,
EventTypes.UNSUBACK : self._dispatch_unsuback,
EventTypes.MESSAGE : self._dispatch_message
}
self._offline_request_handlers = {
RequestTypes.PUBLISH : self._handle_offline_publish,
RequestTypes.SUBSCRIBE : self._handle_offline_subscribe,
RequestTypes.UNSUBSCRIBE : self._handle_offline_unsubscribe
}
self._stopper = Event()
def update_offline_requests_manager(self, offline_requests_manager):
self._offline_requests_manager = offline_requests_manager
def update_draining_interval_sec(self, draining_interval_sec):
self._draining_interval_sec = draining_interval_sec
def get_draining_interval_sec(self):
return self._draining_interval_sec
def is_running(self):
return self._is_running
def start(self):
self._stopper.clear()
self._is_running = True
dispatch_events = Thread(target=self._dispatch)
dispatch_events.daemon = True
dispatch_events.start()
self._logger.debug("Event consuming thread started")
def stop(self):
if self._is_running:
self._is_running = False
self._clean_up()
self._logger.debug("Event consuming thread stopped")
def _clean_up(self):
self._logger.debug("Cleaning up before stopping event consuming")
with self._event_queue.mutex:
self._event_queue.queue.clear()
self._logger.debug("Event queue cleared")
self._internal_async_client.stop_background_network_io()
self._logger.debug("Network thread stopped")
self._internal_async_client.clean_up_event_callbacks()
self._logger.debug("Event callbacks cleared")
def wait_until_it_stops(self, timeout_sec):
self._logger.debug("Waiting for event consumer to completely stop")
return self._stopper.wait(timeout=timeout_sec)
def is_fully_stopped(self):
return self._stopper.is_set()
def _dispatch(self):
while self._is_running:
with self._cv:
if self._event_queue.empty():
self._cv.wait(self.MAX_DISPATCH_INTERNAL_SEC)
else:
while not self._event_queue.empty():
self._dispatch_one()
self._stopper.set()
self._logger.debug("Exiting dispatching loop...")
def _dispatch_one(self):
mid, event_type, data = self._event_queue.get()
if mid:
self._dispatch_methods[event_type](mid, data)
self._internal_async_client.invoke_event_callback(mid, data=data)
# We need to make sure disconnect event gets dispatched and then we stop the consumer
if self._need_to_stop_dispatching(mid):
self.stop()
def _need_to_stop_dispatching(self, mid):
status = self._client_status.get_status()
return (ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status) \
and mid == FixedEventMids.DISCONNECT_MID
def _dispatch_connack(self, mid, rc):
status = self._client_status.get_status()
self._logger.debug("Dispatching [connack] event")
if self._need_recover():
if ClientStatus.STABLE != status: # To avoid multiple connack dispatching
self._logger.debug("Has recovery job")
clean_up_debt = Thread(target=self._clean_up_debt)
clean_up_debt.start()
else:
self._logger.debug("No need for recovery")
self._client_status.set_status(ClientStatus.STABLE)
def _need_recover(self):
return self._subscription_manager.list_records() or self._offline_requests_manager.has_more()
def _clean_up_debt(self):
self._handle_resubscribe()
self._handle_draining()
self._client_status.set_status(ClientStatus.STABLE)
def _handle_resubscribe(self):
subscriptions = self._subscription_manager.list_records()
if subscriptions and not self._has_user_disconnect_request():
self._logger.debug("Start resubscribing")
self._client_status.set_status(ClientStatus.RESUBSCRIBE)
for topic, (qos, message_callback, ack_callback) in subscriptions:
if self._has_user_disconnect_request():
self._logger.debug("User disconnect detected")
break
self._internal_async_client.subscribe(topic, qos, ack_callback)
def _handle_draining(self):
if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request():
self._logger.debug("Start draining")
self._client_status.set_status(ClientStatus.DRAINING)
while self._offline_requests_manager.has_more():
if self._has_user_disconnect_request():
self._logger.debug("User disconnect detected")
break
offline_request = self._offline_requests_manager.get_next()
if offline_request:
self._offline_request_handlers[offline_request.type](offline_request)
time.sleep(self._draining_interval_sec)
def _has_user_disconnect_request(self):
return ClientStatus.USER_DISCONNECT == self._client_status.get_status()
def _dispatch_disconnect(self, mid, rc):
self._logger.debug("Dispatching [disconnect] event")
status = self._client_status.get_status()
if ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status:
pass
else:
self._client_status.set_status(ClientStatus.ABNORMAL_DISCONNECT)
# For puback, suback and unsuback, ack callback invocation is handled in dispatch_one
# Do nothing in the event dispatching itself
def _dispatch_puback(self, mid, rc):
self._logger.debug("Dispatching [puback] event")
def _dispatch_suback(self, mid, rc):
self._logger.debug("Dispatching [suback] event")
def _dispatch_unsuback(self, mid, rc):
self._logger.debug("Dispatching [unsuback] event")
def _dispatch_message(self, mid, message):
self._logger.debug("Dispatching [message] event")
subscriptions = self._subscription_manager.list_records()
if subscriptions:
for topic, (qos, message_callback, _) in subscriptions:
if topic_matches_sub(topic, message.topic) and message_callback:
message_callback(None, None, message) # message_callback(client, userdata, message)
def _handle_offline_publish(self, request):
topic, payload, qos, retain = request.data
self._internal_async_client.publish(topic, payload, qos, retain)
self._logger.debug("Processed offline publish request")
def _handle_offline_subscribe(self, request):
topic, qos, message_callback, ack_callback = request.data
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
self._internal_async_client.subscribe(topic, qos, ack_callback)
self._logger.debug("Processed offline subscribe request")
def _handle_offline_unsubscribe(self, request):
topic, ack_callback = request.data
self._subscription_manager.remove_record(topic)
self._internal_async_client.unsubscribe(topic, ack_callback)
self._logger.debug("Processed offline unsubscribe request")
class SubscriptionManager(object):
_logger = logging.getLogger(__name__)
def __init__(self):
self._subscription_map = dict()
def add_record(self, topic, qos, message_callback, ack_callback):
self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos)
self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None
def remove_record(self, topic):
self._logger.debug("Removing subscription record: %s", topic)
if self._subscription_map.get(topic): # Ignore topics that are never subscribed to
del self._subscription_map[topic]
else:
self._logger.warn("Removing attempt for non-exist subscription record: %s", topic)
def list_records(self):
return list(self._subscription_map.items())
class OfflineRequestsManager(object):
_logger = logging.getLogger(__name__)
def __init__(self, max_size, drop_behavior):
self._queue = OfflineRequestQueue(max_size, drop_behavior)
def has_more(self):
return len(self._queue) > 0
def add_one(self, request):
return self._queue.append(request)
def get_next(self):
if self.has_more():
return self._queue.pop(0)
else:
return None

View File

@@ -0,0 +1,373 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import AWSIoTPythonSDK
from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer
from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer
from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager
from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_OPERATION_TIMEOUT_SEC
from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX
from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException
from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv31
from threading import Condition
from threading import Event
import logging
import sys
if sys.version_info[0] < 3:
from Queue import Queue
else:
from queue import Queue
class MqttCore(object):
_logger = logging.getLogger(__name__)
def __init__(self, client_id, clean_session, protocol, use_wss):
self._use_wss = use_wss
self._username = ""
self._password = None
self._enable_metrics_collection = True
self._event_queue = Queue()
self._event_cv = Condition()
self._event_producer = EventProducer(self._event_cv, self._event_queue)
self._client_status = ClientStatusContainer()
self._internal_async_client = InternalAsyncMqttClient(client_id, clean_session, protocol, use_wss)
self._subscription_manager = SubscriptionManager()
self._offline_requests_manager = OfflineRequestsManager(-1, DropBehaviorTypes.DROP_NEWEST) # Infinite queue
self._event_consumer = EventConsumer(self._event_cv,
self._event_queue,
self._internal_async_client,
self._subscription_manager,
self._offline_requests_manager,
self._client_status)
self._connect_disconnect_timeout_sec = DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
self._operation_timeout_sec = DEFAULT_OPERATION_TIMEOUT_SEC
self._init_offline_request_exceptions()
self._init_workers()
self._logger.info("MqttCore initialized")
self._logger.info("Client id: %s" % client_id)
self._logger.info("Protocol version: %s" % ("MQTTv3.1" if protocol == MQTTv31 else "MQTTv3.1.1"))
self._logger.info("Authentication type: %s" % ("SigV4 WebSocket" if use_wss else "TLSv1.2 certificate based Mutual Auth."))
def _init_offline_request_exceptions(self):
self._offline_request_queue_disabled_exceptions = {
RequestTypes.PUBLISH : publishQueueDisabledException(),
RequestTypes.SUBSCRIBE : subscribeQueueDisabledException(),
RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException()
}
self._offline_request_queue_full_exceptions = {
RequestTypes.PUBLISH : publishQueueFullException(),
RequestTypes.SUBSCRIBE : subscribeQueueFullException(),
RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException()
}
def _init_workers(self):
self._internal_async_client.register_internal_event_callbacks(self._event_producer.on_connect,
self._event_producer.on_disconnect,
self._event_producer.on_publish,
self._event_producer.on_subscribe,
self._event_producer.on_unsubscribe,
self._event_producer.on_message)
def _start_workers(self):
self._event_consumer.start()
def use_wss(self):
return self._use_wss
# Used for general message event reception
def on_message(self, message):
pass
# Used for general online event notification
def on_online(self):
pass
# Used for general offline event notification
def on_offline(self):
pass
def configure_cert_credentials(self, cert_credentials_provider):
self._logger.info("Configuring certificates...")
self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider)
def configure_iam_credentials(self, iam_credentials_provider):
self._logger.info("Configuring custom IAM credentials...")
self._internal_async_client.set_iam_credentials_provider(iam_credentials_provider)
def configure_endpoint(self, endpoint_provider):
self._logger.info("Configuring endpoint...")
self._internal_async_client.set_endpoint_provider(endpoint_provider)
def configure_connect_disconnect_timeout_sec(self, connect_disconnect_timeout_sec):
self._logger.info("Configuring connect/disconnect time out: %f sec" % connect_disconnect_timeout_sec)
self._connect_disconnect_timeout_sec = connect_disconnect_timeout_sec
def configure_operation_timeout_sec(self, operation_timeout_sec):
self._logger.info("Configuring MQTT operation time out: %f sec" % operation_timeout_sec)
self._operation_timeout_sec = operation_timeout_sec
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
self._logger.info("Configuring reconnect back off timing...")
self._logger.info("Base quiet time: %f sec" % base_reconnect_quiet_sec)
self._logger.info("Max quiet time: %f sec" % max_reconnect_quiet_sec)
self._logger.info("Stable connection time: %f sec" % stable_connection_sec)
self._internal_async_client.configure_reconnect_back_off(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
def configure_alpn_protocols(self):
self._logger.info("Configuring alpn protocols...")
self._internal_async_client.configure_alpn_protocols([ALPN_PROTCOLS])
def configure_last_will(self, topic, payload, qos, retain=False):
self._logger.info("Configuring last will...")
self._internal_async_client.configure_last_will(topic, payload, qos, retain)
def clear_last_will(self):
self._logger.info("Clearing last will...")
self._internal_async_client.clear_last_will()
def configure_username_password(self, username, password=None):
self._logger.info("Configuring username and password...")
self._username = username
self._password = password
def configure_socket_factory(self, socket_factory):
self._logger.info("Configuring socket factory...")
self._internal_async_client.set_socket_factory(socket_factory)
def enable_metrics_collection(self):
self._enable_metrics_collection = True
def disable_metrics_collection(self):
self._enable_metrics_collection = False
def configure_offline_requests_queue(self, max_size, drop_behavior):
self._logger.info("Configuring offline requests queueing: max queue size: %d", max_size)
self._offline_requests_manager = OfflineRequestsManager(max_size, drop_behavior)
self._event_consumer.update_offline_requests_manager(self._offline_requests_manager)
def configure_draining_interval_sec(self, draining_interval_sec):
self._logger.info("Configuring offline requests queue draining interval: %f sec", draining_interval_sec)
self._event_consumer.update_draining_interval_sec(draining_interval_sec)
def connect(self, keep_alive_sec):
self._logger.info("Performing sync connect...")
event = Event()
self.connect_async(keep_alive_sec, self._create_blocking_ack_callback(event))
if not event.wait(self._connect_disconnect_timeout_sec):
self._logger.error("Connect timed out")
raise connectTimeoutException()
return True
def connect_async(self, keep_alive_sec, ack_callback=None):
self._logger.info("Performing async connect...")
self._logger.info("Keep-alive: %f sec" % keep_alive_sec)
self._start_workers()
self._load_callbacks()
self._load_username_password()
try:
self._client_status.set_status(ClientStatus.CONNECT)
rc = self._internal_async_client.connect(keep_alive_sec, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Connect error: %d", rc)
raise connectError(rc)
except Exception as e:
# Provided any error in connect, we should clean up the threads that have been created
self._event_consumer.stop()
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
self._logger.error("Time out in waiting for event consumer to stop")
else:
self._logger.debug("Event consumer stopped")
self._client_status.set_status(ClientStatus.IDLE)
raise e
return FixedEventMids.CONNACK_MID
def _load_callbacks(self):
self._logger.debug("Passing in general notification callbacks to internal client...")
self._internal_async_client.on_online = self.on_online
self._internal_async_client.on_offline = self.on_offline
self._internal_async_client.on_message = self.on_message
def _load_username_password(self):
username_candidate = self._username
if self._enable_metrics_collection:
username_candidate += METRICS_PREFIX
username_candidate += AWSIoTPythonSDK.__version__
self._internal_async_client.set_username_password(username_candidate, self._password)
def disconnect(self):
self._logger.info("Performing sync disconnect...")
event = Event()
self.disconnect_async(self._create_blocking_ack_callback(event))
if not event.wait(self._connect_disconnect_timeout_sec):
self._logger.error("Disconnect timed out")
raise disconnectTimeoutException()
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
self._logger.error("Disconnect timed out in waiting for event consumer")
raise disconnectTimeoutException()
return True
def disconnect_async(self, ack_callback=None):
self._logger.info("Performing async disconnect...")
self._client_status.set_status(ClientStatus.USER_DISCONNECT)
rc = self._internal_async_client.disconnect(ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Disconnect error: %d", rc)
raise disconnectError(rc)
return FixedEventMids.DISCONNECT_MID
def publish(self, topic, payload, qos, retain=False):
self._logger.info("Performing sync publish...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
else:
if qos > 0:
event = Event()
rc, mid = self._publish_async(topic, payload, qos, retain, self._create_blocking_ack_callback(event))
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Publish timed out")
raise publishTimeoutException()
else:
self._publish_async(topic, payload, qos, retain)
ret = True
return ret
def publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
self._logger.info("Performing async publish...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._publish_async(topic, payload, qos, retain, ack_callback)
return mid
def _publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
rc, mid = self._internal_async_client.publish(topic, payload, qos, retain, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Publish error: %d", rc)
raise publishError(rc)
return rc, mid
def subscribe(self, topic, qos, message_callback=None):
self._logger.info("Performing sync subscribe...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None))
else:
event = Event()
rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback)
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Subscribe timed out")
raise subscribeTimeoutException()
ret = True
return ret
def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
self._logger.info("Performing async subscribe...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback)
return mid
def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Subscribe error: %d", rc)
raise subscribeError(rc)
return rc, mid
def unsubscribe(self, topic):
self._logger.info("Performing sync unsubscribe...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None))
else:
event = Event()
rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event))
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Unsubscribe timed out")
raise unsubscribeTimeoutException()
ret = True
return ret
def unsubscribe_async(self, topic, ack_callback=None):
self._logger.info("Performing async unsubscribe...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._unsubscribe_async(topic, ack_callback)
return mid
def _unsubscribe_async(self, topic, ack_callback=None):
self._subscription_manager.remove_record(topic)
rc, mid = self._internal_async_client.unsubscribe(topic, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Unsubscribe error: %d", rc)
raise unsubscribeError(rc)
return rc, mid
def _create_blocking_ack_callback(self, event):
def ack_callback(mid, data=None):
event.set()
return ack_callback
def _handle_offline_request(self, type, data):
self._logger.info("Offline request detected!")
offline_request = QueueableRequest(type, data)
append_result = self._offline_requests_manager.add_one(offline_request)
if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result:
self._logger.error("Offline request queue has been disabled")
raise self._offline_request_queue_disabled_exceptions[type]
if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result:
self._logger.error("Offline request queue is full")
raise self._offline_request_queue_full_exceptions[type]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
import logging
import uuid
from threading import Timer, Lock, Thread
class _shadowRequestToken:
URN_PREFIX_LENGTH = 9
def getNextToken(self):
return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix
class _basicJSONParser:
def setString(self, srcString):
self._rawString = srcString
self._dictionObject = None
def regenerateString(self):
return json.dumps(self._dictionaryObject)
def getAttributeValue(self, srcAttributeKey):
return self._dictionaryObject.get(srcAttributeKey)
def setAttributeValue(self, srcAttributeKey, srcAttributeValue):
self._dictionaryObject[srcAttributeKey] = srcAttributeValue
def validateJSON(self):
try:
self._dictionaryObject = json.loads(self._rawString)
except ValueError:
return False
return True
class deviceShadow:
_logger = logging.getLogger(__name__)
def __init__(self, srcShadowName, srcIsPersistentSubscribe, srcShadowManager):
"""
The class that denotes a local/client-side device shadow instance.
Users can perform shadow operations on this instance to retrieve and modify the
corresponding shadow JSON document in AWS IoT Cloud. The following shadow operations
are available:
- Get
- Update
- Delete
- Listen on delta
- Cancel listening on delta
This is returned from :code:`AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient.createShadowWithName` function call.
No need to call directly from user scripts.
"""
if srcShadowName is None or srcIsPersistentSubscribe is None or srcShadowManager is None:
raise TypeError("None type inputs detected.")
self._shadowName = srcShadowName
# Tool handler
self._shadowManagerHandler = srcShadowManager
self._basicJSONParserHandler = _basicJSONParser()
self._tokenHandler = _shadowRequestToken()
# Properties
self._isPersistentSubscribe = srcIsPersistentSubscribe
self._lastVersionInSync = -1 # -1 means not initialized
self._isGetSubscribed = False
self._isUpdateSubscribed = False
self._isDeleteSubscribed = False
self._shadowSubscribeCallbackTable = dict()
self._shadowSubscribeCallbackTable["get"] = None
self._shadowSubscribeCallbackTable["delete"] = None
self._shadowSubscribeCallbackTable["update"] = None
self._shadowSubscribeCallbackTable["delta"] = None
self._shadowSubscribeStatusTable = dict()
self._shadowSubscribeStatusTable["get"] = 0
self._shadowSubscribeStatusTable["delete"] = 0
self._shadowSubscribeStatusTable["update"] = 0
self._tokenPool = dict()
self._dataStructureLock = Lock()
def _doNonPersistentUnsubscribe(self, currentAction):
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, currentAction)
self._logger.info("Unsubscribed to " + currentAction + " accepted/rejected topics for deviceShadow: " + self._shadowName)
def generalCallback(self, client, userdata, message):
# In Py3.x, message.payload comes in as a bytes(string)
# json.loads needs a string input
with self._dataStructureLock:
currentTopic = message.topic
currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta
currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta
payloadUTF8String = message.payload.decode('utf-8')
# get/delete/update: Need to deal with token, timer and unsubscribe
if currentAction in ["get", "delete", "update"]:
# Check for token
self._basicJSONParserHandler.setString(payloadUTF8String)
if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON
currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken")
if currentToken is not None:
self._logger.debug("shadow message clientToken: " + currentToken)
if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token
# Sync local version when it is an accepted response
self._logger.debug("Token is in the pool. Type: " + currentType)
if currentType == "accepted":
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
# If it is get/update accepted response, we need to sync the local version
if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete":
self._lastVersionInSync = incomingVersion
# If it is a delete accepted, we need to reset the version
else:
self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response
# Cancel the timer and clear the token
self._tokenPool[currentToken].cancel()
del self._tokenPool[currentToken]
# Need to unsubscribe?
self._shadowSubscribeStatusTable[currentAction] -= 1
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0:
self._shadowSubscribeStatusTable[currentAction] = 0
processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction])
processNonPersistentUnsubscribe.start()
# Custom callback
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken])
processCustomCallback.start()
# delta: Watch for version
else:
currentType += "/" + self._parseTopicShadowName(currentTopic)
# Sync local version
self._basicJSONParserHandler.setString(payloadUTF8String)
if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
if incomingVersion is not None and incomingVersion > self._lastVersionInSync:
self._lastVersionInSync = incomingVersion
# Custom callback
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None])
processCustomCallback.start()
def _parseTopicAction(self, srcTopic):
ret = None
fragments = srcTopic.split('/')
if fragments[5] == "delta":
ret = "delta"
else:
ret = fragments[4]
return ret
def _parseTopicType(self, srcTopic):
fragments = srcTopic.split('/')
return fragments[5]
def _parseTopicShadowName(self, srcTopic):
fragments = srcTopic.split('/')
return fragments[2]
def _timerHandler(self, srcActionName, srcToken):
with self._dataStructureLock:
# Don't crash if we try to remove an unknown token
if srcToken not in self._tokenPool:
self._logger.warn('Tried to remove non-existent token from pool: %s' % str(srcToken))
return
# Remove the token
del self._tokenPool[srcToken]
# Need to unsubscribe?
self._shadowSubscribeStatusTable[srcActionName] -= 1
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0:
self._shadowSubscribeStatusTable[srcActionName] = 0
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName)
# Notify time-out issue
if self._shadowSubscribeCallbackTable.get(srcActionName) is not None:
self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.")
self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken)
def shadowGet(self, srcCallback, srcTimeout):
"""
**Description**
Retrieve the device shadow JSON document from AWS IoT by publishing an empty JSON document to the
corresponding shadow topics. Shadow response topics will be subscribed to receive responses from
AWS IoT regarding the result of the get operation. Retrieved shadow JSON document will be available
in the registered callback. If no response is received within the provided timeout, a timeout
notification will be passed into the registered callback.
**Syntax**
.. code:: python
# Retrieve the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowGet(customCallback, 5)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["get"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["get"] += 1
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken])
self._basicJSONParserHandler.setString("{}")
self._basicJSONParserHandler.validateJSON()
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
currentPayload = self._basicJSONParserHandler.regenerateString()
# Two subscriptions
if not self._isPersistentSubscribe or not self._isGetSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self.generalCallback)
self._isGetSubscribed = True
self._logger.info("Subscribed to get accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload)
# Start the timer
self._tokenPool[currentToken].start()
return currentToken
def shadowDelete(self, srcCallback, srcTimeout):
"""
**Description**
Delete the device shadow from AWS IoT by publishing an empty JSON document to the corresponding
shadow topics. Shadow response topics will be subscribed to receive responses from AWS IoT
regarding the result of the get operation. Responses will be available in the registered callback.
If no response is received within the provided timeout, a timeout notification will be passed into
the registered callback.
**Syntax**
.. code:: python
# Delete the device shadow from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowDelete(customCallback, 5)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["delete"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["delete"] += 1
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken])
self._basicJSONParserHandler.setString("{}")
self._basicJSONParserHandler.validateJSON()
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
currentPayload = self._basicJSONParserHandler.regenerateString()
# Two subscriptions
if not self._isPersistentSubscribe or not self._isDeleteSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self.generalCallback)
self._isDeleteSubscribed = True
self._logger.info("Subscribed to delete accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload)
# Start the timer
self._tokenPool[currentToken].start()
return currentToken
def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout):
"""
**Description**
Update the device shadow JSON document string from AWS IoT by publishing the provided JSON
document to the corresponding shadow topics. Shadow response topics will be subscribed to
receive responses from AWS IoT regarding the result of the get operation. Response will be
available in the registered callback. If no response is received within the provided timeout,
a timeout notification will be passed into the registered callback.
**Syntax**
.. code:: python
# Update the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowUpdate(newShadowJSONDocumentString, customCallback, 5)
**Parameters**
*srcJSONPayload* - JSON document string used to update shadow JSON document in AWS IoT.
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
# Validate JSON
self._basicJSONParserHandler.setString(srcJSONPayload)
if self._basicJSONParserHandler.validateJSON():
with self._dataStructureLock:
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken])
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString()
# Update callback data structure
self._shadowSubscribeCallbackTable["update"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["update"] += 1
# Two subscriptions
if not self._isPersistentSubscribe or not self._isUpdateSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self.generalCallback)
self._isUpdateSubscribed = True
self._logger.info("Subscribed to update accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken)
# Start the timer
self._tokenPool[currentToken].start()
else:
raise ValueError("Invalid JSON file.")
return currentToken
def shadowRegisterDeltaCallback(self, srcCallback):
"""
**Description**
Listen on delta topics for this device shadow by subscribing to delta topics. Whenever there
is a difference between the desired and reported state, the registered callback will be called
and the delta payload will be available in the callback.
**Syntax**
.. code:: python
# Listen on delta topics for BotShadow
BotShadow.shadowRegisterDeltaCallback(customCallback)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
**Returns**
None
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["delta"] = srcCallback
# One subscription
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self.generalCallback)
self._logger.info("Subscribed to delta topic for deviceShadow: " + self._shadowName)
def shadowUnregisterDeltaCallback(self):
"""
**Description**
Cancel listening on delta topics for this device shadow by unsubscribing to delta topics. There will
be no delta messages received after this API call even though there is a difference between the
desired and reported state.
**Syntax**
.. code:: python
# Cancel listening on delta topics for BotShadow
BotShadow.shadowUnregisterDeltaCallback()
**Parameters**
None
**Returns**
None
"""
with self._dataStructureLock:
# Update callback data structure
del self._shadowSubscribeCallbackTable["delta"]
# One unsubscription
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, "delta")
self._logger.info("Unsubscribed to delta topics for deviceShadow: " + self._shadowName)

View File

@@ -0,0 +1,83 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import logging
import time
from threading import Lock
class _shadowAction:
_actionType = ["get", "update", "delete", "delta"]
def __init__(self, srcShadowName, srcActionName):
if srcActionName is None or srcActionName not in self._actionType:
raise TypeError("Unsupported shadow action.")
self._shadowName = srcShadowName
self._actionName = srcActionName
self.isDelta = srcActionName == "delta"
if self.isDelta:
self._topicDelta = "$aws/things/" + str(self._shadowName) + "/shadow/update/delta"
else:
self._topicGeneral = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName)
self._topicAccept = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/accepted"
self._topicReject = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/rejected"
def getTopicGeneral(self):
return self._topicGeneral
def getTopicAccept(self):
return self._topicAccept
def getTopicReject(self):
return self._topicReject
def getTopicDelta(self):
return self._topicDelta
class shadowManager:
_logger = logging.getLogger(__name__)
def __init__(self, srcMQTTCore):
# Load in mqttCore
if srcMQTTCore is None:
raise TypeError("None type inputs detected.")
self._mqttCoreHandler = srcMQTTCore
self._shadowSubUnsubOperationLock = Lock()
def basicShadowPublish(self, srcShadowName, srcShadowAction, srcPayload):
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
self._mqttCoreHandler.publish(currentShadowAction.getTopicGeneral(), srcPayload, 0, False)
def basicShadowSubscribe(self, srcShadowName, srcShadowAction, srcCallback):
with self._shadowSubUnsubOperationLock:
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
if currentShadowAction.isDelta:
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback)
else:
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback)
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback)
time.sleep(2)
def basicShadowUnsubscribe(self, srcShadowName, srcShadowAction):
with self._shadowSubUnsubOperationLock:
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
if currentShadowAction.isDelta:
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta())
else:
self._logger.debug(currentShadowAction.getTopicAccept())
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept())
self._logger.debug(currentShadowAction.getTopicReject())
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject())

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class DropBehaviorTypes(object):
DROP_OLDEST = 0
DROP_NEWEST = 1

View File

@@ -0,0 +1,92 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class CredentialsProvider(object):
def __init__(self):
self._ca_path = ""
def set_ca_path(self, ca_path):
self._ca_path = ca_path
def get_ca_path(self):
return self._ca_path
class CertificateCredentialsProvider(CredentialsProvider):
def __init__(self):
CredentialsProvider.__init__(self)
self._cert_path = ""
self._key_path = ""
def set_cert_path(self,cert_path):
self._cert_path = cert_path
def set_key_path(self, key_path):
self._key_path = key_path
def get_cert_path(self):
return self._cert_path
def get_key_path(self):
return self._key_path
class IAMCredentialsProvider(CredentialsProvider):
def __init__(self):
CredentialsProvider.__init__(self)
self._aws_access_key_id = ""
self._aws_secret_access_key = ""
self._aws_session_token = ""
def set_access_key_id(self, access_key_id):
self._aws_access_key_id = access_key_id
def set_secret_access_key(self, secret_access_key):
self._aws_secret_access_key = secret_access_key
def set_session_token(self, session_token):
self._aws_session_token = session_token
def get_access_key_id(self):
return self._aws_access_key_id
def get_secret_access_key(self):
return self._aws_secret_access_key
def get_session_token(self):
return self._aws_session_token
class EndpointProvider(object):
def __init__(self):
self._host = ""
self._port = -1
def set_host(self, host):
self._host = host
def set_port(self, port):
self._port = port
def get_host(self):
return self._host
def get_port(self):
return self._port

View File

@@ -0,0 +1,153 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import AWSIoTPythonSDK.exception.operationTimeoutException as operationTimeoutException
import AWSIoTPythonSDK.exception.operationError as operationError
# Serial Exception
class acceptTimeoutException(Exception):
def __init__(self, msg="Accept Timeout"):
self.message = msg
# MQTT Operation Timeout Exception
class connectTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Connect Timeout"):
self.message = msg
class disconnectTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Disconnect Timeout"):
self.message = msg
class publishTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Publish Timeout"):
self.message = msg
class subscribeTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Subscribe Timeout"):
self.message = msg
class unsubscribeTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Unsubscribe Timeout"):
self.message = msg
# MQTT Operation Error
class connectError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Connect Error: " + str(errorCode)
class disconnectError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Disconnect Error: " + str(errorCode)
class publishError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Publish Error: " + str(errorCode)
class publishQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Publish Queue Full"
class publishQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline publish request dropped because queueing is disabled"
class subscribeError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Subscribe Error: " + str(errorCode)
class subscribeQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Subscribe Queue Full"
class subscribeQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline subscribe request dropped because queueing is disabled"
class unsubscribeError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Unsubscribe Error: " + str(errorCode)
class unsubscribeQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Unsubscribe Queue Full"
class unsubscribeQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline unsubscribe request dropped because queueing is disabled"
# Websocket Error
class wssNoKeyInEnvironmentError(operationError.operationError):
def __init__(self):
self.message = "No AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY detected in $ENV."
class wssHandShakeError(operationError.operationError):
def __init__(self):
self.message = "Error in WSS handshake."
# Greengrass Discovery Error
class DiscoveryDataNotFoundException(operationError.operationError):
def __init__(self):
self.message = "No discovery data found"
class DiscoveryTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, message="Discovery request timed out"):
self.message = message
class DiscoveryInvalidRequestException(operationError.operationError):
def __init__(self):
self.message = "Invalid discovery request"
class DiscoveryUnauthorizedException(operationError.operationError):
def __init__(self):
self.message = "Discovery request not authorized"
class DiscoveryThrottlingException(operationError.operationError):
def __init__(self):
self.message = "Too many discovery requests"
class DiscoveryFailure(operationError.operationError):
def __init__(self, message):
self.message = message
# Client Error
class ClientError(Exception):
def __init__(self, message):
self.message = message

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class operationError(Exception):
def __init__(self, msg="Operation Error"):
self.message = msg

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class operationTimeoutException(Exception):
def __init__(self, msg="Operation Timeout"):
self.message = msg

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
__version__ = "1.4.8"

View File

@@ -0,0 +1,466 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
KEY_GROUP_LIST = "GGGroups"
KEY_GROUP_ID = "GGGroupId"
KEY_CORE_LIST = "Cores"
KEY_CORE_ARN = "thingArn"
KEY_CA_LIST = "CAs"
KEY_CONNECTIVITY_INFO_LIST = "Connectivity"
KEY_CONNECTIVITY_INFO_ID = "Id"
KEY_HOST_ADDRESS = "HostAddress"
KEY_PORT_NUMBER = "PortNumber"
KEY_METADATA = "Metadata"
class ConnectivityInfo(object):
"""
Class the stores one set of the connectivity information.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, id, host, port, metadata):
self._id = id
self._host = host
self._port = port
self._metadata = metadata
@property
def id(self):
"""
Connectivity Information Id.
"""
return self._id
@property
def host(self):
"""
Host address.
"""
return self._host
@property
def port(self):
"""
Port number.
"""
return self._port
@property
def metadata(self):
"""
Metadata string.
"""
return self._metadata
class CoreConnectivityInfo(object):
"""
Class that stores the connectivity information for a Greengrass core.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, coreThingArn, groupId):
self._core_thing_arn = coreThingArn
self._group_id = groupId
self._connectivity_info_dict = dict()
@property
def coreThingArn(self):
"""
Thing arn for this Greengrass core.
"""
return self._core_thing_arn
@property
def groupId(self):
"""
Greengrass group id that this Greengrass core belongs to.
"""
return self._group_id
@property
def connectivityInfoList(self):
"""
The list of connectivity information that this Greengrass core has.
"""
return list(self._connectivity_info_dict.values())
def getConnectivityInfo(self, id):
"""
**Description**
Used for quickly accessing a certain set of connectivity information by id.
**Syntax**
.. code:: python
myCoreConnectivityInfo.getConnectivityInfo("CoolId")
**Parameters**
*id* - The id for the desired connectivity information.
**Return**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
"""
return self._connectivity_info_dict.get(id)
def appendConnectivityInfo(self, connectivityInfo):
"""
**Description**
Used for adding a new set of connectivity information to the list for this Greengrass core. This is used by the
SDK internally. No need to call directly from user scripts.
**Syntax**
.. code:: python
myCoreConnectivityInfo.appendConnectivityInfo(newInfo)
**Parameters**
*connectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
**Returns**
None
"""
self._connectivity_info_dict[connectivityInfo.id] = connectivityInfo
class GroupConnectivityInfo(object):
"""
Class that stores the connectivity information for a specific Greengrass group.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, groupId):
self._group_id = groupId
self._core_connectivity_info_dict = dict()
self._ca_list = list()
@property
def groupId(self):
"""
Id for this Greengrass group.
"""
return self._group_id
@property
def coreConnectivityInfoList(self):
"""
A list of Greengrass cores
(:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object) that belong to this
Greengrass group.
"""
return list(self._core_connectivity_info_dict.values())
@property
def caList(self):
"""
A list of CA content strings for this Greengrass group.
"""
return self._ca_list
def getCoreConnectivityInfo(self, coreThingArn):
"""
**Description**
Used to retrieve the corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
object by core thing arn.
**Syntax**
.. code:: python
myGroupConnectivityInfo.getCoreConnectivityInfo("YourOwnArnString")
**Parameters**
coreThingArn - Thing arn for the desired Greengrass core.
**Returns**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.CoreConnectivityInfo` object.
"""
return self._core_connectivity_info_dict.get(coreThingArn)
def appendCoreConnectivityInfo(self, coreConnectivityInfo):
"""
**Description**
Used to append new core connectivity information to this group connectivity information. This is used by the
SDK internally. No need to call directly from user scripts.
**Syntax**
.. code:: python
myGroupConnectivityInfo.appendCoreConnectivityInfo(newCoreConnectivityInfo)
**Parameters**
*coreConnectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object.
**Returns**
None
"""
self._core_connectivity_info_dict[coreConnectivityInfo.coreThingArn] = coreConnectivityInfo
def appendCa(self, ca):
"""
**Description**
Used to append new CA content string to this group connectivity information. This is used by the SDK internally.
No need to call directly from user scripts.
**Syntax**
.. code:: python
myGroupConnectivityInfo.appendCa("CaContentString")
**Parameters**
*ca* - Group CA content string.
**Returns**
None
"""
self._ca_list.append(ca)
class DiscoveryInfo(object):
"""
Class that stores the discovery information coming back from the discovery request.
This is the data model for easy access to the discovery information from the discovery request function call. No
need to call directly from user scripts.
"""
def __init__(self, rawJson):
self._raw_json = rawJson
@property
def rawJson(self):
"""
JSON response string that contains the discovery information. This is reserved in case users want to do
some process by themselves.
"""
return self._raw_json
def getAllCores(self):
"""
**Description**
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
object for this discovery information. The retrieved cores could be from different Greengrass groups. This is
designed for uses who want to iterate through all available cores at the same time, regardless of which group
those cores are in.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllCores()
**Parameters**
None
**Returns**
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivtyInfo` object.
"""
groups_list = self.getAllGroups()
core_list = list()
for group in groups_list:
core_list.extend(group.coreConnectivityInfoList)
return core_list
def getAllCas(self):
"""
**Description**
Used to retrieve the list of :code:`(groupId, caContent)` pair for this discovery information. The retrieved
pairs could be from different Greengrass groups. This is designed for users who want to iterate through all
available cores/groups/CAs at the same time, regardless of which group those CAs belong to.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllCas()
**Parameters**
None
**Returns**
List of :code:`(groupId, caContent)` string pair, where :code:`caContent` is the CA content string and
:code:`groupId` is the group id that this CA belongs to.
"""
group_list = self.getAllGroups()
ca_list = list()
for group in group_list:
for ca in group.caList:
ca_list.append((group.groupId, ca))
return ca_list
def getAllGroups(self):
"""
**Description**
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo`
object for this discovery information. This is designed for users who want to iterate through all available
groups that this Greengrass aware device (GGAD) belongs to.
**Syntax**
.. code:: python
myDiscoveryInfo.getAllGroups()
**Parameters**
None
**Returns**
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object.
"""
groups_dict = self.toObjectAtGroupLevel()
return list(groups_dict.values())
def toObjectAtGroupLevel(self):
"""
**Description**
Used to get a dictionary of Greengrass group discovery information, with group id string as key and the
corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object as the
value. This is designed for users who know exactly which group, which core and which set of connectivity info
they want to use for the Greengrass aware device to connect.
**Syntax**
.. code:: python
# Get to the targeted connectivity information for a specific core in a specific group
groupLevelDiscoveryInfoObj = myDiscoveryInfo.toObjectAtGroupLevel()
groupConnectivityInfoObj = groupLevelDiscoveryInfoObj.toObjectAtGroupLevel("IKnowMyGroupId")
coreConnectivityInfoObj = groupConnectivityInfoObj.getCoreConnectivityInfo("IKnowMyCoreThingArn")
connectivityInfo = coreConnectivityInfoObj.getConnectivityInfo("IKnowMyConnectivityInfoSetId")
# Now retrieve the detailed information
caList = groupConnectivityInfoObj.caList
host = connectivityInfo.host
port = connectivityInfo.port
metadata = connectivityInfo.metadata
# Actual connecting logic follows...
"""
groups_object = json.loads(self._raw_json)
groups_dict = dict()
for group_object in groups_object[KEY_GROUP_LIST]:
group_info = self._decode_group_info(group_object)
groups_dict[group_info.groupId] = group_info
return groups_dict
def _decode_group_info(self, group_object):
group_id = group_object[KEY_GROUP_ID]
group_info = GroupConnectivityInfo(group_id)
for core in group_object[KEY_CORE_LIST]:
core_info = self._decode_core_info(core, group_id)
group_info.appendCoreConnectivityInfo(core_info)
for ca in group_object[KEY_CA_LIST]:
group_info.appendCa(ca)
return group_info
def _decode_core_info(self, core_object, group_id):
core_info = CoreConnectivityInfo(core_object[KEY_CORE_ARN], group_id)
for connectivity_info_object in core_object[KEY_CONNECTIVITY_INFO_LIST]:
connectivity_info = ConnectivityInfo(connectivity_info_object[KEY_CONNECTIVITY_INFO_ID],
connectivity_info_object[KEY_HOST_ADDRESS],
connectivity_info_object[KEY_PORT_NUMBER],
connectivity_info_object.get(KEY_METADATA,''))
core_info.appendConnectivityInfo(connectivity_info)
return core_info

View File

@@ -0,0 +1,426 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure
from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo
from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder
import re
import sys
import ssl
import time
import errno
import logging
import socket
import platform
if platform.system() == 'Windows':
EAGAIN = errno.WSAEWOULDBLOCK
else:
EAGAIN = errno.EAGAIN
class DiscoveryInfoProvider(object):
REQUEST_TYPE_PREFIX = "GET "
PAYLOAD_PREFIX = "/greengrass/discover/thing/"
PAYLOAD_SUFFIX = " HTTP/1.1\r\n" # Space in the front
HOST_PREFIX = "Host: "
HOST_SUFFIX = "\r\n\r\n"
HTTP_PROTOCOL = r"HTTP/1.1 "
CONTENT_LENGTH = r"content-length: "
CONTENT_LENGTH_PATTERN = CONTENT_LENGTH + r"([0-9]+)\r\n"
HTTP_RESPONSE_CODE_PATTERN = HTTP_PROTOCOL + r"([0-9]+) "
HTTP_SC_200 = "200"
HTTP_SC_400 = "400"
HTTP_SC_401 = "401"
HTTP_SC_404 = "404"
HTTP_SC_429 = "429"
LOW_LEVEL_RC_COMPLETE = 0
LOW_LEVEL_RC_TIMEOUT = -1
_logger = logging.getLogger(__name__)
def __init__(self, caPath="", certPath="", keyPath="", host="", port=8443, timeoutSec=120):
"""
The class that provides functionality to perform a Greengrass discovery process to the cloud.
Users can perform Greengrass discovery process for a specific Greengrass aware device to retrieve
connectivity/identity information of Greengrass cores within the same group.
**Syntax**
.. code:: python
from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider
# Create a discovery information provider
myDiscoveryInfoProvider = DiscoveryInfoProvider()
# Create a discovery information provider with custom configuration
myDiscoveryInfoProvider = DiscoveryInfoProvider(caPath=myCAPath, certPath=myCertPath, keyPath=myKeyPath, host=myHost, timeoutSec=myTimeoutSec)
**Parameters**
*caPath* - Path to read the root CA file.
*certPath* - Path to read the certificate file.
*keyPath* - Path to read the private key file.
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
been timed out.
**Returns**
AWSIoTPythonSDK.core.greengrass.discovery.providers.DiscoveryInfoProvider object
"""
self._ca_path = caPath
self._cert_path = certPath
self._key_path = keyPath
self._host = host
self._port = port
self._timeout_sec = timeoutSec
self._expected_exception_map = {
self.HTTP_SC_400 : DiscoveryInvalidRequestException(),
self.HTTP_SC_401 : DiscoveryUnauthorizedException(),
self.HTTP_SC_404 : DiscoveryDataNotFoundException(),
self.HTTP_SC_429 : DiscoveryThrottlingException()
}
def configureEndpoint(self, host, port=8443):
"""
**Description**
Used to configure the host address and port number for the discovery request to hit. Should be called before
the discovery request happens.
**Syntax**
.. code:: python
# Using default port configuration, 8443
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com")
# Customize port configuration
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com", port=8888)
**Parameters**
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
**Returns**
None
"""
self._host = host
self._port = port
def configureCredentials(self, caPath, certPath, keyPath):
"""
**Description**
Used to configure the credentials for discovery request. Should be called before the discovery request happens.
**Syntax**
.. code:: python
myDiscoveryInfoProvider.configureCredentials("my/ca/path", "my/cert/path", "my/key/path")
**Parameters**
*caPath* - Path to read the root CA file.
*certPath* - Path to read the certificate file.
*keyPath* - Path to read the private key file.
**Returns**
None
"""
self._ca_path = caPath
self._cert_path = certPath
self._key_path = keyPath
def configureTimeout(self, timeoutSec):
"""
**Description**
Used to configure the time out in seconds for discovery request sending/response waiting. Should be called before
the discovery request happens.
**Syntax**
.. code:: python
# Configure the time out for discovery to be 10 seconds
myDiscoveryInfoProvider.configureTimeout(10)
**Parameters**
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
been timed out.
**Returns**
None
"""
self._timeout_sec = timeoutSec
def discover(self, thingName):
"""
**Description**
Perform the discovery request for the given Greengrass aware device thing name.
**Syntax**
.. code:: python
myDiscoveryInfoProvider.discover(thingName="myGGAD")
**Parameters**
*thingName* - Greengrass aware device thing name.
**Returns**
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.DiscoveryInfo` object.
"""
self._logger.info("Starting discover request...")
self._logger.info("Endpoint: " + self._host + ":" + str(self._port))
self._logger.info("Target thing: " + thingName)
sock = self._create_tcp_connection()
ssl_sock = self._create_ssl_connection(sock)
self._raise_on_timeout(self._send_discovery_request(ssl_sock, thingName))
status_code, response_body = self._receive_discovery_response(ssl_sock)
return self._raise_if_not_200(status_code, response_body)
def _create_tcp_connection(self):
self._logger.debug("Creating tcp connection...")
try:
if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
sock = socket.create_connection((self._host, self._port))
else:
sock = socket.create_connection((self._host, self._port), source_address=("", 0))
return sock
except socket.error as err:
if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN:
raise
self._logger.debug("Created tcp connection.")
def _create_ssl_connection(self, sock):
self._logger.debug("Creating ssl connection...")
ssl_protocol_version = ssl.PROTOCOL_SSLv23
if self._port == 443:
ssl_context = SSLContextBuilder()\
.with_ca_certs(self._ca_path)\
.with_cert_key_pair(self._cert_path, self._key_path)\
.with_cert_reqs(ssl.CERT_REQUIRED)\
.with_check_hostname(True)\
.with_ciphers(None)\
.with_alpn_protocols(['x-amzn-http-ca'])\
.build()
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False)
ssl_sock.do_handshake()
else:
ssl_sock = ssl.wrap_socket(sock,
certfile=self._cert_path,
keyfile=self._key_path,
ca_certs=self._ca_path,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl_protocol_version)
self._logger.debug("Matching host name...")
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
self._tls_match_hostname(ssl_sock)
else:
ssl.match_hostname(ssl_sock.getpeercert(), self._host)
return ssl_sock
def _tls_match_hostname(self, ssl_sock):
try:
cert = ssl_sock.getpeercert()
except AttributeError:
# the getpeercert can throw Attribute error: object has no attribute 'peer_certificate'
# Don't let that crash the whole client. See also: http://bugs.python.org/issue13721
raise ssl.SSLError('Not connected')
san = cert.get('subjectAltName')
if san:
have_san_dns = False
for (key, value) in san:
if key == 'DNS':
have_san_dns = True
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
return
if key == 'IP Address':
have_san_dns = True
if value.lower() == self._host.lower():
return
if have_san_dns:
# Only check subject if subjectAltName dns not found.
raise ssl.SSLError('Certificate subject does not match remote hostname.')
subject = cert.get('subject')
if subject:
for ((key, value),) in subject:
if key == 'commonName':
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
return
raise ssl.SSLError('Certificate subject does not match remote hostname.')
def _host_matches_cert(self, host, cert_host):
if cert_host[0:2] == "*.":
if cert_host.count("*") != 1:
return False
host_match = host.split(".", 1)[1]
cert_match = cert_host.split(".", 1)[1]
if host_match == cert_match:
return True
else:
return False
else:
if host == cert_host:
return True
else:
return False
def _send_discovery_request(self, ssl_sock, thing_name):
request = self.REQUEST_TYPE_PREFIX + \
self.PAYLOAD_PREFIX + \
thing_name + \
self.PAYLOAD_SUFFIX + \
self.HOST_PREFIX + \
self._host + ":" + str(self._port) + \
self.HOST_SUFFIX
self._logger.debug("Sending discover request: " + request)
start_time = time.time()
desired_length_to_write = len(request)
actual_length_written = 0
while True:
try:
length_written = ssl_sock.write(request.encode("utf-8"))
actual_length_written += length_written
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
pass
if actual_length_written == desired_length_to_write:
return self.LOW_LEVEL_RC_COMPLETE
if start_time + self._timeout_sec < time.time():
return self.LOW_LEVEL_RC_TIMEOUT
def _receive_discovery_response(self, ssl_sock):
self._logger.debug("Receiving discover response header...")
rc1, response_header = self._receive_until(ssl_sock, self._got_two_crlfs)
status_code, body_length = self._handle_discovery_response_header(rc1, response_header.decode("utf-8"))
self._logger.debug("Receiving discover response body...")
rc2, response_body = self._receive_until(ssl_sock, self._got_enough_bytes, body_length)
response_body = self._handle_discovery_response_body(rc2, response_body.decode("utf-8"))
return status_code, response_body
def _receive_until(self, ssl_sock, criteria_function, extra_data=None):
start_time = time.time()
response = bytearray()
number_bytes_read = 0
while True: # Python does not have do-while
try:
response.append(self._convert_to_int_py3(ssl_sock.read(1)))
number_bytes_read += 1
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
pass
if criteria_function((number_bytes_read, response, extra_data)):
return self.LOW_LEVEL_RC_COMPLETE, response
if start_time + self._timeout_sec < time.time():
return self.LOW_LEVEL_RC_TIMEOUT, response
def _convert_to_int_py3(self, input_char):
try:
return ord(input_char)
except:
return input_char
def _got_enough_bytes(self, data):
number_bytes_read, response, target_length = data
return number_bytes_read == int(target_length)
def _got_two_crlfs(self, data):
number_bytes_read, response, extra_data_unused = data
number_of_crlf = 2
has_enough_bytes = number_bytes_read > number_of_crlf * 2 - 1
if has_enough_bytes:
end_of_received = response[number_bytes_read - number_of_crlf * 2 : number_bytes_read]
expected_end_of_response = b"\r\n" * number_of_crlf
return end_of_received == expected_end_of_response
else:
return False
def _handle_discovery_response_header(self, rc, response):
self._raise_on_timeout(rc)
http_status_code_matcher = re.compile(self.HTTP_RESPONSE_CODE_PATTERN)
http_status_code_matched_groups = http_status_code_matcher.match(response)
content_length_matcher = re.compile(self.CONTENT_LENGTH_PATTERN)
content_length_matched_groups = content_length_matcher.search(response)
return http_status_code_matched_groups.group(1), content_length_matched_groups.group(1)
def _handle_discovery_response_body(self, rc, response):
self._raise_on_timeout(rc)
return response
def _raise_on_timeout(self, rc):
if rc == self.LOW_LEVEL_RC_TIMEOUT:
raise DiscoveryTimeoutException()
def _raise_if_not_200(self, status_code, response_body): # response_body here is str in Py3
if status_code != self.HTTP_SC_200:
expected_exception = self._expected_exception_map.get(status_code)
if expected_exception:
raise expected_exception
else:
raise DiscoveryFailure(response_body)
return DiscoveryInfo(response_body)

View File

@@ -0,0 +1,156 @@
# /*
# * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
_BASE_THINGS_TOPIC = "$aws/things/"
_NOTIFY_OPERATION = "notify"
_NOTIFY_NEXT_OPERATION = "notify-next"
_GET_OPERATION = "get"
_START_NEXT_OPERATION = "start-next"
_WILDCARD_OPERATION = "+"
_UPDATE_OPERATION = "update"
_ACCEPTED_REPLY = "accepted"
_REJECTED_REPLY = "rejected"
_WILDCARD_REPLY = "#"
#Members of this enum are tuples
_JOB_ID_REQUIRED_INDEX = 1
_JOB_OPERATION_INDEX = 2
_STATUS_KEY = 'status'
_STATUS_DETAILS_KEY = 'statusDetails'
_EXPECTED_VERSION_KEY = 'expectedVersion'
_EXEXCUTION_NUMBER_KEY = 'executionNumber'
_INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState'
_INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument'
_CLIENT_TOKEN_KEY = 'clientToken'
_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes'
#The type of job topic.
class jobExecutionTopicType(object):
JOB_UNRECOGNIZED_TOPIC = (0, False, '')
JOB_GET_PENDING_TOPIC = (1, False, _GET_OPERATION)
JOB_START_NEXT_TOPIC = (2, False, _START_NEXT_OPERATION)
JOB_DESCRIBE_TOPIC = (3, True, _GET_OPERATION)
JOB_UPDATE_TOPIC = (4, True, _UPDATE_OPERATION)
JOB_NOTIFY_TOPIC = (5, False, _NOTIFY_OPERATION)
JOB_NOTIFY_NEXT_TOPIC = (6, False, _NOTIFY_NEXT_OPERATION)
JOB_WILDCARD_TOPIC = (7, False, _WILDCARD_OPERATION)
#Members of this enum are tuples
_JOB_SUFFIX_INDEX = 1
#The type of reply topic, or #JOB_REQUEST_TYPE for topics that are not replies.
class jobExecutionTopicReplyType(object):
JOB_UNRECOGNIZED_TOPIC_TYPE = (0, '')
JOB_REQUEST_TYPE = (1, '')
JOB_ACCEPTED_REPLY_TYPE = (2, '/' + _ACCEPTED_REPLY)
JOB_REJECTED_REPLY_TYPE = (3, '/' + _REJECTED_REPLY)
JOB_WILDCARD_REPLY_TYPE = (4, '/' + _WILDCARD_REPLY)
_JOB_STATUS_INDEX = 1
class jobExecutionStatus(object):
JOB_EXECUTION_STATUS_NOT_SET = (0, None)
JOB_EXECUTION_QUEUED = (1, 'QUEUED')
JOB_EXECUTION_IN_PROGRESS = (2, 'IN_PROGRESS')
JOB_EXECUTION_FAILED = (3, 'FAILED')
JOB_EXECUTION_SUCCEEDED = (4, 'SUCCEEDED')
JOB_EXECUTION_CANCELED = (5, 'CANCELED')
JOB_EXECUTION_REJECTED = (6, 'REJECTED')
JOB_EXECUTION_UNKNOWN_STATUS = (99, None)
def _getExecutionStatus(jobStatus):
try:
return jobStatus[_JOB_STATUS_INDEX]
except KeyError:
return None
def _isWithoutJobIdTopicType(srcJobExecTopicType):
return (srcJobExecTopicType == jobExecutionTopicType.JOB_GET_PENDING_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_START_NEXT_TOPIC
or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC)
class thingJobManager:
def __init__(self, thingName, clientToken = None):
self._thingName = thingName
self._clientToken = clientToken
def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None):
if self._thingName is None:
return None
#Verify topics that only support request type, actually have request type specified for reply
if (srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) and srcJobExecTopicReplyType != jobExecutionTopicReplyType.JOB_REQUEST_TYPE:
return None
#Verify topics that explicitly do not want a job ID do not have one specified
if (jobId is not None and _isWithoutJobIdTopicType(srcJobExecTopicType)):
return None
#Verify job ID is present if the topic requires one
if jobId is None and srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
return None
#Ensure the job operation is a non-empty string
if srcJobExecTopicType[_JOB_OPERATION_INDEX] == '':
return None
if srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
return '{0}{1}/jobs/{2}/{3}{4}'.format(_BASE_THINGS_TOPIC, self._thingName, str(jobId), srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
elif srcJobExecTopicType == jobExecutionTopicType.JOB_WILDCARD_TOPIC:
return '{0}{1}/jobs/#'.format(_BASE_THINGS_TOPIC, self._thingName)
else:
return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None):
executionStatus = _getExecutionStatus(status)
if executionStatus is None:
return None
payload = {_STATUS_KEY: executionStatus}
if statusDetails:
payload[_STATUS_DETAILS_KEY] = statusDetails
if expectedVersion > 0:
payload[_EXPECTED_VERSION_KEY] = str(expectedVersion)
if executionNumber > 0:
payload[_EXEXCUTION_NUMBER_KEY] = str(executionNumber)
if includeJobExecutionState:
payload[_INCLUDE_JOB_EXECUTION_STATE_KEY] = True
if includeJobDocument:
payload[_INCLUDE_JOB_DOCUMENT_KEY] = True
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
if stepTimeoutInMinutes is not None:
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
return json.dumps(payload)
def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True):
payload = {_INCLUDE_JOB_DOCUMENT_KEY: includeJobDocument}
if executionNumber > 0:
payload[_EXEXCUTION_NUMBER_KEY] = executionNumber
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
return json.dumps(payload)
def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None):
payload = {}
if self._clientToken is not None:
payload[_CLIENT_TOKEN_KEY] = self._clientToken
if statusDetails is not None:
payload[_STATUS_DETAILS_KEY] = statusDetails
if stepTimeoutInMinutes is not None:
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
return json.dumps(payload)
def serializeClientTokenPayload(self):
return json.dumps({_CLIENT_TOKEN_KEY: self._clientToken}) if self._clientToken is not None else '{}'

View File

@@ -0,0 +1,63 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
try:
import ssl
except:
ssl = None
class SSLContextBuilder(object):
def __init__(self):
self.check_supportability()
self._ssl_context = ssl.create_default_context()
def check_supportability(self):
if ssl is None:
raise RuntimeError("This platform has no SSL/TLS.")
if not hasattr(ssl, "SSLContext"):
raise NotImplementedError("This platform does not support SSLContext. Python 2.7.10+/3.5+ is required.")
if not hasattr(ssl.SSLContext, "set_alpn_protocols"):
raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.")
def with_ca_certs(self, ca_certs):
self._ssl_context.load_verify_locations(ca_certs)
return self
def with_cert_key_pair(self, cert_file, key_file):
self._ssl_context.load_cert_chain(cert_file, key_file)
return self
def with_cert_reqs(self, cert_reqs):
self._ssl_context.verify_mode = cert_reqs
return self
def with_check_hostname(self, check_hostname):
self._ssl_context.check_hostname = check_hostname
return self
def with_ciphers(self, ciphers):
if ciphers is not None:
self._ssl_context.set_ciphers(ciphers) # set_ciphers() does not allow None input. Use default (do nothing) if None
return self
def with_alpn_protocols(self, alpn_protocols):
self._ssl_context.set_alpn_protocols(alpn_protocols)
return self
def build(self):
return self._ssl_context

View File

@@ -0,0 +1,699 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
# This class implements the progressive backoff logic for auto-reconnect.
# It manages the reconnect wait time for the current reconnect, controling
# when to increase it and when to reset it.
import re
import sys
import ssl
import errno
import struct
import socket
import base64
import time
import threading
import logging
import os
from datetime import datetime
import hashlib
import hmac
from AWSIoTPythonSDK.exception.AWSIoTExceptions import ClientError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssHandShakeError
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
try:
from urllib.parse import quote # Python 3+
except ImportError:
from urllib import quote
# INI config file handling
try:
from configparser import ConfigParser # Python 3+
from configparser import NoOptionError
from configparser import NoSectionError
except ImportError:
from ConfigParser import ConfigParser
from ConfigParser import NoOptionError
from ConfigParser import NoSectionError
class ProgressiveBackOffCore:
# Logger
_logger = logging.getLogger(__name__)
def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20):
# The base reconnection time in seconds, default 1
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
# The maximum reconnection time in seconds, default 32
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
# The minimum time in milliseconds that a connection must be maintained in order to be considered stable
# Default 20
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
# Current backOff time in seconds, init to equal to 0
self._currentBackoffTimeSecond = 1
# Handler for timer
self._resetBackoffTimer = None
# For custom progressiveBackoff timing configuration
def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond):
if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0:
self._logger.error("init: Negative time configuration detected.")
raise ValueError("Negative time configuration detected.")
if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond:
self._logger.error("init: Min connect time should be bigger than base reconnect time.")
raise ValueError("Min connect time should be bigger than base reconnect time.")
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
self._currentBackoffTimeSecond = 1
# Block the reconnect logic for _currentBackoffTimeSecond
# Update the currentBackoffTimeSecond for the next reconnect
# Cancel the in-waiting timer for resetting backOff time
# This should get called only when a disconnect/reconnect happens
def backOff(self):
self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.")
if self._resetBackoffTimer is not None:
# Cancel the timer
self._resetBackoffTimer.cancel()
# Block the reconnect logic
time.sleep(self._currentBackoffTimeSecond)
# Update the backoff time
if self._currentBackoffTimeSecond == 0:
# This is the first attempt to connect, set it to base
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
else:
# r_cur = min(2^n*r_base, r_max)
self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2)
# Start the timer for resetting _currentBackoffTimeSecond
# Will be cancelled upon calling backOff
def startStableConnectionTimer(self):
self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond,
self._connectionStableThenResetBackoffTime)
self._resetBackoffTimer.start()
def stopStableConnectionTimer(self):
if self._resetBackoffTimer is not None:
# Cancel the timer
self._resetBackoffTimer.cancel()
# Timer callback to reset _currentBackoffTimeSecond
# If the connection is stable for longer than _minimumConnectTimeSecond,
# reset the currentBackoffTimeSecond to _baseReconnectTimeSecond
def _connectionStableThenResetBackoffTime(self):
self._logger.debug(
"stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.")
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
class SigV4Core:
_logger = logging.getLogger(__name__)
def __init__(self):
self._aws_access_key_id = ""
self._aws_secret_access_key = ""
self._aws_session_token = ""
self._credentialConfigFilePath = "~/.aws/credentials"
def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken):
self._aws_access_key_id = srcAWSAccessKeyID
self._aws_secret_access_key = srcAWSSecretAccessKey
self._aws_session_token = srcAWSSessionToken
def _createAmazonDate(self):
# Returned as a unicode string in Py3.x
amazonDate = []
currentTime = datetime.utcnow()
YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ')
YMD = YMDHMS[0:YMDHMS.index('T')]
amazonDate.append(YMD)
amazonDate.append(YMDHMS)
return amazonDate
def _sign(self, key, message):
# Returned as a utf-8 byte string in Py3.x
return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest()
def _getSignatureKey(self, key, dateStamp, regionName, serviceName):
# Returned as a utf-8 byte string in Py3.x
kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp)
kRegion = self._sign(kDate, regionName)
kService = self._sign(kRegion, serviceName)
kSigning = self._sign(kService, 'aws4_request')
return kSigning
def _checkIAMCredentials(self):
# Check custom config
ret = self._checkKeyInCustomConfig()
# Check environment variables
if not ret:
ret = self._checkKeyInEnv()
# Check files
if not ret:
ret = self._checkKeyInFiles()
# All credentials returned as unicode strings in Py3.x
return ret
def _checkKeyInEnv(self):
ret = dict()
self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN')
if self._aws_access_key_id is not None and self._aws_secret_access_key is not None:
ret["aws_access_key_id"] = self._aws_access_key_id
ret["aws_secret_access_key"] = self._aws_secret_access_key
# We do not necessarily need session token...
if self._aws_session_token is not None:
ret["aws_session_token"] = self._aws_session_token
self._logger.debug("IAM credentials from env var.")
return ret
def _checkKeyInINIDefault(self, srcConfigParser, sectionName):
ret = dict()
# Check aws_access_key_id and aws_secret_access_key
try:
ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id")
ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key")
except NoOptionError:
self._logger.warn("Cannot find IAM keyID/secretKey in credential file.")
# We do not continue searching if we cannot even get IAM id/secret right
if len(ret) == 2:
# Check aws_session_token, optional
try:
ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token")
except NoOptionError:
self._logger.debug("No AWS Session Token found.")
return ret
def _checkKeyInFiles(self):
credentialFile = None
credentialConfig = None
ret = dict()
# Should be compatible with aws cli default credential configuration
# *NIX/Windows
try:
# See if we get the file
credentialConfig = ConfigParser()
credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/
credentialConfig.read(credentialFilePath)
# Now we have the file, start looking for credentials...
# 'default' section
ret = self._checkKeyInINIDefault(credentialConfig, "default")
if not ret:
# 'DEFAULT' section
ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT")
self._logger.debug("IAM credentials from file.")
except IOError:
self._logger.debug("No IAM credential configuration file in " + credentialFilePath)
except NoSectionError:
self._logger.error("Cannot find IAM 'default' section.")
return ret
def _checkKeyInCustomConfig(self):
ret = dict()
if self._aws_access_key_id != "" and self._aws_secret_access_key != "":
ret["aws_access_key_id"] = self._aws_access_key_id
ret["aws_secret_access_key"] = self._aws_secret_access_key
# We do not necessarily need session token...
if self._aws_session_token != "":
ret["aws_session_token"] = self._aws_session_token
self._logger.debug("IAM credentials from custom config.")
return ret
def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path):
# Return the endpoint as unicode string in 3.x
# Gather all the facts
amazonDate = self._createAmazonDate()
amazonDateSimple = amazonDate[0] # Unicode in 3.x
amazonDateComplex = amazonDate[1] # Unicode in 3.x
allKeys = self._checkIAMCredentials() # Unicode in 3.x
if not self._hasCredentialsNecessaryForWebsocket(allKeys):
raise wssNoKeyInEnvironmentError()
else:
# Because of self._hasCredentialsNecessaryForWebsocket(...), keyID and secretKey should not be None from here
keyID = allKeys["aws_access_key_id"]
secretKey = allKeys["aws_secret_access_key"]
# amazonDateSimple and amazonDateComplex are guaranteed not to be None
queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \
"&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \
"&X-Amz-Date=" + amazonDateComplex + \
"&X-Amz-Expires=86400" + \
"&X-Amz-SignedHeaders=host" # Unicode in 3.x
hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x
# Create the string to sign
signedHeaders = "host"
canonicalHeaders = "host:" + host + "\n"
canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x
hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x
stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x
# Sign it
signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName)
signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest()
# generate url
url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature
# See if we have STS token, if we do, add it
awsSessionTokenCandidate = allKeys.get("aws_session_token")
if awsSessionTokenCandidate is not None and len(awsSessionTokenCandidate) != 0:
aws_session_token = allKeys["aws_session_token"]
url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x
self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url)
return url
def _hasCredentialsNecessaryForWebsocket(self, allKeys):
awsAccessKeyIdCandidate = allKeys.get("aws_access_key_id")
awsSecretAccessKeyCandidate = allKeys.get("aws_secret_access_key")
# None value is NOT considered as valid entries
validEntries = awsAccessKeyIdCandidate is not None and awsAccessKeyIdCandidate is not None
if validEntries:
# Empty value is NOT considered as valid entries
validEntries &= (len(awsAccessKeyIdCandidate) != 0 and len(awsSecretAccessKeyCandidate) != 0)
return validEntries
# This is an internal class that buffers the incoming bytes into an
# internal buffer until it gets the full desired length of bytes.
# At that time, this bufferedReader will be reset.
# *Error handling:
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
# leave them to the paho _packet_read for further handling (ignored and try
# again when data is available.
# For other errors, leave them to the paho _packet_read for error reporting.
class _BufferedReader:
_sslSocket = None
_internalBuffer = None
_remainedLength = -1
_bufferingInProgress = False
def __init__(self, sslSocket):
self._sslSocket = sslSocket
self._internalBuffer = bytearray()
self._bufferingInProgress = False
def _reset(self):
self._internalBuffer = bytearray()
self._remainedLength = -1
self._bufferingInProgress = False
def read(self, numberOfBytesToBeBuffered):
if not self._bufferingInProgress: # If last read is completed...
self._remainedLength = numberOfBytesToBeBuffered
self._bufferingInProgress = True # Now we start buffering a new length of bytes
while self._remainedLength > 0: # Read in a loop, always try to read in the remained length
# If the data is temporarily not available, socket.error will be raised and catched by paho
dataChunk = self._sslSocket.read(self._remainedLength)
# There is a chance where the server terminates the connection without closing the socket.
# If that happens, let's raise an exception and enter the reconnect flow.
if not dataChunk:
raise socket.error(errno.ECONNABORTED, 0)
self._internalBuffer.extend(dataChunk) # Buffer the data
self._remainedLength -= len(dataChunk) # Update the remained length
# The requested length of bytes is buffered, recover the context and return it
# Otherwise error should be raised
ret = self._internalBuffer
self._reset()
return ret # This should always be bytearray
# This is the internal class that sends requested data out chunk by chunk according
# to the availablity of the socket write operation. If the requested bytes of data
# (after encoding) needs to be sent out in separate socket write operations (most
# probably be interrupted by the error socket.error (errno = ssl.SSL_ERROR_WANT_WRITE).)
# , the write pointer is stored to ensure that the continued bytes will be sent next
# time this function gets called.
# *Error handling:
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
# leave them to the paho _packet_read for further handling (ignored and try
# again when data is available.
# For other errors, leave them to the paho _packet_read for error reporting.
class _BufferedWriter:
_sslSocket = None
_internalBuffer = None
_writingInProgress = False
_requestedDataLength = -1
def __init__(self, sslSocket):
self._sslSocket = sslSocket
self._internalBuffer = bytearray()
self._writingInProgress = False
self._requestedDataLength = -1
def _reset(self):
self._internalBuffer = bytearray()
self._writingInProgress = False
self._requestedDataLength = -1
# Input data for this function needs to be an encoded wss frame
# Always request for packet[pos=0:] (raw MQTT data)
def write(self, encodedData, payloadLength):
# encodedData should always be bytearray
# Check if we have a frame that is partially sent
if not self._writingInProgress:
self._internalBuffer = encodedData
self._writingInProgress = True
self._requestedDataLength = payloadLength
# Now, write as much as we can
lengthWritten = self._sslSocket.write(self._internalBuffer)
self._internalBuffer = self._internalBuffer[lengthWritten:]
# This MQTT packet has been sent out in a wss frame, completely
if len(self._internalBuffer) == 0:
ret = self._requestedDataLength
self._reset()
return ret
# This socket write is half-baked...
else:
return 0 # Ensure that the 'pos' inside the MQTT packet never moves since we have not finished the transmission of this encoded frame
class SecuredWebSocketCore:
# Websocket Constants
_OP_CONTINUATION = 0x0
_OP_TEXT = 0x1
_OP_BINARY = 0x2
_OP_CONNECTION_CLOSE = 0x8
_OP_PING = 0x9
_OP_PONG = 0xa
# Websocket Connect Status
_WebsocketConnectInit = -1
_WebsocketDisconnected = 1
_logger = logging.getLogger(__name__)
def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecretAccessKey="", AWSSessionToken=""):
self._connectStatus = self._WebsocketConnectInit
# Handlers
self._sslSocket = socket
self._sigV4Handler = self._createSigV4Core()
self._sigV4Handler.setIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken)
# Endpoint Info
self._hostAddress = hostAddress
self._portNumber = portNumber
# Section Flags
self._hasOpByte = False
self._hasPayloadLengthFirst = False
self._hasPayloadLengthExtended = False
self._hasMaskKey = False
self._hasPayload = False
# Properties for current websocket frame
self._isFIN = False
self._RSVBits = None
self._opCode = None
self._needMaskKey = False
self._payloadLengthBytesLength = 1
self._payloadLength = 0
self._maskKey = None
self._payloadDataBuffer = bytearray() # Once the whole wss connection is lost, there is no need to keep the buffered payload
try:
self._handShake(hostAddress, portNumber)
except wssNoKeyInEnvironmentError: # Handle SigV4 signing and websocket handshaking errors
raise ValueError("No Access Key/KeyID Error")
except wssHandShakeError:
raise ValueError("Websocket Handshake Error")
except ClientError as e:
raise ValueError(e.message)
# Now we have a socket with secured websocket...
self._bufferedReader = _BufferedReader(self._sslSocket)
self._bufferedWriter = _BufferedWriter(self._sslSocket)
def _createSigV4Core(self):
return SigV4Core()
def _generateMaskKey(self):
return bytearray(os.urandom(4))
# os.urandom returns ascii str in 2.x, converted to bytearray
# os.urandom returns bytes in 3.x, converted to bytearray
def _reset(self): # Reset the context for wss frame reception
# Control info
self._hasOpByte = False
self._hasPayloadLengthFirst = False
self._hasPayloadLengthExtended = False
self._hasMaskKey = False
self._hasPayload = False
# Frame Info
self._isFIN = False
self._RSVBits = None
self._opCode = None
self._needMaskKey = False
self._payloadLengthBytesLength = 1
self._payloadLength = 0
self._maskKey = None
# Never reset the payloadData since we might have fragmented MQTT data from the pervious frame
def _generateWSSKey(self):
return base64.b64encode(os.urandom(128)) # Bytes
def _verifyWSSResponse(self, response, clientKey):
# Check if it is a 101 response
rawResponse = response.strip().lower()
if b"101 switching protocols" not in rawResponse or b"upgrade: websocket" not in rawResponse or b"connection: upgrade" not in rawResponse:
return False
# Parse out the sec-websocket-accept
WSSAcceptKeyIndex = response.strip().index(b"sec-websocket-accept: ") + len(b"sec-websocket-accept: ")
rawSecWebSocketAccept = response.strip()[WSSAcceptKeyIndex:].split(b"\r\n")[0].strip()
# Verify the WSSAcceptKey
return self._verifyWSSAcceptKey(rawSecWebSocketAccept, clientKey)
def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey):
GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
verifyServerAcceptKey = base64.b64encode((hashlib.sha1(clientKey + GUID)).digest()) # Bytes
return srcAcceptKey == verifyServerAcceptKey
def _handShake(self, hostAddress, portNumber):
CRLF = "\r\n"
IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*"
matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress)
if not matched:
raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress)
region = matched.group(2)
signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt")
# Now we got a signedURL
path = signedURL[signedURL.index("/mqtt"):]
# Assemble HTTP request headers
Method = "GET " + path + " HTTP/1.1" + CRLF
Host = "Host: " + hostAddress + CRLF
Connection = "Connection: " + "Upgrade" + CRLF
Upgrade = "Upgrade: " + "websocket" + CRLF
secWebSocketVersion = "Sec-WebSocket-Version: " + "13" + CRLF
rawSecWebSocketKey = self._generateWSSKey() # Bytes
secWebSocketKey = "sec-websocket-key: " + rawSecWebSocketKey.decode('utf-8') + CRLF # Should be randomly generated...
secWebSocketProtocol = "Sec-WebSocket-Protocol: " + "mqttv3.1" + CRLF
secWebSocketExtensions = "Sec-WebSocket-Extensions: " + "permessage-deflate; client_max_window_bits" + CRLF
# Send the HTTP request
# Ensure that we are sending bytes, not by any chance unicode string
handshakeBytes = Method + Host + Connection + Upgrade + secWebSocketVersion + secWebSocketProtocol + secWebSocketExtensions + secWebSocketKey + CRLF
handshakeBytes = handshakeBytes.encode('utf-8')
self._sslSocket.write(handshakeBytes)
# Read it back (Non-blocking socket)
timeStart = time.time()
wssHandshakeResponse = bytearray()
while len(wssHandshakeResponse) == 0:
try:
wssHandshakeResponse += self._sslSocket.read(1024) # Response is always less than 1024 bytes
except socket.error as err:
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
if time.time() - timeStart > self._getTimeoutSec():
raise err # We make sure that reconnect gets retried in Paho upon a wss reconnect response timeout
else:
raise err
# Verify response
# Now both wssHandshakeResponse and rawSecWebSocketKey are byte strings
if not self._verifyWSSResponse(wssHandshakeResponse, rawSecWebSocketKey):
raise wssHandShakeError()
else:
pass
def _getTimeoutSec(self):
return DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
# Used to create a single wss frame
# Assume that the maximum length of a MQTT packet never exceeds the maximum length
# for a wss frame. Therefore, the FIN bit for the encoded frame will always be 1.
# Frames are encoded as BINARY frames.
def _encodeFrame(self, rawPayload, opCode, masked=1):
ret = bytearray()
# Op byte
opByte = 0x80 | opCode # Always a FIN, no RSV bits
ret.append(opByte)
# Payload Length bytes
maskBit = masked
payloadLength = len(rawPayload)
if payloadLength <= 125:
ret.append((maskBit << 7) | payloadLength)
elif payloadLength <= 0xffff: # 16-bit unsigned int
ret.append((maskBit << 7) | 126)
ret.extend(struct.pack("!H", payloadLength))
elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0)
ret.append((maskBit << 7) | 127)
ret.extend(struct.pack("!Q", payloadLength))
else: # Overflow
raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.")
if maskBit == 1:
# Mask key bytes
maskKey = self._generateMaskKey()
ret.extend(maskKey)
# Mask the payload
payloadBytes = bytearray(rawPayload)
if maskBit == 1:
for i in range(0, payloadLength):
payloadBytes[i] ^= maskKey[i % 4]
ret.extend(payloadBytes)
# Return the assembled wss frame
return ret
# Used for the wss client to close a wss connection
# Create and send a masked wss closing frame
def _closeWssConnection(self):
# Frames sent from client to server must be masked
self._sslSocket.write(self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, masked=1))
# Used for the wss client to respond to a wss PING from server
# Create and send a masked PONG frame
def _sendPONG(self):
# Frames sent from client to server must be masked
self._sslSocket.write(self._encodeFrame(b"", self._OP_PONG, masked=1))
# Override sslSocket read. Always read from the wss internal payload buffer, which
# contains the masked MQTT packet. This read will decode ONE wss frame every time
# and load in the payload for MQTT _packet_read. At any time, MQTT _packet_read
# should be able to read a complete MQTT packet from the payload (buffered per wss
# frame payload). If the MQTT packet is break into separate wss frames, different
# chunks will be buffered in separate frames and MQTT _packet_read will not be able
# to collect a complete MQTT packet to operate on until the necessary payload is
# fully buffered.
# If the requested number of bytes are not available, SSL_ERROR_WANT_READ will be
# raised to trigger another call of _packet_read when the data is available again.
def read(self, numberOfBytes):
# Check if we have enough data for paho
# _payloadDataBuffer will not be empty ony when the payload of a new wss frame
# has been unmasked.
if len(self._payloadDataBuffer) >= numberOfBytes:
ret = self._payloadDataBuffer[0:numberOfBytes]
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
# struct.unpack(fmt, string) # Py2.x
# struct.unpack(fmt, buffer) # Py3.x
# Here ret is always in bytes (buffer interface)
if sys.version_info[0] < 3: # Py2.x
ret = str(ret)
return ret
# Emmm, We don't. Try to buffer from the socket (It's a new wss frame).
if not self._hasOpByte: # Check if we need to buffer OpByte
opByte = self._bufferedReader.read(1)
self._isFIN = (opByte[0] & 0x80) == 0x80
self._RSVBits = (opByte[0] & 0x70)
self._opCode = (opByte[0] & 0x0f)
self._hasOpByte = True # Finished buffering opByte
# Check if any of the RSV bits are set, if so, close the connection
# since client never sends negotiated extensions
if self._RSVBits != 0x0:
self._closeWssConnection()
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray()
raise socket.error(ssl.SSL_ERROR_WANT_READ, "RSV bits set with NO negotiated extensions.")
if not self._hasPayloadLengthFirst: # Check if we need to buffer First Payload Length byte
payloadLengthFirst = self._bufferedReader.read(1)
self._hasPayloadLengthFirst = True # Finished buffering first byte of payload length
self._needMaskKey = (payloadLengthFirst[0] & 0x80) == 0x80
payloadLengthFirstByteArray = bytearray()
payloadLengthFirstByteArray.extend(payloadLengthFirst)
self._payloadLength = (payloadLengthFirstByteArray[0] & 0x7f)
if self._payloadLength == 126:
self._payloadLengthBytesLength = 2
self._hasPayloadLengthExtended = False # Force to buffer the extended
elif self._payloadLength == 127:
self._payloadLengthBytesLength = 8
self._hasPayloadLengthExtended = False # Force to buffer the extended
else: # _payloadLength <= 125:
self._hasPayloadLengthExtended = True # No need to buffer extended payload length
if not self._hasPayloadLengthExtended: # Check if we need to buffer Extended Payload Length bytes
payloadLengthExtended = self._bufferedReader.read(self._payloadLengthBytesLength)
self._hasPayloadLengthExtended = True
if sys.version_info[0] < 3:
payloadLengthExtended = str(payloadLengthExtended)
if self._payloadLengthBytesLength == 2:
self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0]
else: # _payloadLengthBytesLength == 8
self._payloadLength = struct.unpack("!Q", payloadLengthExtended)[0]
if self._needMaskKey: # Response from server is masked, close the connection
self._closeWssConnection()
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray()
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Server response masked, closing connection and try again.")
if not self._hasPayload: # Check if we need to buffer the payload
payloadForThisFrame = self._bufferedReader.read(self._payloadLength)
self._hasPayload = True
# Client side should never received a masked packet from the server side
# Unmask it as needed
#if self._needMaskKey:
# for i in range(0, self._payloadLength):
# payloadForThisFrame[i] ^= self._maskKey[i % 4]
# Append it to the internal payload buffer
self._payloadDataBuffer.extend(payloadForThisFrame)
# Now we have the complete wss frame, reset the context
# Check to see if it is a wss closing frame
if self._opCode == self._OP_CONNECTION_CLOSE:
self._connectStatus = self._WebsocketDisconnected
self._payloadDataBuffer = bytearray() # Ensure that once the wss closing frame comes, we have nothing to read and start all over again
# Check to see if it is a wss PING frame
if self._opCode == self._OP_PING:
self._sendPONG() # Nothing more to do here, if the transmission of the last wssMQTT packet is not finished, it will continue
self._reset()
# Check again if we have enough data for paho
if len(self._payloadDataBuffer) >= numberOfBytes:
ret = self._payloadDataBuffer[0:numberOfBytes]
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
# struct.unpack(fmt, string) # Py2.x
# struct.unpack(fmt, buffer) # Py3.x
# Here ret is always in bytes (buffer interface)
if sys.version_info[0] < 3: # Py2.x
ret = str(ret)
return ret
else: # Fragmented MQTT packets in separate wss frames
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not a complete MQTT packet payload within this wss frame.")
def write(self, bytesToBeSent):
# When there is a disconnection, select will report a TypeError which triggers the reconnect.
# In reconnect, Paho will set the socket object (mocked by wss) to None, blocking other ops
# before a connection is re-established.
# This 'low-level' socket write op should always be able to write to plain socket.
# Error reporting is performed by Python socket itself.
# Wss closing frame handling is performed in the wss read.
return self._bufferedWriter.write(self._encodeFrame(bytesToBeSent, self._OP_BINARY, 1), len(bytesToBeSent))
def close(self):
if self._sslSocket is not None:
self._sslSocket.close()
self._sslSocket = None
def getpeercert(self):
return self._sslSocket.getpeercert()
def getSSLSocket(self):
if self._connectStatus != self._WebsocketDisconnected:
return self._sslSocket
else:
return None # Leave the sslSocket to Paho to close it. (_ssl.close() -> wssCore.close())

View File

@@ -0,0 +1,244 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import ssl
import logging
from threading import Lock
from numbers import Number
import AWSIoTPythonSDK.core.protocol.paho.client as mqtt
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
class ClientStatus(object):
IDLE = 0
CONNECT = 1
RESUBSCRIBE = 2
DRAINING = 3
STABLE = 4
USER_DISCONNECT = 5
ABNORMAL_DISCONNECT = 6
class ClientStatusContainer(object):
def __init__(self):
self._status = ClientStatus.IDLE
def get_status(self):
return self._status
def set_status(self, status):
if ClientStatus.USER_DISCONNECT == self._status: # If user requests to disconnect, no status updates other than user connect
if ClientStatus.CONNECT == status:
self._status = status
else:
self._status = status
class InternalAsyncMqttClient(object):
_logger = logging.getLogger(__name__)
def __init__(self, client_id, clean_session, protocol, use_wss):
self._paho_client = self._create_paho_client(client_id, clean_session, None, protocol, use_wss)
self._use_wss = use_wss
self._event_callback_map_lock = Lock()
self._event_callback_map = dict()
def _create_paho_client(self, client_id, clean_session, user_data, protocol, use_wss):
self._logger.debug("Initializing MQTT layer...")
return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss)
# TODO: Merge credentials providers configuration into one
def set_cert_credentials_provider(self, cert_credentials_provider):
# History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3
# pre-installed. In this version, TLSv1_2 is not even an option.
# SSLv23 is a work-around which selects the highest TLS version between the client
# and service. If user installs opensslv1.0.1+, this option will work fine for Mutual
# Auth.
# Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support
# in Python only starts from Python2.7.
# See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23
if self._use_wss:
ca_path = cert_credentials_provider.get_ca_path()
self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
else:
ca_path = cert_credentials_provider.get_ca_path()
cert_path = cert_credentials_provider.get_cert_path()
key_path = cert_credentials_provider.get_key_path()
self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path,
cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
def set_iam_credentials_provider(self, iam_credentials_provider):
self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(),
iam_credentials_provider.get_secret_access_key(),
iam_credentials_provider.get_session_token())
def set_endpoint_provider(self, endpoint_provider):
self._endpoint_provider = endpoint_provider
def configure_last_will(self, topic, payload, qos, retain=False):
self._paho_client.will_set(topic, payload, qos, retain)
def configure_alpn_protocols(self, alpn_protocols):
self._paho_client.config_alpn_protocols(alpn_protocols)
def clear_last_will(self):
self._paho_client.will_clear()
def set_username_password(self, username, password=None):
self._paho_client.username_pw_set(username, password)
def set_socket_factory(self, socket_factory):
self._paho_client.socket_factory_set(socket_factory)
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
def connect(self, keep_alive_sec, ack_callback=None):
host = self._endpoint_provider.get_host()
port = self._endpoint_provider.get_port()
with self._event_callback_map_lock:
self._logger.debug("Filling in fixed event callbacks: CONNACK, DISCONNECT, MESSAGE")
self._event_callback_map[FixedEventMids.CONNACK_MID] = self._create_combined_on_connect_callback(ack_callback)
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = self._create_combined_on_disconnect_callback(None)
self._event_callback_map[FixedEventMids.MESSAGE_MID] = self._create_converted_on_message_callback()
rc = self._paho_client.connect(host, port, keep_alive_sec)
if MQTT_ERR_SUCCESS == rc:
self.start_background_network_io()
return rc
def start_background_network_io(self):
self._logger.debug("Starting network I/O thread...")
self._paho_client.loop_start()
def stop_background_network_io(self):
self._logger.debug("Stopping network I/O thread...")
self._paho_client.loop_stop()
def disconnect(self, ack_callback=None):
with self._event_callback_map_lock:
rc = self._paho_client.disconnect()
if MQTT_ERR_SUCCESS == rc:
self._logger.debug("Filling in custom disconnect event callback...")
combined_on_disconnect_callback = self._create_combined_on_disconnect_callback(ack_callback)
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = combined_on_disconnect_callback
return rc
def _create_combined_on_connect_callback(self, ack_callback):
def combined_on_connect_callback(mid, data):
self.on_online()
if ack_callback:
ack_callback(mid, data)
return combined_on_connect_callback
def _create_combined_on_disconnect_callback(self, ack_callback):
def combined_on_disconnect_callback(mid, data):
self.on_offline()
if ack_callback:
ack_callback(mid, data)
return combined_on_disconnect_callback
def _create_converted_on_message_callback(self):
def converted_on_message_callback(mid, data):
self.on_message(data)
return converted_on_message_callback
# For client online notification
def on_online(self):
pass
# For client offline notification
def on_offline(self):
pass
# For client message reception notification
def on_message(self, message):
pass
def publish(self, topic, payload, qos, retain=False, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.publish(topic, payload, qos, retain)
if MQTT_ERR_SUCCESS == rc and qos > 0 and ack_callback:
self._logger.debug("Filling in custom puback (QoS>0) event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def subscribe(self, topic, qos, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.subscribe(topic, qos)
if MQTT_ERR_SUCCESS == rc and ack_callback:
self._logger.debug("Filling in custom suback event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def unsubscribe(self, topic, ack_callback=None):
with self._event_callback_map_lock:
rc, mid = self._paho_client.unsubscribe(topic)
if MQTT_ERR_SUCCESS == rc and ack_callback:
self._logger.debug("Filling in custom unsuback event callback...")
self._event_callback_map[mid] = ack_callback
return rc, mid
def register_internal_event_callbacks(self, on_connect, on_disconnect, on_publish, on_subscribe, on_unsubscribe, on_message):
self._logger.debug("Registering internal event callbacks to MQTT layer...")
self._paho_client.on_connect = on_connect
self._paho_client.on_disconnect = on_disconnect
self._paho_client.on_publish = on_publish
self._paho_client.on_subscribe = on_subscribe
self._paho_client.on_unsubscribe = on_unsubscribe
self._paho_client.on_message = on_message
def unregister_internal_event_callbacks(self):
self._logger.debug("Unregistering internal event callbacks from MQTT layer...")
self._paho_client.on_connect = None
self._paho_client.on_disconnect = None
self._paho_client.on_publish = None
self._paho_client.on_subscribe = None
self._paho_client.on_unsubscribe = None
self._paho_client.on_message = None
def invoke_event_callback(self, mid, data=None):
with self._event_callback_map_lock:
event_callback = self._event_callback_map.get(mid)
# For invoking the event callback, we do not need to acquire the lock
if event_callback:
self._logger.debug("Invoking custom event callback...")
if data is not None:
event_callback(mid=mid, data=data)
else:
event_callback(mid=mid)
if isinstance(mid, Number): # Do NOT remove callbacks for CONNACK/DISCONNECT/MESSAGE
self._logger.debug("This custom event callback is for pub/sub/unsub, removing it after invocation...")
with self._event_callback_map_lock:
del self._event_callback_map[mid]
def remove_event_callback(self, mid):
with self._event_callback_map_lock:
if mid in self._event_callback_map:
self._logger.debug("Removing custom event callback...")
del self._event_callback_map[mid]
def clean_up_event_callbacks(self):
with self._event_callback_map_lock:
self._event_callback_map.clear()
def get_event_callback_map(self):
return self._event_callback_map

View File

@@ -0,0 +1,20 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC = 30
DEFAULT_OPERATION_TIMEOUT_SEC = 5
DEFAULT_DRAINING_INTERNAL_SEC = 0.5
METRICS_PREFIX = "?SDK=Python&Version="
ALPN_PROTCOLS = "x-amzn-mqtt-ca"

View File

@@ -0,0 +1,29 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class EventTypes(object):
CONNACK = 0
DISCONNECT = 1
PUBACK = 2
SUBACK = 3
UNSUBACK = 4
MESSAGE = 5
class FixedEventMids(object):
CONNACK_MID = "CONNECTED"
DISCONNECT_MID = "DISCONNECTED"
MESSAGE_MID = "MESSAGE"
QUEUED_MID = "QUEUED"

View File

@@ -0,0 +1,87 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import logging
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
class AppendResults(object):
APPEND_FAILURE_QUEUE_FULL = -1
APPEND_FAILURE_QUEUE_DISABLED = -2
APPEND_SUCCESS = 0
class OfflineRequestQueue(list):
_logger = logging.getLogger(__name__)
def __init__(self, max_size, drop_behavior=DropBehaviorTypes.DROP_NEWEST):
if not isinstance(max_size, int) or not isinstance(drop_behavior, int):
self._logger.error("init: MaximumSize/DropBehavior must be integer.")
raise TypeError("MaximumSize/DropBehavior must be integer.")
if drop_behavior != DropBehaviorTypes.DROP_OLDEST and drop_behavior != DropBehaviorTypes.DROP_NEWEST:
self._logger.error("init: Drop behavior not supported.")
raise ValueError("Drop behavior not supported.")
list.__init__([])
self._drop_behavior = drop_behavior
# When self._maximumSize > 0, queue is limited
# When self._maximumSize == 0, queue is disabled
# When self._maximumSize < 0. queue is infinite
self._max_size = max_size
def _is_enabled(self):
return self._max_size != 0
def _need_drop_messages(self):
# Need to drop messages when:
# 1. Queue is limited and full
# 2. Queue is disabled
is_queue_full = len(self) >= self._max_size
is_queue_limited = self._max_size > 0
is_queue_disabled = not self._is_enabled()
return (is_queue_full and is_queue_limited) or is_queue_disabled
def set_behavior_drop_newest(self):
self._drop_behavior = DropBehaviorTypes.DROP_NEWEST
def set_behavior_drop_oldest(self):
self._drop_behavior = DropBehaviorTypes.DROP_OLDEST
# Override
# Append to a queue with a limited size.
# Return APPEND_SUCCESS if the append is successful
# Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full
# Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled
def append(self, data):
ret = AppendResults.APPEND_SUCCESS
if self._is_enabled():
if self._need_drop_messages():
# We should drop the newest
if DropBehaviorTypes.DROP_NEWEST == self._drop_behavior:
self._logger.warn("append: Full queue. Drop the newest: " + str(data))
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
# We should drop the oldest
else:
current_oldest = super(OfflineRequestQueue, self).pop(0)
self._logger.warn("append: Full queue. Drop the oldest: " + str(current_oldest))
super(OfflineRequestQueue, self).append(data)
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
else:
self._logger.debug("append: Add new element: " + str(data))
super(OfflineRequestQueue, self).append(data)
else:
self._logger.debug("append: Queue is disabled. Drop the message: " + str(data))
ret = AppendResults.APPEND_FAILURE_QUEUE_DISABLED
return ret

View File

@@ -0,0 +1,27 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class RequestTypes(object):
CONNECT = 0
DISCONNECT = 1
PUBLISH = 2
SUBSCRIBE = 3
UNSUBSCRIBE = 4
class QueueableRequest(object):
def __init__(self, type, data):
self.type = type
self.data = data # Can be a tuple

View File

@@ -0,0 +1,296 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import time
import logging
from threading import Thread
from threading import Event
from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
from AWSIoTPythonSDK.core.protocol.internal.queues import OfflineRequestQueue
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
from AWSIoTPythonSDK.core.protocol.paho.client import topic_matches_sub
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC
class EventProducer(object):
_logger = logging.getLogger(__name__)
def __init__(self, cv, event_queue):
self._cv = cv
self._event_queue = event_queue
def on_connect(self, client, user_data, flags, rc):
self._add_to_queue(FixedEventMids.CONNACK_MID, EventTypes.CONNACK, rc)
self._logger.debug("Produced [connack] event")
def on_disconnect(self, client, user_data, rc):
self._add_to_queue(FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, rc)
self._logger.debug("Produced [disconnect] event")
def on_publish(self, client, user_data, mid):
self._add_to_queue(mid, EventTypes.PUBACK, None)
self._logger.debug("Produced [puback] event")
def on_subscribe(self, client, user_data, mid, granted_qos):
self._add_to_queue(mid, EventTypes.SUBACK, granted_qos)
self._logger.debug("Produced [suback] event")
def on_unsubscribe(self, client, user_data, mid):
self._add_to_queue(mid, EventTypes.UNSUBACK, None)
self._logger.debug("Produced [unsuback] event")
def on_message(self, client, user_data, message):
self._add_to_queue(FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, message)
self._logger.debug("Produced [message] event")
def _add_to_queue(self, mid, event_type, data):
with self._cv:
self._event_queue.put((mid, event_type, data))
self._cv.notify()
class EventConsumer(object):
MAX_DISPATCH_INTERNAL_SEC = 0.01
_logger = logging.getLogger(__name__)
def __init__(self, cv, event_queue, internal_async_client,
subscription_manager, offline_requests_manager, client_status):
self._cv = cv
self._event_queue = event_queue
self._internal_async_client = internal_async_client
self._subscription_manager = subscription_manager
self._offline_requests_manager = offline_requests_manager
self._client_status = client_status
self._is_running = False
self._draining_interval_sec = DEFAULT_DRAINING_INTERNAL_SEC
self._dispatch_methods = {
EventTypes.CONNACK : self._dispatch_connack,
EventTypes.DISCONNECT : self._dispatch_disconnect,
EventTypes.PUBACK : self._dispatch_puback,
EventTypes.SUBACK : self._dispatch_suback,
EventTypes.UNSUBACK : self._dispatch_unsuback,
EventTypes.MESSAGE : self._dispatch_message
}
self._offline_request_handlers = {
RequestTypes.PUBLISH : self._handle_offline_publish,
RequestTypes.SUBSCRIBE : self._handle_offline_subscribe,
RequestTypes.UNSUBSCRIBE : self._handle_offline_unsubscribe
}
self._stopper = Event()
def update_offline_requests_manager(self, offline_requests_manager):
self._offline_requests_manager = offline_requests_manager
def update_draining_interval_sec(self, draining_interval_sec):
self._draining_interval_sec = draining_interval_sec
def get_draining_interval_sec(self):
return self._draining_interval_sec
def is_running(self):
return self._is_running
def start(self):
self._stopper.clear()
self._is_running = True
dispatch_events = Thread(target=self._dispatch)
dispatch_events.daemon = True
dispatch_events.start()
self._logger.debug("Event consuming thread started")
def stop(self):
if self._is_running:
self._is_running = False
self._clean_up()
self._logger.debug("Event consuming thread stopped")
def _clean_up(self):
self._logger.debug("Cleaning up before stopping event consuming")
with self._event_queue.mutex:
self._event_queue.queue.clear()
self._logger.debug("Event queue cleared")
self._internal_async_client.stop_background_network_io()
self._logger.debug("Network thread stopped")
self._internal_async_client.clean_up_event_callbacks()
self._logger.debug("Event callbacks cleared")
def wait_until_it_stops(self, timeout_sec):
self._logger.debug("Waiting for event consumer to completely stop")
return self._stopper.wait(timeout=timeout_sec)
def is_fully_stopped(self):
return self._stopper.is_set()
def _dispatch(self):
while self._is_running:
with self._cv:
if self._event_queue.empty():
self._cv.wait(self.MAX_DISPATCH_INTERNAL_SEC)
else:
while not self._event_queue.empty():
self._dispatch_one()
self._stopper.set()
self._logger.debug("Exiting dispatching loop...")
def _dispatch_one(self):
mid, event_type, data = self._event_queue.get()
if mid:
self._dispatch_methods[event_type](mid, data)
self._internal_async_client.invoke_event_callback(mid, data=data)
# We need to make sure disconnect event gets dispatched and then we stop the consumer
if self._need_to_stop_dispatching(mid):
self.stop()
def _need_to_stop_dispatching(self, mid):
status = self._client_status.get_status()
return (ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status) \
and mid == FixedEventMids.DISCONNECT_MID
def _dispatch_connack(self, mid, rc):
status = self._client_status.get_status()
self._logger.debug("Dispatching [connack] event")
if self._need_recover():
if ClientStatus.STABLE != status: # To avoid multiple connack dispatching
self._logger.debug("Has recovery job")
clean_up_debt = Thread(target=self._clean_up_debt)
clean_up_debt.start()
else:
self._logger.debug("No need for recovery")
self._client_status.set_status(ClientStatus.STABLE)
def _need_recover(self):
return self._subscription_manager.list_records() or self._offline_requests_manager.has_more()
def _clean_up_debt(self):
self._handle_resubscribe()
self._handle_draining()
self._client_status.set_status(ClientStatus.STABLE)
def _handle_resubscribe(self):
subscriptions = self._subscription_manager.list_records()
if subscriptions and not self._has_user_disconnect_request():
self._logger.debug("Start resubscribing")
self._client_status.set_status(ClientStatus.RESUBSCRIBE)
for topic, (qos, message_callback, ack_callback) in subscriptions:
if self._has_user_disconnect_request():
self._logger.debug("User disconnect detected")
break
self._internal_async_client.subscribe(topic, qos, ack_callback)
def _handle_draining(self):
if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request():
self._logger.debug("Start draining")
self._client_status.set_status(ClientStatus.DRAINING)
while self._offline_requests_manager.has_more():
if self._has_user_disconnect_request():
self._logger.debug("User disconnect detected")
break
offline_request = self._offline_requests_manager.get_next()
if offline_request:
self._offline_request_handlers[offline_request.type](offline_request)
time.sleep(self._draining_interval_sec)
def _has_user_disconnect_request(self):
return ClientStatus.USER_DISCONNECT == self._client_status.get_status()
def _dispatch_disconnect(self, mid, rc):
self._logger.debug("Dispatching [disconnect] event")
status = self._client_status.get_status()
if ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status:
pass
else:
self._client_status.set_status(ClientStatus.ABNORMAL_DISCONNECT)
# For puback, suback and unsuback, ack callback invocation is handled in dispatch_one
# Do nothing in the event dispatching itself
def _dispatch_puback(self, mid, rc):
self._logger.debug("Dispatching [puback] event")
def _dispatch_suback(self, mid, rc):
self._logger.debug("Dispatching [suback] event")
def _dispatch_unsuback(self, mid, rc):
self._logger.debug("Dispatching [unsuback] event")
def _dispatch_message(self, mid, message):
self._logger.debug("Dispatching [message] event")
subscriptions = self._subscription_manager.list_records()
if subscriptions:
for topic, (qos, message_callback, _) in subscriptions:
if topic_matches_sub(topic, message.topic) and message_callback:
message_callback(None, None, message) # message_callback(client, userdata, message)
def _handle_offline_publish(self, request):
topic, payload, qos, retain = request.data
self._internal_async_client.publish(topic, payload, qos, retain)
self._logger.debug("Processed offline publish request")
def _handle_offline_subscribe(self, request):
topic, qos, message_callback, ack_callback = request.data
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
self._internal_async_client.subscribe(topic, qos, ack_callback)
self._logger.debug("Processed offline subscribe request")
def _handle_offline_unsubscribe(self, request):
topic, ack_callback = request.data
self._subscription_manager.remove_record(topic)
self._internal_async_client.unsubscribe(topic, ack_callback)
self._logger.debug("Processed offline unsubscribe request")
class SubscriptionManager(object):
_logger = logging.getLogger(__name__)
def __init__(self):
self._subscription_map = dict()
def add_record(self, topic, qos, message_callback, ack_callback):
self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos)
self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None
def remove_record(self, topic):
self._logger.debug("Removing subscription record: %s", topic)
if self._subscription_map.get(topic): # Ignore topics that are never subscribed to
del self._subscription_map[topic]
else:
self._logger.warn("Removing attempt for non-exist subscription record: %s", topic)
def list_records(self):
return list(self._subscription_map.items())
class OfflineRequestsManager(object):
_logger = logging.getLogger(__name__)
def __init__(self, max_size, drop_behavior):
self._queue = OfflineRequestQueue(max_size, drop_behavior)
def has_more(self):
return len(self._queue) > 0
def add_one(self, request):
return self._queue.append(request)
def get_next(self):
if self.has_more():
return self._queue.pop(0)
else:
return None

View File

@@ -0,0 +1,373 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import AWSIoTPythonSDK
from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer
from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer
from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager
from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_OPERATION_TIMEOUT_SEC
from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX
from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException
from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv31
from threading import Condition
from threading import Event
import logging
import sys
if sys.version_info[0] < 3:
from Queue import Queue
else:
from queue import Queue
class MqttCore(object):
_logger = logging.getLogger(__name__)
def __init__(self, client_id, clean_session, protocol, use_wss):
self._use_wss = use_wss
self._username = ""
self._password = None
self._enable_metrics_collection = True
self._event_queue = Queue()
self._event_cv = Condition()
self._event_producer = EventProducer(self._event_cv, self._event_queue)
self._client_status = ClientStatusContainer()
self._internal_async_client = InternalAsyncMqttClient(client_id, clean_session, protocol, use_wss)
self._subscription_manager = SubscriptionManager()
self._offline_requests_manager = OfflineRequestsManager(-1, DropBehaviorTypes.DROP_NEWEST) # Infinite queue
self._event_consumer = EventConsumer(self._event_cv,
self._event_queue,
self._internal_async_client,
self._subscription_manager,
self._offline_requests_manager,
self._client_status)
self._connect_disconnect_timeout_sec = DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
self._operation_timeout_sec = DEFAULT_OPERATION_TIMEOUT_SEC
self._init_offline_request_exceptions()
self._init_workers()
self._logger.info("MqttCore initialized")
self._logger.info("Client id: %s" % client_id)
self._logger.info("Protocol version: %s" % ("MQTTv3.1" if protocol == MQTTv31 else "MQTTv3.1.1"))
self._logger.info("Authentication type: %s" % ("SigV4 WebSocket" if use_wss else "TLSv1.2 certificate based Mutual Auth."))
def _init_offline_request_exceptions(self):
self._offline_request_queue_disabled_exceptions = {
RequestTypes.PUBLISH : publishQueueDisabledException(),
RequestTypes.SUBSCRIBE : subscribeQueueDisabledException(),
RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException()
}
self._offline_request_queue_full_exceptions = {
RequestTypes.PUBLISH : publishQueueFullException(),
RequestTypes.SUBSCRIBE : subscribeQueueFullException(),
RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException()
}
def _init_workers(self):
self._internal_async_client.register_internal_event_callbacks(self._event_producer.on_connect,
self._event_producer.on_disconnect,
self._event_producer.on_publish,
self._event_producer.on_subscribe,
self._event_producer.on_unsubscribe,
self._event_producer.on_message)
def _start_workers(self):
self._event_consumer.start()
def use_wss(self):
return self._use_wss
# Used for general message event reception
def on_message(self, message):
pass
# Used for general online event notification
def on_online(self):
pass
# Used for general offline event notification
def on_offline(self):
pass
def configure_cert_credentials(self, cert_credentials_provider):
self._logger.info("Configuring certificates...")
self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider)
def configure_iam_credentials(self, iam_credentials_provider):
self._logger.info("Configuring custom IAM credentials...")
self._internal_async_client.set_iam_credentials_provider(iam_credentials_provider)
def configure_endpoint(self, endpoint_provider):
self._logger.info("Configuring endpoint...")
self._internal_async_client.set_endpoint_provider(endpoint_provider)
def configure_connect_disconnect_timeout_sec(self, connect_disconnect_timeout_sec):
self._logger.info("Configuring connect/disconnect time out: %f sec" % connect_disconnect_timeout_sec)
self._connect_disconnect_timeout_sec = connect_disconnect_timeout_sec
def configure_operation_timeout_sec(self, operation_timeout_sec):
self._logger.info("Configuring MQTT operation time out: %f sec" % operation_timeout_sec)
self._operation_timeout_sec = operation_timeout_sec
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
self._logger.info("Configuring reconnect back off timing...")
self._logger.info("Base quiet time: %f sec" % base_reconnect_quiet_sec)
self._logger.info("Max quiet time: %f sec" % max_reconnect_quiet_sec)
self._logger.info("Stable connection time: %f sec" % stable_connection_sec)
self._internal_async_client.configure_reconnect_back_off(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
def configure_alpn_protocols(self):
self._logger.info("Configuring alpn protocols...")
self._internal_async_client.configure_alpn_protocols([ALPN_PROTCOLS])
def configure_last_will(self, topic, payload, qos, retain=False):
self._logger.info("Configuring last will...")
self._internal_async_client.configure_last_will(topic, payload, qos, retain)
def clear_last_will(self):
self._logger.info("Clearing last will...")
self._internal_async_client.clear_last_will()
def configure_username_password(self, username, password=None):
self._logger.info("Configuring username and password...")
self._username = username
self._password = password
def configure_socket_factory(self, socket_factory):
self._logger.info("Configuring socket factory...")
self._internal_async_client.set_socket_factory(socket_factory)
def enable_metrics_collection(self):
self._enable_metrics_collection = True
def disable_metrics_collection(self):
self._enable_metrics_collection = False
def configure_offline_requests_queue(self, max_size, drop_behavior):
self._logger.info("Configuring offline requests queueing: max queue size: %d", max_size)
self._offline_requests_manager = OfflineRequestsManager(max_size, drop_behavior)
self._event_consumer.update_offline_requests_manager(self._offline_requests_manager)
def configure_draining_interval_sec(self, draining_interval_sec):
self._logger.info("Configuring offline requests queue draining interval: %f sec", draining_interval_sec)
self._event_consumer.update_draining_interval_sec(draining_interval_sec)
def connect(self, keep_alive_sec):
self._logger.info("Performing sync connect...")
event = Event()
self.connect_async(keep_alive_sec, self._create_blocking_ack_callback(event))
if not event.wait(self._connect_disconnect_timeout_sec):
self._logger.error("Connect timed out")
raise connectTimeoutException()
return True
def connect_async(self, keep_alive_sec, ack_callback=None):
self._logger.info("Performing async connect...")
self._logger.info("Keep-alive: %f sec" % keep_alive_sec)
self._start_workers()
self._load_callbacks()
self._load_username_password()
try:
self._client_status.set_status(ClientStatus.CONNECT)
rc = self._internal_async_client.connect(keep_alive_sec, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Connect error: %d", rc)
raise connectError(rc)
except Exception as e:
# Provided any error in connect, we should clean up the threads that have been created
self._event_consumer.stop()
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
self._logger.error("Time out in waiting for event consumer to stop")
else:
self._logger.debug("Event consumer stopped")
self._client_status.set_status(ClientStatus.IDLE)
raise e
return FixedEventMids.CONNACK_MID
def _load_callbacks(self):
self._logger.debug("Passing in general notification callbacks to internal client...")
self._internal_async_client.on_online = self.on_online
self._internal_async_client.on_offline = self.on_offline
self._internal_async_client.on_message = self.on_message
def _load_username_password(self):
username_candidate = self._username
if self._enable_metrics_collection:
username_candidate += METRICS_PREFIX
username_candidate += AWSIoTPythonSDK.__version__
self._internal_async_client.set_username_password(username_candidate, self._password)
def disconnect(self):
self._logger.info("Performing sync disconnect...")
event = Event()
self.disconnect_async(self._create_blocking_ack_callback(event))
if not event.wait(self._connect_disconnect_timeout_sec):
self._logger.error("Disconnect timed out")
raise disconnectTimeoutException()
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
self._logger.error("Disconnect timed out in waiting for event consumer")
raise disconnectTimeoutException()
return True
def disconnect_async(self, ack_callback=None):
self._logger.info("Performing async disconnect...")
self._client_status.set_status(ClientStatus.USER_DISCONNECT)
rc = self._internal_async_client.disconnect(ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Disconnect error: %d", rc)
raise disconnectError(rc)
return FixedEventMids.DISCONNECT_MID
def publish(self, topic, payload, qos, retain=False):
self._logger.info("Performing sync publish...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
else:
if qos > 0:
event = Event()
rc, mid = self._publish_async(topic, payload, qos, retain, self._create_blocking_ack_callback(event))
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Publish timed out")
raise publishTimeoutException()
else:
self._publish_async(topic, payload, qos, retain)
ret = True
return ret
def publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
self._logger.info("Performing async publish...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._publish_async(topic, payload, qos, retain, ack_callback)
return mid
def _publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
rc, mid = self._internal_async_client.publish(topic, payload, qos, retain, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Publish error: %d", rc)
raise publishError(rc)
return rc, mid
def subscribe(self, topic, qos, message_callback=None):
self._logger.info("Performing sync subscribe...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None))
else:
event = Event()
rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback)
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Subscribe timed out")
raise subscribeTimeoutException()
ret = True
return ret
def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
self._logger.info("Performing async subscribe...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback)
return mid
def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Subscribe error: %d", rc)
raise subscribeError(rc)
return rc, mid
def unsubscribe(self, topic):
self._logger.info("Performing sync unsubscribe...")
ret = False
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None))
else:
event = Event()
rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event))
if not event.wait(self._operation_timeout_sec):
self._internal_async_client.remove_event_callback(mid)
self._logger.error("Unsubscribe timed out")
raise unsubscribeTimeoutException()
ret = True
return ret
def unsubscribe_async(self, topic, ack_callback=None):
self._logger.info("Performing async unsubscribe...")
if ClientStatus.STABLE != self._client_status.get_status():
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback))
return FixedEventMids.QUEUED_MID
else:
rc, mid = self._unsubscribe_async(topic, ack_callback)
return mid
def _unsubscribe_async(self, topic, ack_callback=None):
self._subscription_manager.remove_record(topic)
rc, mid = self._internal_async_client.unsubscribe(topic, ack_callback)
if MQTT_ERR_SUCCESS != rc:
self._logger.error("Unsubscribe error: %d", rc)
raise unsubscribeError(rc)
return rc, mid
def _create_blocking_ack_callback(self, event):
def ack_callback(mid, data=None):
event.set()
return ack_callback
def _handle_offline_request(self, type, data):
self._logger.info("Offline request detected!")
offline_request = QueueableRequest(type, data)
append_result = self._offline_requests_manager.add_one(offline_request)
if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result:
self._logger.error("Offline request queue has been disabled")
raise self._offline_request_queue_disabled_exceptions[type]
if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result:
self._logger.error("Offline request queue is full")
raise self._offline_request_queue_full_exceptions[type]

View File

@@ -0,0 +1,430 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import json
import logging
import uuid
from threading import Timer, Lock, Thread
class _shadowRequestToken:
URN_PREFIX_LENGTH = 9
def getNextToken(self):
return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix
class _basicJSONParser:
def setString(self, srcString):
self._rawString = srcString
self._dictionObject = None
def regenerateString(self):
return json.dumps(self._dictionaryObject)
def getAttributeValue(self, srcAttributeKey):
return self._dictionaryObject.get(srcAttributeKey)
def setAttributeValue(self, srcAttributeKey, srcAttributeValue):
self._dictionaryObject[srcAttributeKey] = srcAttributeValue
def validateJSON(self):
try:
self._dictionaryObject = json.loads(self._rawString)
except ValueError:
return False
return True
class deviceShadow:
_logger = logging.getLogger(__name__)
def __init__(self, srcShadowName, srcIsPersistentSubscribe, srcShadowManager):
"""
The class that denotes a local/client-side device shadow instance.
Users can perform shadow operations on this instance to retrieve and modify the
corresponding shadow JSON document in AWS IoT Cloud. The following shadow operations
are available:
- Get
- Update
- Delete
- Listen on delta
- Cancel listening on delta
This is returned from :code:`AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient.createShadowWithName` function call.
No need to call directly from user scripts.
"""
if srcShadowName is None or srcIsPersistentSubscribe is None or srcShadowManager is None:
raise TypeError("None type inputs detected.")
self._shadowName = srcShadowName
# Tool handler
self._shadowManagerHandler = srcShadowManager
self._basicJSONParserHandler = _basicJSONParser()
self._tokenHandler = _shadowRequestToken()
# Properties
self._isPersistentSubscribe = srcIsPersistentSubscribe
self._lastVersionInSync = -1 # -1 means not initialized
self._isGetSubscribed = False
self._isUpdateSubscribed = False
self._isDeleteSubscribed = False
self._shadowSubscribeCallbackTable = dict()
self._shadowSubscribeCallbackTable["get"] = None
self._shadowSubscribeCallbackTable["delete"] = None
self._shadowSubscribeCallbackTable["update"] = None
self._shadowSubscribeCallbackTable["delta"] = None
self._shadowSubscribeStatusTable = dict()
self._shadowSubscribeStatusTable["get"] = 0
self._shadowSubscribeStatusTable["delete"] = 0
self._shadowSubscribeStatusTable["update"] = 0
self._tokenPool = dict()
self._dataStructureLock = Lock()
def _doNonPersistentUnsubscribe(self, currentAction):
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, currentAction)
self._logger.info("Unsubscribed to " + currentAction + " accepted/rejected topics for deviceShadow: " + self._shadowName)
def generalCallback(self, client, userdata, message):
# In Py3.x, message.payload comes in as a bytes(string)
# json.loads needs a string input
with self._dataStructureLock:
currentTopic = message.topic
currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta
currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta
payloadUTF8String = message.payload.decode('utf-8')
# get/delete/update: Need to deal with token, timer and unsubscribe
if currentAction in ["get", "delete", "update"]:
# Check for token
self._basicJSONParserHandler.setString(payloadUTF8String)
if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON
currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken")
if currentToken is not None:
self._logger.debug("shadow message clientToken: " + currentToken)
if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token
# Sync local version when it is an accepted response
self._logger.debug("Token is in the pool. Type: " + currentType)
if currentType == "accepted":
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
# If it is get/update accepted response, we need to sync the local version
if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete":
self._lastVersionInSync = incomingVersion
# If it is a delete accepted, we need to reset the version
else:
self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response
# Cancel the timer and clear the token
self._tokenPool[currentToken].cancel()
del self._tokenPool[currentToken]
# Need to unsubscribe?
self._shadowSubscribeStatusTable[currentAction] -= 1
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0:
self._shadowSubscribeStatusTable[currentAction] = 0
processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction])
processNonPersistentUnsubscribe.start()
# Custom callback
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken])
processCustomCallback.start()
# delta: Watch for version
else:
currentType += "/" + self._parseTopicShadowName(currentTopic)
# Sync local version
self._basicJSONParserHandler.setString(payloadUTF8String)
if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
if incomingVersion is not None and incomingVersion > self._lastVersionInSync:
self._lastVersionInSync = incomingVersion
# Custom callback
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None])
processCustomCallback.start()
def _parseTopicAction(self, srcTopic):
ret = None
fragments = srcTopic.split('/')
if fragments[5] == "delta":
ret = "delta"
else:
ret = fragments[4]
return ret
def _parseTopicType(self, srcTopic):
fragments = srcTopic.split('/')
return fragments[5]
def _parseTopicShadowName(self, srcTopic):
fragments = srcTopic.split('/')
return fragments[2]
def _timerHandler(self, srcActionName, srcToken):
with self._dataStructureLock:
# Don't crash if we try to remove an unknown token
if srcToken not in self._tokenPool:
self._logger.warn('Tried to remove non-existent token from pool: %s' % str(srcToken))
return
# Remove the token
del self._tokenPool[srcToken]
# Need to unsubscribe?
self._shadowSubscribeStatusTable[srcActionName] -= 1
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0:
self._shadowSubscribeStatusTable[srcActionName] = 0
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName)
# Notify time-out issue
if self._shadowSubscribeCallbackTable.get(srcActionName) is not None:
self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.")
self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken)
def shadowGet(self, srcCallback, srcTimeout):
"""
**Description**
Retrieve the device shadow JSON document from AWS IoT by publishing an empty JSON document to the
corresponding shadow topics. Shadow response topics will be subscribed to receive responses from
AWS IoT regarding the result of the get operation. Retrieved shadow JSON document will be available
in the registered callback. If no response is received within the provided timeout, a timeout
notification will be passed into the registered callback.
**Syntax**
.. code:: python
# Retrieve the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowGet(customCallback, 5)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["get"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["get"] += 1
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken])
self._basicJSONParserHandler.setString("{}")
self._basicJSONParserHandler.validateJSON()
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
currentPayload = self._basicJSONParserHandler.regenerateString()
# Two subscriptions
if not self._isPersistentSubscribe or not self._isGetSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self.generalCallback)
self._isGetSubscribed = True
self._logger.info("Subscribed to get accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload)
# Start the timer
self._tokenPool[currentToken].start()
return currentToken
def shadowDelete(self, srcCallback, srcTimeout):
"""
**Description**
Delete the device shadow from AWS IoT by publishing an empty JSON document to the corresponding
shadow topics. Shadow response topics will be subscribed to receive responses from AWS IoT
regarding the result of the get operation. Responses will be available in the registered callback.
If no response is received within the provided timeout, a timeout notification will be passed into
the registered callback.
**Syntax**
.. code:: python
# Delete the device shadow from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowDelete(customCallback, 5)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["delete"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["delete"] += 1
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken])
self._basicJSONParserHandler.setString("{}")
self._basicJSONParserHandler.validateJSON()
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
currentPayload = self._basicJSONParserHandler.regenerateString()
# Two subscriptions
if not self._isPersistentSubscribe or not self._isDeleteSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self.generalCallback)
self._isDeleteSubscribed = True
self._logger.info("Subscribed to delete accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload)
# Start the timer
self._tokenPool[currentToken].start()
return currentToken
def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout):
"""
**Description**
Update the device shadow JSON document string from AWS IoT by publishing the provided JSON
document to the corresponding shadow topics. Shadow response topics will be subscribed to
receive responses from AWS IoT regarding the result of the get operation. Response will be
available in the registered callback. If no response is received within the provided timeout,
a timeout notification will be passed into the registered callback.
**Syntax**
.. code:: python
# Update the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
BotShadow.shadowUpdate(newShadowJSONDocumentString, customCallback, 5)
**Parameters**
*srcJSONPayload* - JSON document string used to update shadow JSON document in AWS IoT.
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
a timeout notification will be generated and put into the registered callback to notify users.
**Returns**
The token used for tracing in this shadow request.
"""
# Validate JSON
self._basicJSONParserHandler.setString(srcJSONPayload)
if self._basicJSONParserHandler.validateJSON():
with self._dataStructureLock:
# clientToken
currentToken = self._tokenHandler.getNextToken()
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken])
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString()
# Update callback data structure
self._shadowSubscribeCallbackTable["update"] = srcCallback
# Update number of pending feedback
self._shadowSubscribeStatusTable["update"] += 1
# Two subscriptions
if not self._isPersistentSubscribe or not self._isUpdateSubscribed:
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self.generalCallback)
self._isUpdateSubscribed = True
self._logger.info("Subscribed to update accepted/rejected topics for deviceShadow: " + self._shadowName)
# One publish
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken)
# Start the timer
self._tokenPool[currentToken].start()
else:
raise ValueError("Invalid JSON file.")
return currentToken
def shadowRegisterDeltaCallback(self, srcCallback):
"""
**Description**
Listen on delta topics for this device shadow by subscribing to delta topics. Whenever there
is a difference between the desired and reported state, the registered callback will be called
and the delta payload will be available in the callback.
**Syntax**
.. code:: python
# Listen on delta topics for BotShadow
BotShadow.shadowRegisterDeltaCallback(customCallback)
**Parameters**
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
**Returns**
None
"""
with self._dataStructureLock:
# Update callback data structure
self._shadowSubscribeCallbackTable["delta"] = srcCallback
# One subscription
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self.generalCallback)
self._logger.info("Subscribed to delta topic for deviceShadow: " + self._shadowName)
def shadowUnregisterDeltaCallback(self):
"""
**Description**
Cancel listening on delta topics for this device shadow by unsubscribing to delta topics. There will
be no delta messages received after this API call even though there is a difference between the
desired and reported state.
**Syntax**
.. code:: python
# Cancel listening on delta topics for BotShadow
BotShadow.shadowUnregisterDeltaCallback()
**Parameters**
None
**Returns**
None
"""
with self._dataStructureLock:
# Update callback data structure
del self._shadowSubscribeCallbackTable["delta"]
# One unsubscription
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, "delta")
self._logger.info("Unsubscribed to delta topics for deviceShadow: " + self._shadowName)

View File

@@ -0,0 +1,83 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import logging
import time
from threading import Lock
class _shadowAction:
_actionType = ["get", "update", "delete", "delta"]
def __init__(self, srcShadowName, srcActionName):
if srcActionName is None or srcActionName not in self._actionType:
raise TypeError("Unsupported shadow action.")
self._shadowName = srcShadowName
self._actionName = srcActionName
self.isDelta = srcActionName == "delta"
if self.isDelta:
self._topicDelta = "$aws/things/" + str(self._shadowName) + "/shadow/update/delta"
else:
self._topicGeneral = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName)
self._topicAccept = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/accepted"
self._topicReject = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/rejected"
def getTopicGeneral(self):
return self._topicGeneral
def getTopicAccept(self):
return self._topicAccept
def getTopicReject(self):
return self._topicReject
def getTopicDelta(self):
return self._topicDelta
class shadowManager:
_logger = logging.getLogger(__name__)
def __init__(self, srcMQTTCore):
# Load in mqttCore
if srcMQTTCore is None:
raise TypeError("None type inputs detected.")
self._mqttCoreHandler = srcMQTTCore
self._shadowSubUnsubOperationLock = Lock()
def basicShadowPublish(self, srcShadowName, srcShadowAction, srcPayload):
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
self._mqttCoreHandler.publish(currentShadowAction.getTopicGeneral(), srcPayload, 0, False)
def basicShadowSubscribe(self, srcShadowName, srcShadowAction, srcCallback):
with self._shadowSubUnsubOperationLock:
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
if currentShadowAction.isDelta:
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback)
else:
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback)
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback)
time.sleep(2)
def basicShadowUnsubscribe(self, srcShadowName, srcShadowAction):
with self._shadowSubUnsubOperationLock:
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
if currentShadowAction.isDelta:
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta())
else:
self._logger.debug(currentShadowAction.getTopicAccept())
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept())
self._logger.debug(currentShadowAction.getTopicReject())
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject())

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class DropBehaviorTypes(object):
DROP_OLDEST = 0
DROP_NEWEST = 1

View File

@@ -0,0 +1,92 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class CredentialsProvider(object):
def __init__(self):
self._ca_path = ""
def set_ca_path(self, ca_path):
self._ca_path = ca_path
def get_ca_path(self):
return self._ca_path
class CertificateCredentialsProvider(CredentialsProvider):
def __init__(self):
CredentialsProvider.__init__(self)
self._cert_path = ""
self._key_path = ""
def set_cert_path(self,cert_path):
self._cert_path = cert_path
def set_key_path(self, key_path):
self._key_path = key_path
def get_cert_path(self):
return self._cert_path
def get_key_path(self):
return self._key_path
class IAMCredentialsProvider(CredentialsProvider):
def __init__(self):
CredentialsProvider.__init__(self)
self._aws_access_key_id = ""
self._aws_secret_access_key = ""
self._aws_session_token = ""
def set_access_key_id(self, access_key_id):
self._aws_access_key_id = access_key_id
def set_secret_access_key(self, secret_access_key):
self._aws_secret_access_key = secret_access_key
def set_session_token(self, session_token):
self._aws_session_token = session_token
def get_access_key_id(self):
return self._aws_access_key_id
def get_secret_access_key(self):
return self._aws_secret_access_key
def get_session_token(self):
return self._aws_session_token
class EndpointProvider(object):
def __init__(self):
self._host = ""
self._port = -1
def set_host(self, host):
self._host = host
def set_port(self, port):
self._port = port
def get_host(self):
return self._host
def get_port(self):
return self._port

View File

@@ -0,0 +1,153 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
import AWSIoTPythonSDK.exception.operationTimeoutException as operationTimeoutException
import AWSIoTPythonSDK.exception.operationError as operationError
# Serial Exception
class acceptTimeoutException(Exception):
def __init__(self, msg="Accept Timeout"):
self.message = msg
# MQTT Operation Timeout Exception
class connectTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Connect Timeout"):
self.message = msg
class disconnectTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Disconnect Timeout"):
self.message = msg
class publishTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Publish Timeout"):
self.message = msg
class subscribeTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Subscribe Timeout"):
self.message = msg
class unsubscribeTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, msg="Unsubscribe Timeout"):
self.message = msg
# MQTT Operation Error
class connectError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Connect Error: " + str(errorCode)
class disconnectError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Disconnect Error: " + str(errorCode)
class publishError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Publish Error: " + str(errorCode)
class publishQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Publish Queue Full"
class publishQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline publish request dropped because queueing is disabled"
class subscribeError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Subscribe Error: " + str(errorCode)
class subscribeQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Subscribe Queue Full"
class subscribeQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline subscribe request dropped because queueing is disabled"
class unsubscribeError(operationError.operationError):
def __init__(self, errorCode):
self.message = "Unsubscribe Error: " + str(errorCode)
class unsubscribeQueueFullException(operationError.operationError):
def __init__(self):
self.message = "Internal Unsubscribe Queue Full"
class unsubscribeQueueDisabledException(operationError.operationError):
def __init__(self):
self.message = "Offline unsubscribe request dropped because queueing is disabled"
# Websocket Error
class wssNoKeyInEnvironmentError(operationError.operationError):
def __init__(self):
self.message = "No AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY detected in $ENV."
class wssHandShakeError(operationError.operationError):
def __init__(self):
self.message = "Error in WSS handshake."
# Greengrass Discovery Error
class DiscoveryDataNotFoundException(operationError.operationError):
def __init__(self):
self.message = "No discovery data found"
class DiscoveryTimeoutException(operationTimeoutException.operationTimeoutException):
def __init__(self, message="Discovery request timed out"):
self.message = message
class DiscoveryInvalidRequestException(operationError.operationError):
def __init__(self):
self.message = "Invalid discovery request"
class DiscoveryUnauthorizedException(operationError.operationError):
def __init__(self):
self.message = "Discovery request not authorized"
class DiscoveryThrottlingException(operationError.operationError):
def __init__(self):
self.message = "Too many discovery requests"
class DiscoveryFailure(operationError.operationError):
def __init__(self, message):
self.message = message
# Client Error
class ClientError(Exception):
def __init__(self, message):
self.message = message

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class operationError(Exception):
def __init__(self, msg="Operation Error"):
self.message = msg

View File

@@ -0,0 +1,19 @@
# /*
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License").
# * You may not use this file except in compliance with the License.
# * A copy of the License is located at
# *
# * http://aws.amazon.com/apache2.0
# *
# * or in the "license" file accompanying this file. This file is distributed
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# * express or implied. See the License for the specific language governing
# * permissions and limitations under the License.
# */
class operationTimeoutException(Exception):
def __init__(self, msg="Operation Timeout"):
self.message = msg

View File

@@ -0,0 +1,2 @@
[metadata]
description-file = README.rst

View File

@@ -0,0 +1,34 @@
import sys
sys.path.insert(0, 'AWSIoTPythonSDK')
import AWSIoTPythonSDK
currentVersion = AWSIoTPythonSDK.__version__
from distutils.core import setup
setup(
name = 'AWSIoTPythonSDK',
packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core',
'AWSIoTPythonSDK.core.util', 'AWSIoTPythonSDK.core.shadow', 'AWSIoTPythonSDK.core.protocol',
'AWSIoTPythonSDK.core.jobs',
'AWSIoTPythonSDK.core.protocol.paho', 'AWSIoTPythonSDK.core.protocol.internal',
'AWSIoTPythonSDK.core.protocol.connection', 'AWSIoTPythonSDK.core.greengrass',
'AWSIoTPythonSDK.core.greengrass.discovery', 'AWSIoTPythonSDK.exception'],
version = currentVersion,
description = 'SDK for connecting to AWS IoT using Python.',
author = 'Amazon Web Service',
author_email = '',
url = 'https://github.com/aws/aws-iot-device-sdk-python.git',
download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip',
keywords = ['aws', 'iot', 'mqtt'],
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Natural Language :: English",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5"
]
)